From 044adf8338267b75e2a3d27c3278fd06177cf068 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Thu, 8 Sep 2022 11:51:14 +0800 Subject: [PATCH 01/21] [fx] add some comment and docstrings. --- colossalai/fx/profiler/profiler.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 8f9fb92e0ae4..0547aa3072d7 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -33,19 +33,32 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop). mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ + # This subgraph traces aten level ops inside one node. + subgraph = Graph() + # flop_count serves as a global dictionary to store results. flop_count = { 'f': 0, 'l': 0, 'b': 0, } + + # TODO: remove this temp = { 'f': [], 'l': [], 'b': [], } + + # `stage` will mark the stage of autograd from outside scope. stage = 'f' + # FlopTensor not only get the flop statistics of a single node, + # it also build a full autograd graph for this node. + # This makes sure we can analyze the dependencies of memory, and + # decide which forward intermediate results should be kept until + # backward is executed. + # Hopefully, this attempt will provide a better estimation of memory. class FlopTensor(MetaTensor): def __repr__(self): @@ -78,6 +91,8 @@ def wrap(x): return tree_map(wrap, out) + # `WEIRD_OPS` are tough to handle because they don't accept autograd + # on meta tensor. if target not in WEIRD_OPS: def wrap(x): @@ -89,6 +104,7 @@ def wrap(x): return FlopTensor( x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x + # Basically, we need to detach the args and kwargs from the outer graph. args = tree_map(wrap, args) kwargs = tree_map(wrap, kwargs) @@ -99,6 +115,8 @@ def wrap(x): else: out = target(*args, **kwargs) + # If the output is not a floating point `torch.Tensor` or it does not + # requires grad, then we should not run backward for this node. if is_autogradable(out) and out.requires_grad: stage = 'l' loss = out.sum() @@ -110,7 +128,8 @@ def wrap(x): fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0 fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0 - bwd_tmp = max(map(activation_size, temp['b'])) if len(temp['b']) else 0 + bwd_tmp = max(map(activation_size, temp['b'][:-1])) if len(temp['b'][:-1]) else 0 + fwd_out = activation_size(temp['b'][-1]) if len(temp['b']) else 0 def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x @@ -134,11 +153,14 @@ def profile_function(target: 'Target') -> Callable: """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + + # If there is an argument that this `call_function` is inplace, we should + # skip the autograd profiling. if kwargs.get('inplace', False): args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) out = func(*args, **kwargs) - return out, (0, 0), (0, 0, 0, 0) + return out, (out.numel(), out.numel()), (0, 0, 0, 0) out, flop_count, mem_stat = _profile(func, *args, **kwargs) return out, flop_count, mem_stat @@ -178,6 +200,9 @@ def profile_module(module: torch.nn.Module) -> Callable: """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + + # If there is an argument that this `call_module` is inplace, we should + # skip the autograd profiling. if getattr(module, 'inplace', False): args = tree_map(lambda x: x.to('meta'), args) kwargs = tree_map(lambda x: x.to('meta'), kwargs) From a42ab22345358092932484aa16ba2f25ecdba010 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Thu, 8 Sep 2022 13:02:43 +0800 Subject: [PATCH 02/21] [fx] add dataflow analysis for an autograd graph. --- colossalai/fx/profiler/dataflow.py | 47 ++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 colossalai/fx/profiler/dataflow.py diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py new file mode 100644 index 000000000000..c002acf645a5 --- /dev/null +++ b/colossalai/fx/profiler/dataflow.py @@ -0,0 +1,47 @@ +from typing import Tuple +from torch.fx import Graph, Node +import torch + + +def is_forward(n: Node): + assert hasattr(n, 'stage'), f'Node {n} has no attribute `stage`!' + return getattr(n, 'stage') == 'f' + + +def is_loss(n: Node): + assert hasattr(n, 'stage'), f'Node {n} has no attribute `stage`!' + return getattr(n, 'stage') == 'l' + + +def is_backward(n: Node): + assert hasattr(n, 'stage'), f'Node {n} has no attribute `stage`!' + return getattr(n, 'stage') == 'b' + + +def autograd_graph_analysis(graph: Graph) -> Tuple[int, int, int, int]: + """Analyze the autograd node dependencies and find out the memory usage. + Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward). + Nodes should have attribute `_out_tensors` indicating the output of each node. + + Args: + graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) + + Returns: + fwd_tmp (int): Intermediate memory encountered through forward pass. These tensors are not supposed to be freed unless checkpointed. + fwd_out (int): The output of the entire forward pass. + bwd_tmp (int): Intermediate memory (or peak memory) encountered through backward pass. These tensors can be freed as long as it is not required for its users. We will use liveness analysis to detect the peak memory usage. + bwd_out (int): + """ + pass + + +def _peak_memory_analysis(nodes: Tuple[Node, ...]) -> int: + """Apply liveness analysis to a list of nodes in topological order and calculate the peak memory. + + Args: + nodes (Tuple[Node, ...]): A list of nodes in topological order. + + Returns: + memory_peak (int): Peak memory encountered during the execution. + """ + pass From 0d55f26010fdf7e1a46a1192e6ec938f32c5b48c Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 9 Sep 2022 13:28:09 +0800 Subject: [PATCH 03/21] add intepretation for graph analysis. --- colossalai/fx/profiler/dataflow.py | 55 +++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index c002acf645a5..d731baa64034 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,38 +1,69 @@ from typing import Tuple from torch.fx import Graph, Node import torch +from .memory import activation_size def is_forward(n: Node): - assert hasattr(n, 'stage'), f'Node {n} has no attribute `stage`!' - return getattr(n, 'stage') == 'f' + assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' + return n.meta['stage'] == 'f' def is_loss(n: Node): - assert hasattr(n, 'stage'), f'Node {n} has no attribute `stage`!' - return getattr(n, 'stage') == 'l' + assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' + return n.meta['stage'] == 'l' def is_backward(n: Node): - assert hasattr(n, 'stage'), f'Node {n} has no attribute `stage`!' - return getattr(n, 'stage') == 'b' + assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' + return n.meta['stage'] == 'b' def autograd_graph_analysis(graph: Graph) -> Tuple[int, int, int, int]: """Analyze the autograd node dependencies and find out the memory usage. - Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward). - Nodes should have attribute `_out_tensors` indicating the output of each node. - + Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. + Nodes should have attribute `out_tensors` indicating the output of each node. + ============================================================================ + p o <---- We need to keep track of grad out + |\________ | + ↓ ↘| + f --------> b + |\ \_____ ↑ + | \ ↘ / + f f ----> b <---- Not every forward result needs to be saved for backward + | \____ ↑ + ↘ ↘| + f ----> b <---- Backward can be freed as soon as it is required no more. + ↘ ↗ + l + ============================================================================= Args: - graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) + graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Returns: fwd_tmp (int): Intermediate memory encountered through forward pass. These tensors are not supposed to be freed unless checkpointed. fwd_out (int): The output of the entire forward pass. bwd_tmp (int): Intermediate memory (or peak memory) encountered through backward pass. These tensors can be freed as long as it is not required for its users. We will use liveness analysis to detect the peak memory usage. - bwd_out (int): + bwd_out (int): The output of the entire backward pass. """ - pass + # deps is used to track all the memory dependencies of the graph. + deps = {} + + fwd_tmp = 0 + fwd_out = 0 + bwd_tmp = 0 + bwd_out = 0 + + for n in graph.nodes: + n: Node + if is_forward(n): + if any(map(is_backward, n.users)): + fwd_tmp += activation_size(n.meta['out_tensors']) + if any(map(is_loss, n.users)): + fwd_out += activation_size(n.meta['out_tensors']) + elif is_backward(n): + if not len(n.users): + bwd_out += activation_size(n.meta['out_tensors']) def _peak_memory_analysis(nodes: Tuple[Node, ...]) -> int: From 42c6e8c973fd2a6dca7b90d353fc04e84f8a1819 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 12 Sep 2022 19:37:57 +0800 Subject: [PATCH 04/21] [fx] before doing save_tensor_hooks. --- colossalai/fx/passes/meta_info_prop.py | 4 +- colossalai/fx/profiler/dataflow.py | 84 +++++++++++++++----------- colossalai/fx/profiler/memory.py | 2 - colossalai/fx/profiler/profiler.py | 51 ++++++++++------ 4 files changed, 84 insertions(+), 57 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 1a1e149577c4..2c1a399dcf09 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -97,8 +97,8 @@ def extract_tensor_meta(obj): setattr(n, 'node_size', mem_stat[1]) setattr(n, 'fwd_flop', flop_count[0]) setattr(n, 'bwd_flop', flop_count[1]) - setattr(n, 'fwd_tmp', mem_stat[0]) - setattr(n, 'fwd_out', mem_stat[1]) + setattr(n, 'fwd_in', mem_stat[0]) + setattr(n, 'fwd_tmp', mem_stat[1]) setattr(n, 'bwd_tmp', mem_stat[2]) setattr(n, 'bwd_out', mem_stat[3]) n.meta['type'] = type(result) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index d731baa64034..14d635501bee 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,7 +1,6 @@ -from typing import Tuple +from typing import Dict, Tuple from torch.fx import Graph, Node -import torch -from .memory import activation_size +from .memory import INPLACE_ATEN, activation_size def is_forward(n: Node): @@ -14,6 +13,11 @@ def is_loss(n: Node): return n.meta['stage'] == 'l' +def is_placeholder(n: Node): + assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' + return n.meta['stage'] == 'p' + + def is_backward(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' return n.meta['stage'] == 'b' @@ -22,20 +26,20 @@ def is_backward(n: Node): def autograd_graph_analysis(graph: Graph) -> Tuple[int, int, int, int]: """Analyze the autograd node dependencies and find out the memory usage. Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. - Nodes should have attribute `out_tensors` indicating the output of each node. + Nodes should have attribute `out` indicating the output of each node. ============================================================================ - p o <---- We need to keep track of grad out - |\________ | - ↓ ↘| - f --------> b - |\ \_____ ↑ - | \ ↘ / - f f ----> b <---- Not every forward result needs to be saved for backward - | \____ ↑ - ↘ ↘| - f ----> b <---- Backward can be freed as soon as it is required no more. - ↘ ↗ - l + Placeholder ----> p o <---- We need to keep track of grad out + |\________ | + ↓ ↘| + f --------> b + |\ \_____ ↑ + | \ ↘ / + f f ----> b <---- Not every forward result needs to be saved for backward + | \____ ↑ + ↘ ↘| + f ----> b <---- Backward can be freed as soon as it is required no more. + ↘ ↗ + l ============================================================================= Args: graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. @@ -46,33 +50,43 @@ def autograd_graph_analysis(graph: Graph) -> Tuple[int, int, int, int]: bwd_tmp (int): Intermediate memory (or peak memory) encountered through backward pass. These tensors can be freed as long as it is not required for its users. We will use liveness analysis to detect the peak memory usage. bwd_out (int): The output of the entire backward pass. """ + + def _peak_memory(deps: Dict[Node, int]): + bwd_tmp = 0 + for k, v in deps.items(): + if v > 0: + bwd_tmp += activation_size(k.meta['out']) + return bwd_tmp + # deps is used to track all the memory dependencies of the graph. deps = {} + fwd_in = 0 fwd_tmp = 0 - fwd_out = 0 bwd_tmp = 0 bwd_out = 0 for n in graph.nodes: n: Node + if is_placeholder(n): + # a placeholder node who has any backward node users will have to be kept in memory until released + if any(map(is_backward, n.users)) and not any(map(is_loss, n.users)): + # but if its users are all inplace methods in forward pass, it should not have activations + fwd_in += activation_size(n.meta['out']) if is_forward(n): - if any(map(is_backward, n.users)): - fwd_tmp += activation_size(n.meta['out_tensors']) - if any(map(is_loss, n.users)): - fwd_out += activation_size(n.meta['out_tensors']) + # a forward node who has any backward node users will have to be kept in memory until released + if any(map(is_backward, n.users)) and not any(map(is_loss, n.users)): + # but if its users are all inplace methods in forward pass, it should not have activations + fwd_tmp += activation_size(n.meta['out']) elif is_backward(n): - if not len(n.users): - bwd_out += activation_size(n.meta['out_tensors']) - - -def _peak_memory_analysis(nodes: Tuple[Node, ...]) -> int: - """Apply liveness analysis to a list of nodes in topological order and calculate the peak memory. - - Args: - nodes (Tuple[Node, ...]): A list of nodes in topological order. - - Returns: - memory_peak (int): Peak memory encountered during the execution. - """ - pass + if len(n.users): + # liveness analysis is only used in backward + deps[n] = len(n.users) + bwd_tmp = max(bwd_tmp, _peak_memory(deps)) + for input_n in n.all_input_nodes: + if input_n in deps: + deps[input_n] -= 1 + else: + # basically a backward node without user is a `grad_out` node + bwd_out += activation_size(n.meta['out']) + return fwd_in, fwd_tmp, bwd_tmp, bwd_out diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/memory.py index be51064220e0..c023d0d1eaa3 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/memory.py @@ -14,12 +14,10 @@ INPLACE_ATEN = [ aten.add_.Tensor, - aten.add.Tensor, aten.sub_.Tensor, aten.div_.Tensor, aten.div_.Scalar, aten.mul_.Tensor, - aten.mul.Tensor, aten.bernoulli_.float, # inplace reshaping diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 0547aa3072d7..9968ec0ebee3 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,9 +1,12 @@ +from lib2to3.pytree import Node +from operator import getitem from typing import Callable, Any, Dict, Tuple import torch from torch.fx import Graph from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map -from .memory import activation_size, INPLACE_ATEN, WEIRD_OPS +from .dataflow import autograd_graph_analysis +from .memory import INPLACE_ATEN, WEIRD_OPS, activation_size from .tensor import MetaTensor from .opcount import flop_mapping @@ -43,13 +46,6 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: 'b': 0, } - # TODO: remove this - temp = { - 'f': [], - 'l': [], - 'b': [], - } - # `stage` will mark the stage of autograd from outside scope. stage = 'f' @@ -61,6 +57,8 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: # Hopefully, this attempt will provide a better estimation of memory. class FlopTensor(MetaTensor): + _node: Node + def __repr__(self): if self.grad_fn: return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})" @@ -74,8 +72,12 @@ def unwrap(x): x = FlopTensor(x.to('meta')) return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - def to_meta(x): - return x.to('meta') if isinstance(x, torch.Tensor) else x + def get_node(x): + return None if not hasattr(x, '_node') else x._node + + args_node = tree_map(get_node, args) + kwargs_node = tree_map(get_node, kwargs) + node = subgraph.create_node('call_function', func, args_node, kwargs_node) args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) @@ -83,13 +85,18 @@ def to_meta(x): # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) - if func not in INPLACE_ATEN: - temp[stage].append(tree_map(to_meta, normalize_tuple(out))) + node.meta['out'] = normalize_tuple(out) + node.meta['stage'] = stage def wrap(x): return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x - return tree_map(wrap, out) + def set_node(x): + x._node = node + + out = tree_map(wrap, out) + tree_map(set_node, out) + return out # `WEIRD_OPS` are tough to handle because they don't accept autograd # on meta tensor. @@ -108,6 +115,17 @@ def wrap(x): args = tree_map(wrap, args) kwargs = tree_map(wrap, kwargs) + def set_placeholder(x): + if isinstance(x, FlopTensor): + x._node = subgraph.create_node('placeholder', + 'placeholder', (subgraph._root,), + name=subgraph._graph_namespace.create_name('input', x._tensor)) + x._node.meta['stage'] = 'p' + x._node.meta['out'] = (x._tensor,) + + tree_map(set_placeholder, args) + tree_map(set_placeholder, kwargs) + if isinstance(target, str): # args[0] is the `self` object for this method call self_obj, *args_tail = args @@ -126,15 +144,12 @@ def wrap(x): fwd_flop = flop_count['f'] bwd_flop = flop_count['b'] - fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0 - fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0 - bwd_tmp = max(map(activation_size, temp['b'][:-1])) if len(temp['b'][:-1]) else 0 - fwd_out = activation_size(temp['b'][-1]) if len(temp['b']) else 0 + fwd_in, fwd_tmp, bwd_tmp, bwd_out = autograd_graph_analysis(subgraph) def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, 0) + return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_in, fwd_tmp, bwd_tmp, bwd_out) def profile_function(target: 'Target') -> Callable: From 5f25d6e001ee8722a7b70f35adbe2c4334e9beb8 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 12 Sep 2022 22:23:39 +0800 Subject: [PATCH 05/21] [fx] provide an accurate estimation of memory except for GPT-2. --- colossalai/fx/profiler/dataflow.py | 18 ++++++++++-------- colossalai/fx/profiler/profiler.py | 26 ++++++++++++++++++++------ 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 14d635501bee..dedff555763c 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -68,15 +68,17 @@ def _peak_memory(deps: Dict[Node, int]): for n in graph.nodes: n: Node - if is_placeholder(n): - # a placeholder node who has any backward node users will have to be kept in memory until released - if any(map(is_backward, n.users)) and not any(map(is_loss, n.users)): - # but if its users are all inplace methods in forward pass, it should not have activations + if n.meta['save'] and not any(map(is_loss, n.users)): + # A forward tensor who is marked `save` but is not + # an input to `loss` should be saved during forward. + # If the tensor is a placeholder, then it belongs to `fwd_in`. + # Any `fwd_in` should be kept in memory even this function + # is checkpointed. + # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint + # the node, `fwd_tmp` can be freed. + if is_placeholder(n): fwd_in += activation_size(n.meta['out']) - if is_forward(n): - # a forward node who has any backward node users will have to be kept in memory until released - if any(map(is_backward, n.users)) and not any(map(is_loss, n.users)): - # but if its users are all inplace methods in forward pass, it should not have activations + if is_forward(n): fwd_tmp += activation_size(n.meta['out']) elif is_backward(n): if len(n.users): diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 9968ec0ebee3..10b7b3fa7cb1 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -87,6 +87,7 @@ def get_node(x): flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) node.meta['out'] = normalize_tuple(out) node.meta['stage'] = stage + node.meta['save'] = False def wrap(x): return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x @@ -122,16 +123,29 @@ def set_placeholder(x): name=subgraph._graph_namespace.create_name('input', x._tensor)) x._node.meta['stage'] = 'p' x._node.meta['out'] = (x._tensor,) + x._node.meta['save'] = False tree_map(set_placeholder, args) tree_map(set_placeholder, kwargs) - if isinstance(target, str): - # args[0] is the `self` object for this method call - self_obj, *args_tail = args - out = getattr(self_obj, target)(*args_tail, **kwargs) - else: - out = target(*args, **kwargs) + fwd_in = 0 + + def pack(x): + if isinstance(x, FlopTensor): + x._node.meta['save'] = True + return x + + def unpack(x): + return x + + # mark saved tensors with save_tensors_hooks + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + if isinstance(target, str): + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + out = getattr(self_obj, target)(*args_tail, **kwargs) + else: + out = target(*args, **kwargs) # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. From 3745c5ff3fca2c0fe4d3feac58bf705304ef2a32 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 12 Sep 2022 22:26:52 +0800 Subject: [PATCH 06/21] [fx] provide an accurate estimation of memory except for GPT-2. --- colossalai/fx/passes/meta_info_prop.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 2c1a399dcf09..f43be9453817 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -94,13 +94,13 @@ def extract_tensor_meta(obj): n.meta['tensor_meta'] = meta # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', mem_stat[1]) - setattr(n, 'fwd_flop', flop_count[0]) - setattr(n, 'bwd_flop', flop_count[1]) - setattr(n, 'fwd_in', mem_stat[0]) - setattr(n, 'fwd_tmp', mem_stat[1]) - setattr(n, 'bwd_tmp', mem_stat[2]) - setattr(n, 'bwd_out', mem_stat[3]) + setattr(n, 'node_size', mem_stat[0] + mem_stat[1]) + n.meta['fwd_flop'] = flop_count[0] + n.meta['bwd_flop'] = flop_count[1] + n.meta['fwd_in'] = mem_stat[0] + n.meta['fwd_tmp'] = mem_stat[1] + n.meta['bwd_tmp'] = mem_stat[2] + n.meta['bwd_out'] = mem_stat[3] n.meta['type'] = type(result) for param in self.module.parameters(): From 504c607fe05502fbd65d9f734d39bc79bf89abe3 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 12 Sep 2022 22:27:41 +0800 Subject: [PATCH 07/21] [fx] provide an accurate estimation of memory except for GPT-2. --- colossalai/fx/profiler/profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 10b7b3fa7cb1..e068fa8e6c1a 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -6,7 +6,7 @@ from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map from .dataflow import autograd_graph_analysis -from .memory import INPLACE_ATEN, WEIRD_OPS, activation_size +from .memory import WEIRD_OPS from .tensor import MetaTensor from .opcount import flop_mapping From 5d72a5275d9201e5ff64b4c3a7d246b500042894 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 11:28:28 +0800 Subject: [PATCH 08/21] [fx] a very accurate version on GPT-2. --- colossalai/fx/passes/meta_info_prop.py | 29 +++++++++--------- colossalai/fx/profiler/dataflow.py | 30 +++++++++---------- colossalai/fx/profiler/profiler.py | 41 ++++++++++++-------------- 3 files changed, 48 insertions(+), 52 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index f43be9453817..15aca1d2dcc6 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -40,8 +40,10 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: class MetaInfoProp(torch.fx.Interpreter): """ Execute an FX graph Node-by-Node with meta tensor and - record the shape, FLOPs, MACs and type of the result + record the memory usage, FLOPs, and type of the result into the corresponding node. + All information should be retrieved with + `node.meta.get(key, default=0)`. Usage: BATCH_SIZE = 2 @@ -82,7 +84,7 @@ def run_node(self, n: Node) -> Any: Returns: Any: The result of executing ``n`` """ - result, flop_count, mem_stat = super().run_node(n) + result, meta_info = super().run_node(n) def extract_tensor_meta(obj): if isinstance(obj, torch.Tensor): @@ -90,21 +92,20 @@ def extract_tensor_meta(obj): else: return TensorMetadata(None, None, False, None, 0, False) - meta = tree_map(extract_tensor_meta, result) - n.meta['tensor_meta'] = meta + tensor_meta = tree_map(extract_tensor_meta, result) + n.meta['tensor_meta'] = tensor_meta # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', mem_stat[0] + mem_stat[1]) - n.meta['fwd_flop'] = flop_count[0] - n.meta['bwd_flop'] = flop_count[1] - n.meta['fwd_in'] = mem_stat[0] - n.meta['fwd_tmp'] = mem_stat[1] - n.meta['bwd_tmp'] = mem_stat[2] - n.meta['bwd_out'] = mem_stat[3] + setattr(n, 'node_size', meta_info.get('fwd_tmp', 0) + meta_info.get('fwd_out', 0)) + n.meta = {**n.meta, **meta_info} + for par in n.all_input_nodes: + par.meta['fwd_out'] = max(par.meta.get('fwd_out', 0), meta_info.get('fwd_in', 0)) n.meta['type'] = type(result) + # retain the autograd graph for param in self.module.parameters(): param.grad = None + return result # Main Node running APIs @@ -130,7 +131,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict """ result = super().placeholder(target, args, kwargs) # A placeholder node only has activation - return result, (0, 0), (0, activation_size(result), 0, 0) + return result, {} @compatibility(is_backward_compatible=True) def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -150,7 +151,7 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ - return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0) + return super().get_attr(target, args, kwargs), {} @compatibility(is_backward_compatible=True) def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -232,7 +233,7 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ - return args[0], (0, 0), (0, 0, 0, 0) + return args[0], {'fwd_in': activation_size(args[0])} def propagate(self, *args): """ diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index dedff555763c..7eb5d868b190 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,4 +1,4 @@ -from typing import Dict, Tuple +from typing import Dict from torch.fx import Graph, Node from .memory import INPLACE_ATEN, activation_size @@ -23,7 +23,7 @@ def is_backward(n: Node): return n.meta['stage'] == 'b' -def autograd_graph_analysis(graph: Graph) -> Tuple[int, int, int, int]: +def autograd_graph_analysis(graph: Graph) -> Dict[str, int]: """Analyze the autograd node dependencies and find out the memory usage. Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Nodes should have attribute `out` indicating the output of each node. @@ -45,10 +45,7 @@ def autograd_graph_analysis(graph: Graph) -> Tuple[int, int, int, int]: graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Returns: - fwd_tmp (int): Intermediate memory encountered through forward pass. These tensors are not supposed to be freed unless checkpointed. - fwd_out (int): The output of the entire forward pass. - bwd_tmp (int): Intermediate memory (or peak memory) encountered through backward pass. These tensors can be freed as long as it is not required for its users. We will use liveness analysis to detect the peak memory usage. - bwd_out (int): The output of the entire backward pass. + meta (Dict): Meta information for the dataflow. """ def _peak_memory(deps: Dict[Node, int]): @@ -60,11 +57,12 @@ def _peak_memory(deps: Dict[Node, int]): # deps is used to track all the memory dependencies of the graph. deps = {} - - fwd_in = 0 - fwd_tmp = 0 - bwd_tmp = 0 - bwd_out = 0 + meta = { + 'fwd_in': 0, + 'fwd_tmp': 0, + 'bwd_tmp': 0, + 'bwd_out': 0, + } for n in graph.nodes: n: Node @@ -77,18 +75,18 @@ def _peak_memory(deps: Dict[Node, int]): # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint # the node, `fwd_tmp` can be freed. if is_placeholder(n): - fwd_in += activation_size(n.meta['out']) + meta['fwd_in'] += activation_size(n.meta['out']) if is_forward(n): - fwd_tmp += activation_size(n.meta['out']) + meta['fwd_tmp'] += activation_size(n.meta['out']) elif is_backward(n): if len(n.users): # liveness analysis is only used in backward deps[n] = len(n.users) - bwd_tmp = max(bwd_tmp, _peak_memory(deps)) + meta['bwd_tmp'] = max(meta['bwd_tmp'], _peak_memory(deps)) for input_n in n.all_input_nodes: if input_n in deps: deps[input_n] -= 1 else: # basically a backward node without user is a `grad_out` node - bwd_out += activation_size(n.meta['out']) - return fwd_in, fwd_tmp, bwd_tmp, bwd_out + meta['bwd_out'] += activation_size(n.meta['out']) + return meta diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index e068fa8e6c1a..1acedef69462 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -6,7 +6,7 @@ from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map from .dataflow import autograd_graph_analysis -from .memory import WEIRD_OPS +from .memory import WEIRD_OPS, activation_size from .tensor import MetaTensor from .opcount import flop_mapping @@ -67,11 +67,6 @@ def __repr__(self): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - def unwrap(x): - if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): - x = FlopTensor(x.to('meta')) - return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - def get_node(x): return None if not hasattr(x, '_node') else x._node @@ -79,6 +74,12 @@ def get_node(x): kwargs_node = tree_map(get_node, kwargs) node = subgraph.create_node('call_function', func, args_node, kwargs_node) + def unwrap(x): + # if x is a `nn.Parameter`, we can first wrap it with `FlopTensor` + if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): + x = FlopTensor(x.to('meta')) + return x._tensor.to('meta') if isinstance(x, FlopTensor) else x + args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) @@ -128,8 +129,6 @@ def set_placeholder(x): tree_map(set_placeholder, args) tree_map(set_placeholder, kwargs) - fwd_in = 0 - def pack(x): if isinstance(x, FlopTensor): x._node.meta['save'] = True @@ -138,7 +137,7 @@ def pack(x): def unpack(x): return x - # mark saved tensors with save_tensors_hooks + # mark saved tensors with saved_tensors_hooks with torch.autograd.graph.saved_tensors_hooks(pack, unpack): if isinstance(target, str): # args[0] is the `self` object for this method call @@ -155,15 +154,13 @@ def unpack(x): stage = 'b' loss.backward() - fwd_flop = flop_count['f'] - bwd_flop = flop_count['b'] - - fwd_in, fwd_tmp, bwd_tmp, bwd_out = autograd_graph_analysis(subgraph) + flop_meta = {'fwd_flop': flop_count['f'], 'bwd_flop': flop_count['b']} + mem_meta = autograd_graph_analysis(subgraph) def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_in, fwd_tmp, bwd_tmp, bwd_out) + return tree_map(unwrap, out), {**flop_meta, **mem_meta} def profile_function(target: 'Target') -> Callable: @@ -189,9 +186,9 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) out = func(*args, **kwargs) - return out, (out.numel(), out.numel()), (0, 0, 0, 0) - out, flop_count, mem_stat = _profile(func, *args, **kwargs) - return out, flop_count, mem_stat + return out, {'fwd_flop': out.numel(), 'bwd_flop': out.numel()} + out, meta = _profile(func, *args, **kwargs) + return out, meta f.__name__ = target.__name__ func = target @@ -207,8 +204,8 @@ def profile_method(target: 'Target') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # execute the method and return the result assert isinstance(target, str), f'{target} instance is not str.' - out, flop_count, mem_stat = _profile(target, *args, **kwargs) - return out, flop_count, mem_stat + out, meta = _profile(target, *args, **kwargs) + return out, meta return f @@ -236,9 +233,9 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: args = tree_map(lambda x: x.to('meta'), args) kwargs = tree_map(lambda x: x.to('meta'), kwargs) out = func(*args, **kwargs) - return out, (out.numel(), out.numel()), (0, 0, 0, 0) - out, flop_count, mem_stat = _profile(func, *args, **kwargs) - return out, flop_count, mem_stat + return out, {'fwd_flop': out.numel(), 'bwd_flop': out.numel()} + out, meta = _profile(func, *args, **kwargs) + return out, meta f.__name__ = module.__class__.__name__ func = module.forward From 6bdeb298f65522f4d7d35c240d62bbd5d02252ac Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 15:27:06 +0800 Subject: [PATCH 09/21] [fx] refactor code. --- colossalai/fx/passes/meta_info_prop.py | 36 +++---- colossalai/fx/profiler/__init__.py | 2 +- colossalai/fx/profiler/dataflow.py | 77 +++++++++++---- .../fx/profiler/experimental/profiler.py | 7 +- colossalai/fx/profiler/profiler.py | 97 +++++++++++-------- 5 files changed, 132 insertions(+), 87 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 15aca1d2dcc6..e3ba8955c80a 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -1,3 +1,5 @@ +from dataclasses import asdict +from colossalai.fx.profiler.profiler import MetaInfo import torch import torch.fx from torch.fx.node import Node, Argument, Target @@ -42,8 +44,6 @@ class MetaInfoProp(torch.fx.Interpreter): Execute an FX graph Node-by-Node with meta tensor and record the memory usage, FLOPs, and type of the result into the corresponding node. - All information should be retrieved with - `node.meta.get(key, default=0)`. Usage: BATCH_SIZE = 2 @@ -94,12 +94,12 @@ def extract_tensor_meta(obj): tensor_meta = tree_map(extract_tensor_meta, result) n.meta['tensor_meta'] = tensor_meta + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', meta_info.get('fwd_tmp', 0) + meta_info.get('fwd_out', 0)) - n.meta = {**n.meta, **meta_info} + setattr(n, 'node_size', n.meta.get('fwd_tmp', 0) + n.meta.get('fwd_out', 0)) for par in n.all_input_nodes: - par.meta['fwd_out'] = max(par.meta.get('fwd_out', 0), meta_info.get('fwd_in', 0)) + par.meta['fwd_out'] = max(par.meta.get('fwd_out', 0), n.meta.get('fwd_in', 0)) n.meta['type'] = type(result) # retain the autograd graph @@ -126,12 +126,9 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict Returns: result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - result = super().placeholder(target, args, kwargs) - # A placeholder node only has activation - return result, {} + return super().placeholder(target, args, kwargs), MetaInfo() @compatibility(is_backward_compatible=True) def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -148,10 +145,9 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st Return: result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return super().get_attr(target, args, kwargs), {} + return super().get_attr(target, args, kwargs), MetaInfo() @compatibility(is_backward_compatible=True) def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -167,8 +163,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di Return result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ assert not isinstance(target, str) return profile_function(target)(*args, **kwargs) @@ -187,8 +182,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict Return result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ return profile_method(target)(*args, **kwargs) @@ -206,8 +200,7 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict Return result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ # Retrieve executed args and kwargs values from the environment # Execute the method and return the result @@ -230,10 +223,9 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Return: result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return args[0], {'fwd_in': activation_size(args[0])} + return args[0], MetaInfo(fwd_in=activation_size(args[0])) def propagate(self, *args): """ diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index 1b46bd494a98..bc9418938fc9 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -2,7 +2,7 @@ if META_COMPATIBILITY: from .opcount import flop_mapping from .tensor import MetaTensor - from .profiler import profile_function, profile_method, profile_module, _profile + from .profiler import profile_function, profile_method, profile_module else: from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 7eb5d868b190..401af22460f0 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,29 +1,73 @@ +from dataclasses import dataclass +from enum import Enum from typing import Dict from torch.fx import Graph, Node -from .memory import INPLACE_ATEN, activation_size +from .memory import activation_size + + +class Stage(Enum): + F = 0 + L = 1 + B = 2 + P = 3 + + +@dataclass +class GraphInfo: + """ + GraphInfo is a dataclass for the dataflow analysis. + The dataflow analysis is conducted on a single node of the FX graph. + ============================================================================ + ------------------------------- + | Node | + [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out` + placeholders saved for | | \__________ | | + backward. | | \ | | + | [fwd_tmp] ------> [bwd_tmp] | <----- + | | \_________ | | [bwd_tmp] marks the peak memory + | / \ \ | | in backward pass. + [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- + in [fwd_tmp] because | | | \_____ | | + it is not saved for | | | \ | | + backward. ------------------------------- + ============================================================================ + Attributes: + fwd_in (int): See the above illustration. + fwd_tmp (int): See the above illustration. + bwd_tmp (int): See the above illustration. + bwd_out (int): See the above illustration. + """ + fwd_in: int = 0 + fwd_tmp: int = 0 + bwd_tmp: int = 0 + bwd_out: int = 0 def is_forward(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == 'f' + return n.meta['stage'] == Stage.F def is_loss(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == 'l' + return n.meta['stage'] == Stage.L def is_placeholder(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == 'p' + return n.meta['stage'] == Stage.P def is_backward(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == 'b' + return n.meta['stage'] == Stage.B + + +def is_saved(n: Node): + return n.meta.get('saved', False) -def autograd_graph_analysis(graph: Graph) -> Dict[str, int]: +def autograd_graph_analysis(graph: Graph) -> GraphInfo: """Analyze the autograd node dependencies and find out the memory usage. Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Nodes should have attribute `out` indicating the output of each node. @@ -45,7 +89,7 @@ def autograd_graph_analysis(graph: Graph) -> Dict[str, int]: graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Returns: - meta (Dict): Meta information for the dataflow. + graphinfo (GraphInfo): Meta information for the dataflow. """ def _peak_memory(deps: Dict[Node, int]): @@ -57,16 +101,11 @@ def _peak_memory(deps: Dict[Node, int]): # deps is used to track all the memory dependencies of the graph. deps = {} - meta = { - 'fwd_in': 0, - 'fwd_tmp': 0, - 'bwd_tmp': 0, - 'bwd_out': 0, - } + graph_info = GraphInfo() for n in graph.nodes: n: Node - if n.meta['save'] and not any(map(is_loss, n.users)): + if is_saved(n) and not any(map(is_loss, n.users)): # A forward tensor who is marked `save` but is not # an input to `loss` should be saved during forward. # If the tensor is a placeholder, then it belongs to `fwd_in`. @@ -75,18 +114,18 @@ def _peak_memory(deps: Dict[Node, int]): # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint # the node, `fwd_tmp` can be freed. if is_placeholder(n): - meta['fwd_in'] += activation_size(n.meta['out']) + graph_info.fwd_in += activation_size(n.meta['out']) if is_forward(n): - meta['fwd_tmp'] += activation_size(n.meta['out']) + graph_info.fwd_tmp += activation_size(n.meta['out']) elif is_backward(n): if len(n.users): # liveness analysis is only used in backward deps[n] = len(n.users) - meta['bwd_tmp'] = max(meta['bwd_tmp'], _peak_memory(deps)) + graph_info.bwd_tmp = max(graph_info.bwd_tmp, _peak_memory(deps)) for input_n in n.all_input_nodes: if input_n in deps: deps[input_n] -= 1 else: # basically a backward node without user is a `grad_out` node - meta['bwd_out'] += activation_size(n.meta['out']) - return meta + graph_info.bwd_out += activation_size(n.meta['out']) + return graph_info diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index 46d4add3c5e9..9a746f95fc15 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -3,6 +3,7 @@ from torch.fx.node import Argument, Target from . import meta_profiler_function, meta_profiler_module from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS +from ..profiler import MetaInfo __all__ = ['profile_function', 'profile_module', 'profile_method'] @@ -59,7 +60,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: else: profiler = meta_profiler_function.get(target.__name__) fwd_flop, _ = profiler(*args, **kwargs) - return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, MetaInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) f.__name__ = target.__name__ func = target @@ -88,7 +89,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) - return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, MetaInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) return f @@ -118,7 +119,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: fwd_out = activation_size(out) profiler = meta_profiler_module.get(type(module)) fwd_flop, _ = profiler(module, *args, **kwargs) - return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, MetaInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) f.__name__ = module.__class__.__name__ func = module.forward diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 1acedef69462..47558368fd73 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,16 +1,16 @@ -from lib2to3.pytree import Node -from operator import getitem +from dataclasses import dataclass +from enum import auto from typing import Callable, Any, Dict, Tuple import torch -from torch.fx import Graph +from torch.fx import Graph, Node from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map -from .dataflow import autograd_graph_analysis +from .dataflow import autograd_graph_analysis, Stage from .memory import WEIRD_OPS, activation_size from .tensor import MetaTensor from .opcount import flop_mapping -__all__ = ['profile_function', 'profile_module', 'profile_method', '_profile'] +__all__ = ['profile_function', 'profile_module', 'profile_method'] def normalize_tuple(x): @@ -23,7 +23,28 @@ def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() -def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: +@dataclass +class MetaInfo: + """ + This is a dataclass for MetaInfo, which measures + the execution memory cost and FLOPs with `MetaTensor`. + Attributes: + fwd_flop (int): The forward FLOPs of a certain node + bwd_flop (int): The backward FLOPs of a certain node. + fwd_in (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py + fwd_tmp (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py + bwd_tmp (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py + bwd_out (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py + """ + fwd_flop: int = 0 + bwd_flop: int = 0 + fwd_in: int = 0 + fwd_tmp: int = 0 + bwd_tmp: int = 0 + bwd_out: int = 0 + + +def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]: """Profile a Callable function with args and kwargs. Args: @@ -32,22 +53,23 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: kwargs (Any): Argument Returns: - out (Tuple[Any, ...]): The argument value that was retrieved - flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + out (Tuple[Any, ...]): The argument value that was retrieved. + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ # This subgraph traces aten level ops inside one node. subgraph = Graph() - # flop_count serves as a global dictionary to store results. + meta_info = MetaInfo() + + # `flop_count`` serves as a global dictionary to store results. flop_count = { - 'f': 0, - 'l': 0, - 'b': 0, + Stage.F: 0, + Stage.L: 0, + Stage.B: 0, } # `stage` will mark the stage of autograd from outside scope. - stage = 'f' + stage = Stage.F # FlopTensor not only get the flop statistics of a single node, # it also build a full autograd graph for this node. @@ -88,7 +110,6 @@ def unwrap(x): flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) node.meta['out'] = normalize_tuple(out) node.meta['stage'] = stage - node.meta['save'] = False def wrap(x): return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x @@ -105,13 +126,13 @@ def set_node(x): if target not in WEIRD_OPS: def wrap(x): - return FlopTensor( - x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x + return FlopTensor(x.detach().requires_grad_( + True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x else: def wrap(x): - return FlopTensor( - x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x + return FlopTensor(x.detach().requires_grad_( + False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x # Basically, we need to detach the args and kwargs from the outer graph. args = tree_map(wrap, args) @@ -122,16 +143,15 @@ def set_placeholder(x): x._node = subgraph.create_node('placeholder', 'placeholder', (subgraph._root,), name=subgraph._graph_namespace.create_name('input', x._tensor)) - x._node.meta['stage'] = 'p' + x._node.meta['stage'] = Stage.P x._node.meta['out'] = (x._tensor,) - x._node.meta['save'] = False tree_map(set_placeholder, args) tree_map(set_placeholder, kwargs) def pack(x): if isinstance(x, FlopTensor): - x._node.meta['save'] = True + x._node.meta['saved'] = True return x def unpack(x): @@ -148,19 +168,22 @@ def unpack(x): # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. - if is_autogradable(out) and out.requires_grad: - stage = 'l' + if is_autogradable(out) and out.requires_grad and not inplace: + stage = Stage.L loss = out.sum() - stage = 'b' + stage = Stage.B loss.backward() - flop_meta = {'fwd_flop': flop_count['f'], 'bwd_flop': flop_count['b']} - mem_meta = autograd_graph_analysis(subgraph) + graph_info = autograd_graph_analysis(subgraph) + meta_info.fwd_flop, meta_info.bwd_flop = flop_count[Stage.F], flop_count[Stage.B] + meta_info.__dict__.update(graph_info.__dict__) + if inplace: + meta_info.fwd_in = 0 def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - return tree_map(unwrap, out), {**flop_meta, **mem_meta} + return tree_map(unwrap, out), meta_info def profile_function(target: 'Target') -> Callable: @@ -175,18 +198,13 @@ def profile_function(target: 'Target') -> Callable: Examples: >>> input = torch.rand(100, 100, 100, 100, device='meta') >>> func = torch.nn.functional.relu - >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False) + >>> output, meta_info = profile_function(func)(input, inplace=False) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # If there is an argument that this `call_function` is inplace, we should # skip the autograd profiling. - if kwargs.get('inplace', False): - args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) - kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) - out = func(*args, **kwargs) - return out, {'fwd_flop': out.numel(), 'bwd_flop': out.numel()} out, meta = _profile(func, *args, **kwargs) return out, meta @@ -204,7 +222,7 @@ def profile_method(target: 'Target') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # execute the method and return the result assert isinstance(target, str), f'{target} instance is not str.' - out, meta = _profile(target, *args, **kwargs) + out, meta = _profile(target, *args, inplace=False, **kwargs) return out, meta return f @@ -222,19 +240,14 @@ def profile_module(module: torch.nn.Module) -> Callable: Example: >>> input = torch.rand(4, 3, 224, 224, device='meta') >>> mod = torch.nn.Conv2d(3, 128, 3) - >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input) + >>> output, meta_info = profile_module(mod)(input) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # If there is an argument that this `call_module` is inplace, we should # skip the autograd profiling. - if getattr(module, 'inplace', False): - args = tree_map(lambda x: x.to('meta'), args) - kwargs = tree_map(lambda x: x.to('meta'), kwargs) - out = func(*args, **kwargs) - return out, {'fwd_flop': out.numel(), 'bwd_flop': out.numel()} - out, meta = _profile(func, *args, **kwargs) + out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs) return out, meta f.__name__ = module.__class__.__name__ From fafb7d0a01b720d265679d90268921482f2d8233 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 15:40:24 +0800 Subject: [PATCH 10/21] [fx] remove redundant inplace=True. --- colossalai/fx/profiler/dataflow.py | 2 ++ colossalai/fx/profiler/profiler.py | 4 +--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 401af22460f0..12b8a3f48b3a 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -115,8 +115,10 @@ def _peak_memory(deps: Dict[Node, int]): # the node, `fwd_tmp` can be freed. if is_placeholder(n): graph_info.fwd_in += activation_size(n.meta['out']) + # print(activation_size(n.meta['out'])) if is_forward(n): graph_info.fwd_tmp += activation_size(n.meta['out']) + # print(activation_size(n.meta['out'])) elif is_backward(n): if len(n.users): # liveness analysis is only used in backward diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 47558368fd73..33912ec377ac 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -168,7 +168,7 @@ def unpack(x): # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. - if is_autogradable(out) and out.requires_grad and not inplace: + if is_autogradable(out) and out.requires_grad: stage = Stage.L loss = out.sum() stage = Stage.B @@ -177,8 +177,6 @@ def unpack(x): graph_info = autograd_graph_analysis(subgraph) meta_info.fwd_flop, meta_info.bwd_flop = flop_count[Stage.F], flop_count[Stage.B] meta_info.__dict__.update(graph_info.__dict__) - if inplace: - meta_info.fwd_in = 0 def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x From d3c3690d8e1ef409be8fb6ce771343d46bf73ad3 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 16:48:22 +0800 Subject: [PATCH 11/21] [fx] refactor code. --- colossalai/fx/passes/meta_info_prop.py | 14 ++++---- colossalai/fx/profiler/__init__.py | 1 + colossalai/fx/profiler/dataflow.py | 49 ++++++++++++++------------ colossalai/fx/profiler/profiler.py | 49 +++++++------------------- 4 files changed, 48 insertions(+), 65 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index e3ba8955c80a..1d2638a02a10 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -1,12 +1,12 @@ from dataclasses import asdict -from colossalai.fx.profiler.profiler import MetaInfo +from colossalai.fx.profiler import GraphInfo import torch import torch.fx from torch.fx.node import Node, Argument, Target from torch.utils._pytree import tree_map from typing import Any, Tuple, NamedTuple, Dict from torch.fx._compatibility import compatibility -from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size +from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size @compatibility(is_backward_compatible=True) @@ -97,9 +97,9 @@ def extract_tensor_meta(obj): n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', n.meta.get('fwd_tmp', 0) + n.meta.get('fwd_out', 0)) + setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) for par in n.all_input_nodes: - par.meta['fwd_out'] = max(par.meta.get('fwd_out', 0), n.meta.get('fwd_in', 0)) + par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0) n.meta['type'] = type(result) # retain the autograd graph @@ -128,7 +128,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict result (Any): The argument value that was retrieved meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return super().placeholder(target, args, kwargs), MetaInfo() + return super().placeholder(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -147,7 +147,7 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st result (Any): The argument value that was retrieved meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return super().get_attr(target, args, kwargs), MetaInfo() + return super().get_attr(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -225,7 +225,7 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, result (Any): The argument value that was retrieved meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return args[0], MetaInfo(fwd_in=activation_size(args[0])) + return args[0], GraphInfo(fwd_mem_in=activation_size(args[0])) def propagate(self, *args): """ diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index bc9418938fc9..fb19618b2762 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -6,4 +6,5 @@ else: from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module +from .dataflow import GraphInfo from .memory import parameter_size, activation_size diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 12b8a3f48b3a..da3b56a2adf9 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -6,16 +6,17 @@ class Stage(Enum): - F = 0 - L = 1 - B = 2 - P = 3 + FORWARD = 0 + LOSS = 1 + BACKWARD = 2 + PLACEHOLDER = 3 @dataclass class GraphInfo: """ - GraphInfo is a dataclass for the dataflow analysis. + GraphInfo is a dataclass for MetaInfo, which measures + the execution memory cost and FLOPs with `MetaTensor`. The dataflow analysis is conducted on a single node of the FX graph. ============================================================================ ------------------------------- @@ -32,35 +33,39 @@ class GraphInfo: backward. ------------------------------- ============================================================================ Attributes: - fwd_in (int): See the above illustration. - fwd_tmp (int): See the above illustration. - bwd_tmp (int): See the above illustration. - bwd_out (int): See the above illustration. + fwd_flop (int): The forward FLOPs of a certain node + bwd_flop (int): The backward FLOPs of a certain node. + fwd_mem_in (int): See the above illustration. + fwd_mem_tmp (int): See the above illustration. + bwd_mem_tmp (int): See the above illustration. + bwd_mem_out (int): See the above illustration. """ - fwd_in: int = 0 - fwd_tmp: int = 0 - bwd_tmp: int = 0 - bwd_out: int = 0 + fwd_flop: int = 0 + bwd_flop: int = 0 + fwd_mem_in: int = 0 + fwd_mem_tmp: int = 0 + bwd_mem_tmp: int = 0 + bwd_mem_out: int = 0 def is_forward(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Stage.F + return n.meta['stage'] == Stage.FORWARD def is_loss(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Stage.L + return n.meta['stage'] == Stage.LOSS def is_placeholder(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Stage.P + return n.meta['stage'] == Stage.PLACEHOLDER def is_backward(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Stage.B + return n.meta['stage'] == Stage.BACKWARD def is_saved(n: Node): @@ -89,7 +94,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Returns: - graphinfo (GraphInfo): Meta information for the dataflow. + graph_info (GraphInfo): Meta information for the dataflow. """ def _peak_memory(deps: Dict[Node, int]): @@ -114,20 +119,20 @@ def _peak_memory(deps: Dict[Node, int]): # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint # the node, `fwd_tmp` can be freed. if is_placeholder(n): - graph_info.fwd_in += activation_size(n.meta['out']) + graph_info.fwd_mem_in += activation_size(n.meta['out']) # print(activation_size(n.meta['out'])) if is_forward(n): - graph_info.fwd_tmp += activation_size(n.meta['out']) + graph_info.fwd_mem_tmp += activation_size(n.meta['out']) # print(activation_size(n.meta['out'])) elif is_backward(n): if len(n.users): # liveness analysis is only used in backward deps[n] = len(n.users) - graph_info.bwd_tmp = max(graph_info.bwd_tmp, _peak_memory(deps)) + graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) for input_n in n.all_input_nodes: if input_n in deps: deps[input_n] -= 1 else: # basically a backward node without user is a `grad_out` node - graph_info.bwd_out += activation_size(n.meta['out']) + graph_info.bwd_mem_out += activation_size(n.meta['out']) return graph_info diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 33912ec377ac..347c68c3ac5c 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -6,7 +6,7 @@ from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map from .dataflow import autograd_graph_analysis, Stage -from .memory import WEIRD_OPS, activation_size +from .memory import WEIRD_OPS from .tensor import MetaTensor from .opcount import flop_mapping @@ -23,29 +23,9 @@ def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() -@dataclass -class MetaInfo: - """ - This is a dataclass for MetaInfo, which measures - the execution memory cost and FLOPs with `MetaTensor`. - Attributes: - fwd_flop (int): The forward FLOPs of a certain node - bwd_flop (int): The backward FLOPs of a certain node. - fwd_in (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py - fwd_tmp (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py - bwd_tmp (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py - bwd_out (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py - """ - fwd_flop: int = 0 - bwd_flop: int = 0 - fwd_in: int = 0 - fwd_tmp: int = 0 - bwd_tmp: int = 0 - bwd_out: int = 0 - - def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]: - """Profile a Callable function with args and kwargs. + """ + Profile a Callable function with args and kwargs. Args: target (Callable): A Callable function @@ -54,22 +34,20 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... Returns: out (Tuple[Any, ...]): The argument value that was retrieved. - meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ # This subgraph traces aten level ops inside one node. subgraph = Graph() - meta_info = MetaInfo() - # `flop_count`` serves as a global dictionary to store results. flop_count = { - Stage.F: 0, - Stage.L: 0, - Stage.B: 0, + Stage.FORWARD: 0, + Stage.LOSS: 0, + Stage.BACKWARD: 0, } # `stage` will mark the stage of autograd from outside scope. - stage = Stage.F + stage = Stage.FORWARD # FlopTensor not only get the flop statistics of a single node, # it also build a full autograd graph for this node. @@ -143,7 +121,7 @@ def set_placeholder(x): x._node = subgraph.create_node('placeholder', 'placeholder', (subgraph._root,), name=subgraph._graph_namespace.create_name('input', x._tensor)) - x._node.meta['stage'] = Stage.P + x._node.meta['stage'] = Stage.PLACEHOLDER x._node.meta['out'] = (x._tensor,) tree_map(set_placeholder, args) @@ -169,19 +147,18 @@ def unpack(x): # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. if is_autogradable(out) and out.requires_grad: - stage = Stage.L + stage = Stage.LOSS loss = out.sum() - stage = Stage.B + stage = Stage.BACKWARD loss.backward() graph_info = autograd_graph_analysis(subgraph) - meta_info.fwd_flop, meta_info.bwd_flop = flop_count[Stage.F], flop_count[Stage.B] - meta_info.__dict__.update(graph_info.__dict__) + graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD] def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - return tree_map(unwrap, out), meta_info + return tree_map(unwrap, out), graph_info def profile_function(target: 'Target') -> Callable: From de2be8f6d238c108d1ca4fdb76c19909c4fa4932 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 16:50:12 +0800 Subject: [PATCH 12/21] [fx] refactor code. --- .../fx/profiler/experimental/profiler.py | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index 9a746f95fc15..cbf95c1ba65c 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -1,12 +1,50 @@ +from dataclasses import dataclass from typing import Callable, Any, Dict, Tuple import torch from torch.fx.node import Argument, Target from . import meta_profiler_function, meta_profiler_module from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS -from ..profiler import MetaInfo __all__ = ['profile_function', 'profile_module', 'profile_method'] + +# this is for compatibility use +@dataclass +class GraphInfo: + """ + GraphInfo is a dataclass for MetaInfo, which measures + the execution memory cost and FLOPs with `MetaTensor`. + The dataflow analysis is conducted on a single node of the FX graph. + ============================================================================ + ------------------------------- + | Node | + [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out` + placeholders saved for | | \__________ | | + backward. | | \ | | + | [fwd_tmp] ------> [bwd_tmp] | <----- + | | \_________ | | [bwd_tmp] marks the peak memory + | / \ \ | | in backward pass. + [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- + in [fwd_tmp] because | | | \_____ | | + it is not saved for | | | \ | | + backward. ------------------------------- + ============================================================================ + Attributes: + fwd_flop (int): The forward FLOPs of a certain node + bwd_flop (int): The backward FLOPs of a certain node. + fwd_mem_in (int): See the above illustration. + fwd_mem_tmp (int): See the above illustration. + bwd_mem_tmp (int): See the above illustration. + bwd_mem_out (int): See the above illustration. + """ + fwd_flop: int = 0 + bwd_flop: int = 0 + fwd_mem_in: int = 0 + fwd_mem_tmp: int = 0 + bwd_mem_tmp: int = 0 + bwd_mem_out: int = 0 + + CALL_FUNCTION_MSG = \ """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n From bc727c2e274fb8e5ec818a7252fd81b3ad317f44 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 16:54:24 +0800 Subject: [PATCH 13/21] [fx] refactor code. --- colossalai/fx/profiler/experimental/profiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index cbf95c1ba65c..954e8b49bbe9 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -98,7 +98,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: else: profiler = meta_profiler_function.get(target.__name__) fwd_flop, _ = profiler(*args, **kwargs) - return out, MetaInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) f.__name__ = target.__name__ func = target @@ -127,7 +127,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) - return out, MetaInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, GraphInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) return f @@ -157,7 +157,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: fwd_out = activation_size(out) profiler = meta_profiler_module.get(type(module)) fwd_flop, _ = profiler(module, *args, **kwargs) - return out, MetaInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) f.__name__ = module.__class__.__name__ func = module.forward From 13b1d58ada424a01ef3b4e4f06934be8db93afea Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 18:20:44 +0800 Subject: [PATCH 14/21] [fx] dive into backward memory. --- .../fx/passes/algorithms/ckpt_solver_chen.py | 5 +- .../fx/passes/algorithms/ckpt_solver_rotor.py | 46 ++++++++++--------- colossalai/fx/passes/algorithms/utils.py | 14 +++--- colossalai/fx/passes/meta_info_prop.py | 2 +- colossalai/fx/profiler/dataflow.py | 9 +--- .../test_ckpt_torchvision.py | 5 +- .../test_ckpt_solvers/test_linearize.py | 3 +- 7 files changed, 40 insertions(+), 44 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 9ebbd48c75b1..860804f48747 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -73,10 +73,11 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: y = 0 prev_idx = 2 for (idx, n) in enumerate(gm.graph.nodes): - temp += getattr(n, 'fwd_out') + n: Node + temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp'] y = max(y, temp) if temp > b and n in ckpt_nodes: - x += getattr(n, 'fwd_out') + x += n.meta['fwd_mem_out'] temp = 0 ckpt_intv.append((prev_idx, idx + 1)) prev_idx = idx + 1 diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 0d8ed955301d..1e09439405bb 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -1,11 +1,10 @@ -from typing import List, Set, Tuple, Dict +from typing import List, Tuple import torch from torch.fx import GraphModule, Node from colossalai.fx.graph_module import ColoGraphModule import math from .linearize import linearize from .utils import * -from colossalai.fx.profiler import profile_function, profile_module from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions @@ -25,8 +24,8 @@ def _compute_table(chain: Chain, mmax) -> Tuple: bw = chain.bweight ## backward time, not used cw = chain.cweight + [0] ## size of x (and of y) cbw = chain.cbweight + [0] ## size of xbar - fwd_tmp = chain.fwd_tmp + [0] - bwd_tmp = chain.bwd_tmp + [0] + fwd_mem_tmp = chain.fwd_mem_tmp + [0] + bwd_mem_tmp = chain.bwd_mem_tmp + [0] # Build table opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] @@ -37,7 +36,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple: for m in range(mmax + 1): for i in range(chain.length + 1): #lmax-lmin = 0 - limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) + limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i]) if m >= limit: ## Equation (1) opt[m][i][i] = fw[i] + bw[i] else: @@ -49,9 +48,9 @@ def _compute_table(chain: Chain, mmax) -> Tuple: for i in range(chain.length + 1 - d): # for idx in range(i+1, chain.length + 1): idx = i + d - mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i] + mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i] if idx > i + 1: - mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, idx))) + mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx))) if m < mmin: opt[m][i][idx] = float("inf") else: @@ -165,7 +164,7 @@ def _fwd_xbar(node: List[Node]) -> int: xbar = 0 for n in node: - xbar += n.fwd_tmp + n.fwd_out + xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out'] return xbar @@ -183,7 +182,7 @@ def _fwd_time(node: List[Node]) -> int: fwd_time = 0 for n in node: # minimum flop count is needed - fwd_time += max(n.fwd_flop, 1) + fwd_time += max(n.meta['fwd_flop'], 1) return fwd_time @@ -201,11 +200,11 @@ def _bwd_time(node: List[Node]) -> int: bwd_time = 0 for n in node: # minimum flop count is needed - bwd_time += max(n.bwd_flop, 1) + bwd_time += max(n.meta['bwd_flop'], 1) return bwd_time -def _get_bwd_tmp(node: List[Node]) -> int: +def _get_bwd_mem_tmp(node: List[Node]) -> int: """Get the backward temp memory of a node Args: @@ -218,29 +217,31 @@ def _get_bwd_tmp(node: List[Node]) -> int: def _get_deps_size(): deps_size = 0 - for key in deps.keys(): - deps_size += key.bwd_out + for k, v in deps.items(): + if v > 0: + deps_size += k.meta['bwd_mem_out'] return deps_size - bwd_tmp = 0 + bwd_mem_tmp = 0 deps = {} # add all the users for last node into deps, # as those nodes' gradient out will be stored in memory - for son in node[-1].users: - deps[son] = 1 + for child in node[-1].users: + deps[child] = 1 for n in reversed(node): - bwd_tmp = max(bwd_tmp, _get_deps_size() + n.bwd_tmp) - deps[n] = len(n._input_nodes) - for son in n.users: - deps[son] -= 1 + for child in n.users: + if child in deps: + deps[child] -= 1 + bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp']) + deps[n] = len(n.all_input_nodes) for key in list(deps.keys()): if deps[key] == 0: del deps[key] - return bwd_tmp + return bwd_mem_tmp def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain: @@ -267,7 +268,7 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain: bwd_time.append(_bwd_time(node)) x_sizes.append(_compute_output_size(node)) xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node))) - tmp_bwd.append(_get_bwd_tmp(node)) + tmp_bwd.append(_get_bwd_mem_tmp(node)) # if a node with only one inplace op, we need to let x_bar = 0 if len(node) == 1 and _get_inplace(node[0]): @@ -397,6 +398,7 @@ def solver_rotor(gm: ColoGraphModule, mem_unit = mem_limit * (1.0 - eps) // mem_slots MetaInfoProp(gm).run(data) chain: Chain = _construct_chain(node_list, data, mem_unit) + print(chain) opt_table = _compute_table(chain, mem_slots) sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) _annotate_from_sequence(sequence, node_list) diff --git a/colossalai/fx/passes/algorithms/utils.py b/colossalai/fx/passes/algorithms/utils.py index d26f1a2e27b6..78fb0c36393f 100644 --- a/colossalai/fx/passes/algorithms/utils.py +++ b/colossalai/fx/passes/algorithms/utils.py @@ -5,24 +5,24 @@ def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True): self.bweight = bw self.cweight = cw self.cbweight = cbw - self.fwd_tmp = ftmp - self.bwd_tmp = btmp + self.fwd_mem_tmp = ftmp + self.bwd_mem_tmp = btmp self.length = len(fw) if check and not self.check_lengths(): raise AttributeError("In Chain, input lists do not have consistent lengths") def check_lengths(self): return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1) - and (len(self.cweight) == self.length + 1) and (len(self.fwd_tmp) == self.length) - and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1)) + and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length) + and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1)) def __repr__(self): chain_list = [] for i in range(self.length): - chain_list.append( - (self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_tmp[i], self.bwd_tmp[i])) + chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i], + self.bwd_mem_tmp[i])) i = self.length - chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_tmp[i])) + chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i])) return chain_list.__repr__() diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 1d2638a02a10..ecaf204f97ee 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -94,7 +94,7 @@ def extract_tensor_meta(obj): tensor_meta = tree_map(extract_tensor_meta, result) n.meta['tensor_meta'] = tensor_meta - n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` + n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index da3b56a2adf9..4fc2ae35e8fb 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -120,19 +120,14 @@ def _peak_memory(deps: Dict[Node, int]): # the node, `fwd_tmp` can be freed. if is_placeholder(n): graph_info.fwd_mem_in += activation_size(n.meta['out']) - # print(activation_size(n.meta['out'])) if is_forward(n): graph_info.fwd_mem_tmp += activation_size(n.meta['out']) - # print(activation_size(n.meta['out'])) elif is_backward(n): if len(n.users): # liveness analysis is only used in backward - deps[n] = len(n.users) - graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) for input_n in n.all_input_nodes: if input_n in deps: deps[input_n] -= 1 - else: - # basically a backward node without user is a `grad_out` node - graph_info.bwd_mem_out += activation_size(n.meta['out']) + graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) + deps[n] = len(n.users) return graph_info diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 4dc1cdc2d9d6..0fb8090bfddd 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -78,7 +78,7 @@ def _run_ckpt_solver(rank): codegen = ActivationCheckpointCodeGen() gm.graph.set_codegen(codegen) if solver == solver_rotor: - gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) + gm = solver(gm, data, mem_limit=1000 * 1024 * 1024, mem_slots=500) else: gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." @@ -89,7 +89,6 @@ def _run_ckpt_solver(rank): @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') -@pytest.mark.skip('TODO: refactor ckpt solvers') def test_ckpt_solver(): mp.spawn(_run_ckpt_solver, nprocs=1) @@ -111,7 +110,7 @@ def _run_ckpt_solver_torch11(rank): MetaInfoProp(gm).run(data) gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) if solver == solver_rotor: - gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) + gm = solver(gm, data, mem_limit=1000 * 1024 * 1024, mem_slots=500) else: gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." diff --git a/tests/test_fx/test_ckpt_solvers/test_linearize.py b/tests/test_fx/test_ckpt_solvers/test_linearize.py index 1f4d4a0bc1a5..357b7ff6f15d 100644 --- a/tests/test_fx/test_ckpt_solvers/test_linearize.py +++ b/tests/test_fx/test_ckpt_solvers/test_linearize.py @@ -15,10 +15,9 @@ with_codegen = False -@pytest.mark.skip(reason='TODO: modify calculations in rotor') @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") def test_linearize(): - MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} + MODEL_DICT = {tm.resnet18: [2000,]} tracer = ColoTracer() for M, budgets in MODEL_DICT.items(): for budget in budgets: From 7df14a254a51ba768cd77becb247c6013b872254 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 18:40:42 +0800 Subject: [PATCH 15/21] [fx] fix variable names in ckpt_solvers and unskip tests. --- colossalai/fx/passes/algorithms/ckpt_solver_rotor.py | 8 ++++---- colossalai/fx/profiler/dataflow.py | 10 ++++++---- .../test_ckpt_solvers/test_ckpt_torchvision.py | 4 ++-- tests/test_fx/test_ckpt_solvers/test_linearize.py | 11 +++++++---- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 1e09439405bb..44ff64779d71 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -231,11 +231,12 @@ def _get_deps_size(): for child in node[-1].users: deps[child] = 1 for n in reversed(node): + bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp']) + + deps[n] = len(n.all_input_nodes) for child in n.users: if child in deps: deps[child] -= 1 - bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp']) - deps[n] = len(n.all_input_nodes) for key in list(deps.keys()): if deps[key] == 0: @@ -379,7 +380,7 @@ def solver_rotor(gm: ColoGraphModule, mem_limit: int, mem_slots: int = 500, cnode: List[str] = None, - eps: float = 0.02) -> ColoGraphModule: + eps: float = 0.0) -> ColoGraphModule: """solver that automatically find activation checkpoint in rotor's manner Args: @@ -398,7 +399,6 @@ def solver_rotor(gm: ColoGraphModule, mem_unit = mem_limit * (1.0 - eps) // mem_slots MetaInfoProp(gm).run(data) chain: Chain = _construct_chain(node_list, data, mem_unit) - print(chain) opt_table = _compute_table(chain, mem_slots) sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) _annotate_from_sequence(sequence, node_list) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 4fc2ae35e8fb..37e828a7c9cd 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -74,7 +74,7 @@ def is_saved(n: Node): def autograd_graph_analysis(graph: Graph) -> GraphInfo: """Analyze the autograd node dependencies and find out the memory usage. - Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. + Basically the input graph should have all nodes marked for keyword `stage`. Nodes should have attribute `out` indicating the output of each node. ============================================================================ Placeholder ----> p o <---- We need to keep track of grad out @@ -91,7 +91,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: l ============================================================================= Args: - graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. + graph (Graph): The autograd graph with nodes marked for keyword `stage`. Returns: graph_info (GraphInfo): Meta information for the dataflow. @@ -125,9 +125,11 @@ def _peak_memory(deps: Dict[Node, int]): elif is_backward(n): if len(n.users): # liveness analysis is only used in backward + graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) + deps[n] = len(n.users) for input_n in n.all_input_nodes: if input_n in deps: deps[input_n] -= 1 - graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) - deps[n] = len(n.users) + else: + graph_info.bwd_mem_out = activation_size(n.meta['out']) return graph_info diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 0fb8090bfddd..ea9aec43dec2 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -78,7 +78,7 @@ def _run_ckpt_solver(rank): codegen = ActivationCheckpointCodeGen() gm.graph.set_codegen(codegen) if solver == solver_rotor: - gm = solver(gm, data, mem_limit=1000 * 1024 * 1024, mem_slots=500) + gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) else: gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." @@ -110,7 +110,7 @@ def _run_ckpt_solver_torch11(rank): MetaInfoProp(gm).run(data) gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) if solver == solver_rotor: - gm = solver(gm, data, mem_limit=1000 * 1024 * 1024, mem_slots=500) + gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) else: gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." diff --git a/tests/test_fx/test_ckpt_solvers/test_linearize.py b/tests/test_fx/test_ckpt_solvers/test_linearize.py index 357b7ff6f15d..56e6b2ac4a19 100644 --- a/tests/test_fx/test_ckpt_solvers/test_linearize.py +++ b/tests/test_fx/test_ckpt_solvers/test_linearize.py @@ -17,7 +17,7 @@ @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") def test_linearize(): - MODEL_DICT = {tm.resnet18: [2000,]} + MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} tracer = ColoTracer() for M, budgets in MODEL_DICT.items(): for budget in budgets: @@ -37,7 +37,8 @@ def test_linearize(): if isinstance(op, ForwardNograd): for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + assert n.activation_checkpoint[ + 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" continue @@ -53,7 +54,8 @@ def test_linearize(): ckpt_idx += 1 for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + assert n.activation_checkpoint[ + 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" continue @@ -62,7 +64,8 @@ def test_linearize(): in_ckpt = True for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + assert n.activation_checkpoint[ + 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" del model del gm From ce7b1a9e90ec89a670266c6e4fdd6905a9ed4c5c Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 14 Sep 2022 13:19:09 +0800 Subject: [PATCH 16/21] [fx] commit my changes. --- .../fx/passes/algorithms/ckpt_solver_rotor.py | 6 +++++- colossalai/fx/passes/meta_info_prop.py | 2 +- colossalai/fx/profiler/dataflow.py | 9 +-------- colossalai/fx/profiler/profiler.py | 14 ++++++++++++-- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 44ff64779d71..b7f5e0977e9e 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -2,6 +2,7 @@ import torch from torch.fx import GraphModule, Node from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.profiler import parameter_size import math from .linearize import linearize from .utils import * @@ -380,7 +381,7 @@ def solver_rotor(gm: ColoGraphModule, mem_limit: int, mem_slots: int = 500, cnode: List[str] = None, - eps: float = 0.0) -> ColoGraphModule: + eps: float = 0.00) -> ColoGraphModule: """solver that automatically find activation checkpoint in rotor's manner Args: @@ -396,6 +397,7 @@ def solver_rotor(gm: ColoGraphModule, """ node_list = linearize(gm, cnode) + mem_limit -= parameter_size(gm) mem_unit = mem_limit * (1.0 - eps) // mem_slots MetaInfoProp(gm).run(data) chain: Chain = _construct_chain(node_list, data, mem_unit) @@ -405,4 +407,6 @@ def solver_rotor(gm: ColoGraphModule, # set __sequence__ attribute to GraphModule setattr(gm, "__sequence__", sequence) + print(chain) + print(node_list) return gm diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index ecaf204f97ee..5df34467e5c9 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -99,7 +99,7 @@ def extract_tensor_meta(obj): # TODO: the attribute node_size should be removed in the future setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) for par in n.all_input_nodes: - par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0) + par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0)) n.meta['type'] = type(result) # retain the autograd graph diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 37e828a7c9cd..b6b22f35dfcd 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -124,12 +124,5 @@ def _peak_memory(deps: Dict[Node, int]): graph_info.fwd_mem_tmp += activation_size(n.meta['out']) elif is_backward(n): if len(n.users): - # liveness analysis is only used in backward - graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) - deps[n] = len(n.users) - for input_n in n.all_input_nodes: - if input_n in deps: - deps[input_n] -= 1 - else: - graph_info.bwd_mem_out = activation_size(n.meta['out']) + graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, activation_size(n.meta['out'])) return graph_info diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 347c68c3ac5c..c694066623de 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -5,8 +5,8 @@ from torch.fx import Graph, Node from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map -from .dataflow import autograd_graph_analysis, Stage -from .memory import WEIRD_OPS +from .dataflow import GraphInfo, autograd_graph_analysis, Stage +from .memory import WEIRD_OPS, activation_size from .tensor import MetaTensor from .opcount import flop_mapping @@ -180,6 +180,11 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # If there is an argument that this `call_function` is inplace, we should # skip the autograd profiling. + if kwargs.get('inplace', False): + args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) + kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) + out = func(*args, **kwargs) + return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0) out, meta = _profile(func, *args, **kwargs) return out, meta @@ -222,6 +227,11 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # If there is an argument that this `call_module` is inplace, we should # skip the autograd profiling. + if getattr(module, 'inplace', False): + args = tree_map(lambda x: x.to('meta'), args) + kwargs = tree_map(lambda x: x.to('meta'), kwargs) + out = func(*args, **kwargs) + return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0) out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs) return out, meta From 5277b72ad5c8c326aeaf3bfc40b8bef12335cb81 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 14 Sep 2022 13:25:34 +0800 Subject: [PATCH 17/21] [fx] restore skips. --- tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py | 1 + tests/test_fx/test_ckpt_solvers/test_linearize.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index ea9aec43dec2..150c759987ab 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -88,6 +88,7 @@ def _run_ckpt_solver(rank): gpc.destroy() +@pytest.mark.skip('TODO: refactor ckpt solvers') @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def test_ckpt_solver(): mp.spawn(_run_ckpt_solver, nprocs=1) diff --git a/tests/test_fx/test_ckpt_solvers/test_linearize.py b/tests/test_fx/test_ckpt_solvers/test_linearize.py index 56e6b2ac4a19..4b6f91a4d3c0 100644 --- a/tests/test_fx/test_ckpt_solvers/test_linearize.py +++ b/tests/test_fx/test_ckpt_solvers/test_linearize.py @@ -15,6 +15,7 @@ with_codegen = False +@pytest.mark.skip(reason='TODO: modify calculations in rotor') @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") def test_linearize(): MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} From 594ec0b6bc6dc056fb0af30c1a6443515a3ad1dd Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 14 Sep 2022 13:26:07 +0800 Subject: [PATCH 18/21] [fx] restore skips. --- tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 150c759987ab..4dc1cdc2d9d6 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -88,8 +88,8 @@ def _run_ckpt_solver(rank): gpc.destroy() -@pytest.mark.skip('TODO: refactor ckpt solvers') @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skip('TODO: refactor ckpt solvers') def test_ckpt_solver(): mp.spawn(_run_ckpt_solver, nprocs=1) From 3144f8a73d340dd263375b01aecf898ffca867b0 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 14 Sep 2022 13:36:05 +0800 Subject: [PATCH 19/21] [fx] chaange stage into phase. --- colossalai/fx/profiler/dataflow.py | 20 ++++++++++---------- colossalai/fx/profiler/profiler.py | 15 +++++++-------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index b5b4163fc1f0..9e729ea7fdab 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -49,23 +49,23 @@ class GraphInfo: def is_forward(n: Node): - assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Phase.FORWARD + assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' + return n.meta['phase'] == Phase.FORWARD def is_loss(n: Node): - assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Phase.LOSS + assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' + return n.meta['phase'] == Phase.LOSS def is_placeholder(n: Node): - assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Phase.PLACEHOLDER + assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' + return n.meta['phase'] == Phase.PLACEHOLDER def is_backward(n: Node): - assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Phase.BACKWARD + assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' + return n.meta['phase'] == Phase.BACKWARD def is_saved(n: Node): @@ -74,7 +74,7 @@ def is_saved(n: Node): def autograd_graph_analysis(graph: Graph) -> GraphInfo: """Analyze the autograd node dependencies and find out the memory usage. - Basically the input graph should have all nodes marked for keyword `stage`. + Basically the input graph should have all nodes marked for keyword `phase`. Nodes should have attribute `out` indicating the output of each node. ============================================================================ Placeholder ----> p o <---- We need to keep track of grad out @@ -91,7 +91,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: l ============================================================================= Args: - graph (Graph): The autograd graph with nodes marked for keyword `stage`. + graph (Graph): The autograd graph with nodes marked for keyword `phase`. Returns: graph_info (GraphInfo): Meta information for the dataflow. diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index e0ad56f39728..a152385e8131 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -46,9 +46,6 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... Phase.BACKWARD: 0, } - # `stage` will mark the stage of autograd from outside scope. - stage = Phase.FORWARD - # FlopTensor not only get the flop statistics of a single node, # it also build a full autograd graph for this node. # This makes sure we can analyze the dependencies of memory, and @@ -85,9 +82,9 @@ def unwrap(x): # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) - flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) + flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) node.meta['out'] = normalize_tuple(out) - node.meta['stage'] = stage + node.meta['phase'] = phase def wrap(x): return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x @@ -121,7 +118,7 @@ def set_placeholder(x): x._node = subgraph.create_node('placeholder', 'placeholder', (subgraph._root,), name=subgraph._graph_namespace.create_name('input', x._tensor)) - x._node.meta['stage'] = Phase.PLACEHOLDER + x._node.meta['phase'] = Phase.PLACEHOLDER x._node.meta['out'] = (x._tensor,) tree_map(set_placeholder, args) @@ -135,6 +132,8 @@ def pack(x): def unpack(x): return x + # `phase` will mark the phase of autograd from outside scope. + phase = Phase.FORWARD # mark saved tensors with saved_tensors_hooks with torch.autograd.graph.saved_tensors_hooks(pack, unpack): if isinstance(target, str): @@ -147,9 +146,9 @@ def unpack(x): # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. if is_autogradable(out) and out.requires_grad: - stage = Phase.LOSS + phase = Phase.LOSS loss = out.sum() - stage = Phase.BACKWARD + phase = Phase.BACKWARD loss.backward() graph_info = autograd_graph_analysis(subgraph) From 31d6bbfd03e6496828f84f1fe92cfd6797c5ce1f Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 14 Sep 2022 13:44:36 +0800 Subject: [PATCH 20/21] [fx] chaange stage into phase. --- colossalai/fx/passes/algorithms/ckpt_solver_rotor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index b7f5e0977e9e..9cb48828e1e9 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -381,7 +381,7 @@ def solver_rotor(gm: ColoGraphModule, mem_limit: int, mem_slots: int = 500, cnode: List[str] = None, - eps: float = 0.00) -> ColoGraphModule: + eps: float = 0.02) -> ColoGraphModule: """solver that automatically find activation checkpoint in rotor's manner Args: @@ -407,6 +407,4 @@ def solver_rotor(gm: ColoGraphModule, # set __sequence__ attribute to GraphModule setattr(gm, "__sequence__", sequence) - print(chain) - print(node_list) return gm From aead7e16b0ca5ac6455f8ca6b371d04fcfdf0f4d Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 14 Sep 2022 13:52:45 +0800 Subject: [PATCH 21/21] [fx] chaange stage into phase. --- colossalai/fx/profiler/dataflow.py | 42 ++++++++++-------------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 9e729ea7fdab..69319b792c07 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from enum import Enum +from functools import partial from typing import Dict from torch.fx import Graph, Node from .memory import activation_size @@ -48,24 +49,9 @@ class GraphInfo: bwd_mem_out: int = 0 -def is_forward(n: Node): +def is_phase(n: Node, phase: Phase) -> bool: assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' - return n.meta['phase'] == Phase.FORWARD - - -def is_loss(n: Node): - assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' - return n.meta['phase'] == Phase.LOSS - - -def is_placeholder(n: Node): - assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' - return n.meta['phase'] == Phase.PLACEHOLDER - - -def is_backward(n: Node): - assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' - return n.meta['phase'] == Phase.BACKWARD + return n.meta['phase'] == phase def is_saved(n: Node): @@ -98,11 +84,11 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: """ def _peak_memory(deps: Dict[Node, int]): - bwd_tmp = 0 + peak_mem = 0 for k, v in deps.items(): if v > 0: - bwd_tmp += activation_size(k.meta['out']) - return bwd_tmp + peak_mem += activation_size(k.meta['out']) + return peak_mem # deps is used to track all the memory dependencies of the graph. deps = {} @@ -110,19 +96,19 @@ def _peak_memory(deps: Dict[Node, int]): for n in graph.nodes: n: Node - if is_saved(n) and not any(map(is_loss, n.users)): + if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)): # A forward tensor who is marked `save` but is not # an input to `loss` should be saved during forward. - # If the tensor is a placeholder, then it belongs to `fwd_in`. - # Any `fwd_in` should be kept in memory even this function + # If the tensor is a placeholder, then it belongs to `fwd_mem_in`. + # Any `fwd_mem_in` should be kept in memory even this function # is checkpointed. - # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint - # the node, `fwd_tmp` can be freed. - if is_placeholder(n): + # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint + # the node, `fwd_mem_tmp` can be freed. + if is_phase(n, Phase.PLACEHOLDER): graph_info.fwd_mem_in += activation_size(n.meta['out']) - if is_forward(n): + if is_phase(n, Phase.FORWARD): graph_info.fwd_mem_tmp += activation_size(n.meta['out']) - elif is_backward(n): + elif is_phase(n, Phase.BACKWARD): if len(n.users): # liveness analysis is only used in backward deps[n] = len(n.users)