From 044adf8338267b75e2a3d27c3278fd06177cf068 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Thu, 8 Sep 2022 11:51:14 +0800 Subject: [PATCH 01/14] [fx] add some comment and docstrings. --- colossalai/fx/profiler/profiler.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 8f9fb92e0ae4..0547aa3072d7 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -33,19 +33,32 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop). mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ + # This subgraph traces aten level ops inside one node. + subgraph = Graph() + # flop_count serves as a global dictionary to store results. flop_count = { 'f': 0, 'l': 0, 'b': 0, } + + # TODO: remove this temp = { 'f': [], 'l': [], 'b': [], } + + # `stage` will mark the stage of autograd from outside scope. stage = 'f' + # FlopTensor not only get the flop statistics of a single node, + # it also build a full autograd graph for this node. + # This makes sure we can analyze the dependencies of memory, and + # decide which forward intermediate results should be kept until + # backward is executed. + # Hopefully, this attempt will provide a better estimation of memory. class FlopTensor(MetaTensor): def __repr__(self): @@ -78,6 +91,8 @@ def wrap(x): return tree_map(wrap, out) + # `WEIRD_OPS` are tough to handle because they don't accept autograd + # on meta tensor. if target not in WEIRD_OPS: def wrap(x): @@ -89,6 +104,7 @@ def wrap(x): return FlopTensor( x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x + # Basically, we need to detach the args and kwargs from the outer graph. args = tree_map(wrap, args) kwargs = tree_map(wrap, kwargs) @@ -99,6 +115,8 @@ def wrap(x): else: out = target(*args, **kwargs) + # If the output is not a floating point `torch.Tensor` or it does not + # requires grad, then we should not run backward for this node. if is_autogradable(out) and out.requires_grad: stage = 'l' loss = out.sum() @@ -110,7 +128,8 @@ def wrap(x): fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0 fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0 - bwd_tmp = max(map(activation_size, temp['b'])) if len(temp['b']) else 0 + bwd_tmp = max(map(activation_size, temp['b'][:-1])) if len(temp['b'][:-1]) else 0 + fwd_out = activation_size(temp['b'][-1]) if len(temp['b']) else 0 def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x @@ -134,11 +153,14 @@ def profile_function(target: 'Target') -> Callable: """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + + # If there is an argument that this `call_function` is inplace, we should + # skip the autograd profiling. if kwargs.get('inplace', False): args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) out = func(*args, **kwargs) - return out, (0, 0), (0, 0, 0, 0) + return out, (out.numel(), out.numel()), (0, 0, 0, 0) out, flop_count, mem_stat = _profile(func, *args, **kwargs) return out, flop_count, mem_stat @@ -178,6 +200,9 @@ def profile_module(module: torch.nn.Module) -> Callable: """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + + # If there is an argument that this `call_module` is inplace, we should + # skip the autograd profiling. if getattr(module, 'inplace', False): args = tree_map(lambda x: x.to('meta'), args) kwargs = tree_map(lambda x: x.to('meta'), kwargs) From a42ab22345358092932484aa16ba2f25ecdba010 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Thu, 8 Sep 2022 13:02:43 +0800 Subject: [PATCH 02/14] [fx] add dataflow analysis for an autograd graph. --- colossalai/fx/profiler/dataflow.py | 47 ++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 colossalai/fx/profiler/dataflow.py diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py new file mode 100644 index 000000000000..c002acf645a5 --- /dev/null +++ b/colossalai/fx/profiler/dataflow.py @@ -0,0 +1,47 @@ +from typing import Tuple +from torch.fx import Graph, Node +import torch + + +def is_forward(n: Node): + assert hasattr(n, 'stage'), f'Node {n} has no attribute `stage`!' + return getattr(n, 'stage') == 'f' + + +def is_loss(n: Node): + assert hasattr(n, 'stage'), f'Node {n} has no attribute `stage`!' + return getattr(n, 'stage') == 'l' + + +def is_backward(n: Node): + assert hasattr(n, 'stage'), f'Node {n} has no attribute `stage`!' + return getattr(n, 'stage') == 'b' + + +def autograd_graph_analysis(graph: Graph) -> Tuple[int, int, int, int]: + """Analyze the autograd node dependencies and find out the memory usage. + Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward). + Nodes should have attribute `_out_tensors` indicating the output of each node. + + Args: + graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) + + Returns: + fwd_tmp (int): Intermediate memory encountered through forward pass. These tensors are not supposed to be freed unless checkpointed. + fwd_out (int): The output of the entire forward pass. + bwd_tmp (int): Intermediate memory (or peak memory) encountered through backward pass. These tensors can be freed as long as it is not required for its users. We will use liveness analysis to detect the peak memory usage. + bwd_out (int): + """ + pass + + +def _peak_memory_analysis(nodes: Tuple[Node, ...]) -> int: + """Apply liveness analysis to a list of nodes in topological order and calculate the peak memory. + + Args: + nodes (Tuple[Node, ...]): A list of nodes in topological order. + + Returns: + memory_peak (int): Peak memory encountered during the execution. + """ + pass From 0d55f26010fdf7e1a46a1192e6ec938f32c5b48c Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 9 Sep 2022 13:28:09 +0800 Subject: [PATCH 03/14] add intepretation for graph analysis. --- colossalai/fx/profiler/dataflow.py | 55 +++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index c002acf645a5..d731baa64034 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,38 +1,69 @@ from typing import Tuple from torch.fx import Graph, Node import torch +from .memory import activation_size def is_forward(n: Node): - assert hasattr(n, 'stage'), f'Node {n} has no attribute `stage`!' - return getattr(n, 'stage') == 'f' + assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' + return n.meta['stage'] == 'f' def is_loss(n: Node): - assert hasattr(n, 'stage'), f'Node {n} has no attribute `stage`!' - return getattr(n, 'stage') == 'l' + assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' + return n.meta['stage'] == 'l' def is_backward(n: Node): - assert hasattr(n, 'stage'), f'Node {n} has no attribute `stage`!' - return getattr(n, 'stage') == 'b' + assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' + return n.meta['stage'] == 'b' def autograd_graph_analysis(graph: Graph) -> Tuple[int, int, int, int]: """Analyze the autograd node dependencies and find out the memory usage. - Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward). - Nodes should have attribute `_out_tensors` indicating the output of each node. - + Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. + Nodes should have attribute `out_tensors` indicating the output of each node. + ============================================================================ + p o <---- We need to keep track of grad out + |\________ | + ↓ ↘| + f --------> b + |\ \_____ ↑ + | \ ↘ / + f f ----> b <---- Not every forward result needs to be saved for backward + | \____ ↑ + ↘ ↘| + f ----> b <---- Backward can be freed as soon as it is required no more. + ↘ ↗ + l + ============================================================================= Args: - graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) + graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Returns: fwd_tmp (int): Intermediate memory encountered through forward pass. These tensors are not supposed to be freed unless checkpointed. fwd_out (int): The output of the entire forward pass. bwd_tmp (int): Intermediate memory (or peak memory) encountered through backward pass. These tensors can be freed as long as it is not required for its users. We will use liveness analysis to detect the peak memory usage. - bwd_out (int): + bwd_out (int): The output of the entire backward pass. """ - pass + # deps is used to track all the memory dependencies of the graph. + deps = {} + + fwd_tmp = 0 + fwd_out = 0 + bwd_tmp = 0 + bwd_out = 0 + + for n in graph.nodes: + n: Node + if is_forward(n): + if any(map(is_backward, n.users)): + fwd_tmp += activation_size(n.meta['out_tensors']) + if any(map(is_loss, n.users)): + fwd_out += activation_size(n.meta['out_tensors']) + elif is_backward(n): + if not len(n.users): + bwd_out += activation_size(n.meta['out_tensors']) def _peak_memory_analysis(nodes: Tuple[Node, ...]) -> int: From 42c6e8c973fd2a6dca7b90d353fc04e84f8a1819 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 12 Sep 2022 19:37:57 +0800 Subject: [PATCH 04/14] [fx] before doing save_tensor_hooks. --- colossalai/fx/passes/meta_info_prop.py | 4 +- colossalai/fx/profiler/dataflow.py | 84 +++++++++++++++----------- colossalai/fx/profiler/memory.py | 2 - colossalai/fx/profiler/profiler.py | 51 ++++++++++------ 4 files changed, 84 insertions(+), 57 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 1a1e149577c4..2c1a399dcf09 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -97,8 +97,8 @@ def extract_tensor_meta(obj): setattr(n, 'node_size', mem_stat[1]) setattr(n, 'fwd_flop', flop_count[0]) setattr(n, 'bwd_flop', flop_count[1]) - setattr(n, 'fwd_tmp', mem_stat[0]) - setattr(n, 'fwd_out', mem_stat[1]) + setattr(n, 'fwd_in', mem_stat[0]) + setattr(n, 'fwd_tmp', mem_stat[1]) setattr(n, 'bwd_tmp', mem_stat[2]) setattr(n, 'bwd_out', mem_stat[3]) n.meta['type'] = type(result) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index d731baa64034..14d635501bee 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,7 +1,6 @@ -from typing import Tuple +from typing import Dict, Tuple from torch.fx import Graph, Node -import torch -from .memory import activation_size +from .memory import INPLACE_ATEN, activation_size def is_forward(n: Node): @@ -14,6 +13,11 @@ def is_loss(n: Node): return n.meta['stage'] == 'l' +def is_placeholder(n: Node): + assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' + return n.meta['stage'] == 'p' + + def is_backward(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' return n.meta['stage'] == 'b' @@ -22,20 +26,20 @@ def is_backward(n: Node): def autograd_graph_analysis(graph: Graph) -> Tuple[int, int, int, int]: """Analyze the autograd node dependencies and find out the memory usage. Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. - Nodes should have attribute `out_tensors` indicating the output of each node. + Nodes should have attribute `out` indicating the output of each node. ============================================================================ - p o <---- We need to keep track of grad out - |\________ | - ↓ ↘| - f --------> b - |\ \_____ ↑ - | \ ↘ / - f f ----> b <---- Not every forward result needs to be saved for backward - | \____ ↑ - ↘ ↘| - f ----> b <---- Backward can be freed as soon as it is required no more. - ↘ ↗ - l + Placeholder ----> p o <---- We need to keep track of grad out + |\________ | + ↓ ↘| + f --------> b + |\ \_____ ↑ + | \ ↘ / + f f ----> b <---- Not every forward result needs to be saved for backward + | \____ ↑ + ↘ ↘| + f ----> b <---- Backward can be freed as soon as it is required no more. + ↘ ↗ + l ============================================================================= Args: graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. @@ -46,33 +50,43 @@ def autograd_graph_analysis(graph: Graph) -> Tuple[int, int, int, int]: bwd_tmp (int): Intermediate memory (or peak memory) encountered through backward pass. These tensors can be freed as long as it is not required for its users. We will use liveness analysis to detect the peak memory usage. bwd_out (int): The output of the entire backward pass. """ + + def _peak_memory(deps: Dict[Node, int]): + bwd_tmp = 0 + for k, v in deps.items(): + if v > 0: + bwd_tmp += activation_size(k.meta['out']) + return bwd_tmp + # deps is used to track all the memory dependencies of the graph. deps = {} + fwd_in = 0 fwd_tmp = 0 - fwd_out = 0 bwd_tmp = 0 bwd_out = 0 for n in graph.nodes: n: Node + if is_placeholder(n): + # a placeholder node who has any backward node users will have to be kept in memory until released + if any(map(is_backward, n.users)) and not any(map(is_loss, n.users)): + # but if its users are all inplace methods in forward pass, it should not have activations + fwd_in += activation_size(n.meta['out']) if is_forward(n): - if any(map(is_backward, n.users)): - fwd_tmp += activation_size(n.meta['out_tensors']) - if any(map(is_loss, n.users)): - fwd_out += activation_size(n.meta['out_tensors']) + # a forward node who has any backward node users will have to be kept in memory until released + if any(map(is_backward, n.users)) and not any(map(is_loss, n.users)): + # but if its users are all inplace methods in forward pass, it should not have activations + fwd_tmp += activation_size(n.meta['out']) elif is_backward(n): - if not len(n.users): - bwd_out += activation_size(n.meta['out_tensors']) - - -def _peak_memory_analysis(nodes: Tuple[Node, ...]) -> int: - """Apply liveness analysis to a list of nodes in topological order and calculate the peak memory. - - Args: - nodes (Tuple[Node, ...]): A list of nodes in topological order. - - Returns: - memory_peak (int): Peak memory encountered during the execution. - """ - pass + if len(n.users): + # liveness analysis is only used in backward + deps[n] = len(n.users) + bwd_tmp = max(bwd_tmp, _peak_memory(deps)) + for input_n in n.all_input_nodes: + if input_n in deps: + deps[input_n] -= 1 + else: + # basically a backward node without user is a `grad_out` node + bwd_out += activation_size(n.meta['out']) + return fwd_in, fwd_tmp, bwd_tmp, bwd_out diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/memory.py index be51064220e0..c023d0d1eaa3 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/memory.py @@ -14,12 +14,10 @@ INPLACE_ATEN = [ aten.add_.Tensor, - aten.add.Tensor, aten.sub_.Tensor, aten.div_.Tensor, aten.div_.Scalar, aten.mul_.Tensor, - aten.mul.Tensor, aten.bernoulli_.float, # inplace reshaping diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 0547aa3072d7..9968ec0ebee3 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,9 +1,12 @@ +from lib2to3.pytree import Node +from operator import getitem from typing import Callable, Any, Dict, Tuple import torch from torch.fx import Graph from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map -from .memory import activation_size, INPLACE_ATEN, WEIRD_OPS +from .dataflow import autograd_graph_analysis +from .memory import INPLACE_ATEN, WEIRD_OPS, activation_size from .tensor import MetaTensor from .opcount import flop_mapping @@ -43,13 +46,6 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: 'b': 0, } - # TODO: remove this - temp = { - 'f': [], - 'l': [], - 'b': [], - } - # `stage` will mark the stage of autograd from outside scope. stage = 'f' @@ -61,6 +57,8 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: # Hopefully, this attempt will provide a better estimation of memory. class FlopTensor(MetaTensor): + _node: Node + def __repr__(self): if self.grad_fn: return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})" @@ -74,8 +72,12 @@ def unwrap(x): x = FlopTensor(x.to('meta')) return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - def to_meta(x): - return x.to('meta') if isinstance(x, torch.Tensor) else x + def get_node(x): + return None if not hasattr(x, '_node') else x._node + + args_node = tree_map(get_node, args) + kwargs_node = tree_map(get_node, kwargs) + node = subgraph.create_node('call_function', func, args_node, kwargs_node) args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) @@ -83,13 +85,18 @@ def to_meta(x): # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) - if func not in INPLACE_ATEN: - temp[stage].append(tree_map(to_meta, normalize_tuple(out))) + node.meta['out'] = normalize_tuple(out) + node.meta['stage'] = stage def wrap(x): return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x - return tree_map(wrap, out) + def set_node(x): + x._node = node + + out = tree_map(wrap, out) + tree_map(set_node, out) + return out # `WEIRD_OPS` are tough to handle because they don't accept autograd # on meta tensor. @@ -108,6 +115,17 @@ def wrap(x): args = tree_map(wrap, args) kwargs = tree_map(wrap, kwargs) + def set_placeholder(x): + if isinstance(x, FlopTensor): + x._node = subgraph.create_node('placeholder', + 'placeholder', (subgraph._root,), + name=subgraph._graph_namespace.create_name('input', x._tensor)) + x._node.meta['stage'] = 'p' + x._node.meta['out'] = (x._tensor,) + + tree_map(set_placeholder, args) + tree_map(set_placeholder, kwargs) + if isinstance(target, str): # args[0] is the `self` object for this method call self_obj, *args_tail = args @@ -126,15 +144,12 @@ def wrap(x): fwd_flop = flop_count['f'] bwd_flop = flop_count['b'] - fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0 - fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0 - bwd_tmp = max(map(activation_size, temp['b'][:-1])) if len(temp['b'][:-1]) else 0 - fwd_out = activation_size(temp['b'][-1]) if len(temp['b']) else 0 + fwd_in, fwd_tmp, bwd_tmp, bwd_out = autograd_graph_analysis(subgraph) def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, 0) + return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_in, fwd_tmp, bwd_tmp, bwd_out) def profile_function(target: 'Target') -> Callable: From 5f25d6e001ee8722a7b70f35adbe2c4334e9beb8 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 12 Sep 2022 22:23:39 +0800 Subject: [PATCH 05/14] [fx] provide an accurate estimation of memory except for GPT-2. --- colossalai/fx/profiler/dataflow.py | 18 ++++++++++-------- colossalai/fx/profiler/profiler.py | 26 ++++++++++++++++++++------ 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 14d635501bee..dedff555763c 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -68,15 +68,17 @@ def _peak_memory(deps: Dict[Node, int]): for n in graph.nodes: n: Node - if is_placeholder(n): - # a placeholder node who has any backward node users will have to be kept in memory until released - if any(map(is_backward, n.users)) and not any(map(is_loss, n.users)): - # but if its users are all inplace methods in forward pass, it should not have activations + if n.meta['save'] and not any(map(is_loss, n.users)): + # A forward tensor who is marked `save` but is not + # an input to `loss` should be saved during forward. + # If the tensor is a placeholder, then it belongs to `fwd_in`. + # Any `fwd_in` should be kept in memory even this function + # is checkpointed. + # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint + # the node, `fwd_tmp` can be freed. + if is_placeholder(n): fwd_in += activation_size(n.meta['out']) - if is_forward(n): - # a forward node who has any backward node users will have to be kept in memory until released - if any(map(is_backward, n.users)) and not any(map(is_loss, n.users)): - # but if its users are all inplace methods in forward pass, it should not have activations + if is_forward(n): fwd_tmp += activation_size(n.meta['out']) elif is_backward(n): if len(n.users): diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 9968ec0ebee3..10b7b3fa7cb1 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -87,6 +87,7 @@ def get_node(x): flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) node.meta['out'] = normalize_tuple(out) node.meta['stage'] = stage + node.meta['save'] = False def wrap(x): return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x @@ -122,16 +123,29 @@ def set_placeholder(x): name=subgraph._graph_namespace.create_name('input', x._tensor)) x._node.meta['stage'] = 'p' x._node.meta['out'] = (x._tensor,) + x._node.meta['save'] = False tree_map(set_placeholder, args) tree_map(set_placeholder, kwargs) - if isinstance(target, str): - # args[0] is the `self` object for this method call - self_obj, *args_tail = args - out = getattr(self_obj, target)(*args_tail, **kwargs) - else: - out = target(*args, **kwargs) + fwd_in = 0 + + def pack(x): + if isinstance(x, FlopTensor): + x._node.meta['save'] = True + return x + + def unpack(x): + return x + + # mark saved tensors with save_tensors_hooks + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + if isinstance(target, str): + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + out = getattr(self_obj, target)(*args_tail, **kwargs) + else: + out = target(*args, **kwargs) # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. From 3745c5ff3fca2c0fe4d3feac58bf705304ef2a32 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 12 Sep 2022 22:26:52 +0800 Subject: [PATCH 06/14] [fx] provide an accurate estimation of memory except for GPT-2. --- colossalai/fx/passes/meta_info_prop.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 2c1a399dcf09..f43be9453817 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -94,13 +94,13 @@ def extract_tensor_meta(obj): n.meta['tensor_meta'] = meta # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', mem_stat[1]) - setattr(n, 'fwd_flop', flop_count[0]) - setattr(n, 'bwd_flop', flop_count[1]) - setattr(n, 'fwd_in', mem_stat[0]) - setattr(n, 'fwd_tmp', mem_stat[1]) - setattr(n, 'bwd_tmp', mem_stat[2]) - setattr(n, 'bwd_out', mem_stat[3]) + setattr(n, 'node_size', mem_stat[0] + mem_stat[1]) + n.meta['fwd_flop'] = flop_count[0] + n.meta['bwd_flop'] = flop_count[1] + n.meta['fwd_in'] = mem_stat[0] + n.meta['fwd_tmp'] = mem_stat[1] + n.meta['bwd_tmp'] = mem_stat[2] + n.meta['bwd_out'] = mem_stat[3] n.meta['type'] = type(result) for param in self.module.parameters(): From 504c607fe05502fbd65d9f734d39bc79bf89abe3 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 12 Sep 2022 22:27:41 +0800 Subject: [PATCH 07/14] [fx] provide an accurate estimation of memory except for GPT-2. --- colossalai/fx/profiler/profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 10b7b3fa7cb1..e068fa8e6c1a 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -6,7 +6,7 @@ from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map from .dataflow import autograd_graph_analysis -from .memory import INPLACE_ATEN, WEIRD_OPS, activation_size +from .memory import WEIRD_OPS from .tensor import MetaTensor from .opcount import flop_mapping From 5d72a5275d9201e5ff64b4c3a7d246b500042894 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 11:28:28 +0800 Subject: [PATCH 08/14] [fx] a very accurate version on GPT-2. --- colossalai/fx/passes/meta_info_prop.py | 29 +++++++++--------- colossalai/fx/profiler/dataflow.py | 30 +++++++++---------- colossalai/fx/profiler/profiler.py | 41 ++++++++++++-------------- 3 files changed, 48 insertions(+), 52 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index f43be9453817..15aca1d2dcc6 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -40,8 +40,10 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: class MetaInfoProp(torch.fx.Interpreter): """ Execute an FX graph Node-by-Node with meta tensor and - record the shape, FLOPs, MACs and type of the result + record the memory usage, FLOPs, and type of the result into the corresponding node. + All information should be retrieved with + `node.meta.get(key, default=0)`. Usage: BATCH_SIZE = 2 @@ -82,7 +84,7 @@ def run_node(self, n: Node) -> Any: Returns: Any: The result of executing ``n`` """ - result, flop_count, mem_stat = super().run_node(n) + result, meta_info = super().run_node(n) def extract_tensor_meta(obj): if isinstance(obj, torch.Tensor): @@ -90,21 +92,20 @@ def extract_tensor_meta(obj): else: return TensorMetadata(None, None, False, None, 0, False) - meta = tree_map(extract_tensor_meta, result) - n.meta['tensor_meta'] = meta + tensor_meta = tree_map(extract_tensor_meta, result) + n.meta['tensor_meta'] = tensor_meta # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', mem_stat[0] + mem_stat[1]) - n.meta['fwd_flop'] = flop_count[0] - n.meta['bwd_flop'] = flop_count[1] - n.meta['fwd_in'] = mem_stat[0] - n.meta['fwd_tmp'] = mem_stat[1] - n.meta['bwd_tmp'] = mem_stat[2] - n.meta['bwd_out'] = mem_stat[3] + setattr(n, 'node_size', meta_info.get('fwd_tmp', 0) + meta_info.get('fwd_out', 0)) + n.meta = {**n.meta, **meta_info} + for par in n.all_input_nodes: + par.meta['fwd_out'] = max(par.meta.get('fwd_out', 0), meta_info.get('fwd_in', 0)) n.meta['type'] = type(result) + # retain the autograd graph for param in self.module.parameters(): param.grad = None + return result # Main Node running APIs @@ -130,7 +131,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict """ result = super().placeholder(target, args, kwargs) # A placeholder node only has activation - return result, (0, 0), (0, activation_size(result), 0, 0) + return result, {} @compatibility(is_backward_compatible=True) def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -150,7 +151,7 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ - return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0) + return super().get_attr(target, args, kwargs), {} @compatibility(is_backward_compatible=True) def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -232,7 +233,7 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ - return args[0], (0, 0), (0, 0, 0, 0) + return args[0], {'fwd_in': activation_size(args[0])} def propagate(self, *args): """ diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index dedff555763c..7eb5d868b190 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,4 +1,4 @@ -from typing import Dict, Tuple +from typing import Dict from torch.fx import Graph, Node from .memory import INPLACE_ATEN, activation_size @@ -23,7 +23,7 @@ def is_backward(n: Node): return n.meta['stage'] == 'b' -def autograd_graph_analysis(graph: Graph) -> Tuple[int, int, int, int]: +def autograd_graph_analysis(graph: Graph) -> Dict[str, int]: """Analyze the autograd node dependencies and find out the memory usage. Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Nodes should have attribute `out` indicating the output of each node. @@ -45,10 +45,7 @@ def autograd_graph_analysis(graph: Graph) -> Tuple[int, int, int, int]: graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Returns: - fwd_tmp (int): Intermediate memory encountered through forward pass. These tensors are not supposed to be freed unless checkpointed. - fwd_out (int): The output of the entire forward pass. - bwd_tmp (int): Intermediate memory (or peak memory) encountered through backward pass. These tensors can be freed as long as it is not required for its users. We will use liveness analysis to detect the peak memory usage. - bwd_out (int): The output of the entire backward pass. + meta (Dict): Meta information for the dataflow. """ def _peak_memory(deps: Dict[Node, int]): @@ -60,11 +57,12 @@ def _peak_memory(deps: Dict[Node, int]): # deps is used to track all the memory dependencies of the graph. deps = {} - - fwd_in = 0 - fwd_tmp = 0 - bwd_tmp = 0 - bwd_out = 0 + meta = { + 'fwd_in': 0, + 'fwd_tmp': 0, + 'bwd_tmp': 0, + 'bwd_out': 0, + } for n in graph.nodes: n: Node @@ -77,18 +75,18 @@ def _peak_memory(deps: Dict[Node, int]): # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint # the node, `fwd_tmp` can be freed. if is_placeholder(n): - fwd_in += activation_size(n.meta['out']) + meta['fwd_in'] += activation_size(n.meta['out']) if is_forward(n): - fwd_tmp += activation_size(n.meta['out']) + meta['fwd_tmp'] += activation_size(n.meta['out']) elif is_backward(n): if len(n.users): # liveness analysis is only used in backward deps[n] = len(n.users) - bwd_tmp = max(bwd_tmp, _peak_memory(deps)) + meta['bwd_tmp'] = max(meta['bwd_tmp'], _peak_memory(deps)) for input_n in n.all_input_nodes: if input_n in deps: deps[input_n] -= 1 else: # basically a backward node without user is a `grad_out` node - bwd_out += activation_size(n.meta['out']) - return fwd_in, fwd_tmp, bwd_tmp, bwd_out + meta['bwd_out'] += activation_size(n.meta['out']) + return meta diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index e068fa8e6c1a..1acedef69462 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -6,7 +6,7 @@ from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map from .dataflow import autograd_graph_analysis -from .memory import WEIRD_OPS +from .memory import WEIRD_OPS, activation_size from .tensor import MetaTensor from .opcount import flop_mapping @@ -67,11 +67,6 @@ def __repr__(self): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - def unwrap(x): - if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): - x = FlopTensor(x.to('meta')) - return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - def get_node(x): return None if not hasattr(x, '_node') else x._node @@ -79,6 +74,12 @@ def get_node(x): kwargs_node = tree_map(get_node, kwargs) node = subgraph.create_node('call_function', func, args_node, kwargs_node) + def unwrap(x): + # if x is a `nn.Parameter`, we can first wrap it with `FlopTensor` + if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): + x = FlopTensor(x.to('meta')) + return x._tensor.to('meta') if isinstance(x, FlopTensor) else x + args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) @@ -128,8 +129,6 @@ def set_placeholder(x): tree_map(set_placeholder, args) tree_map(set_placeholder, kwargs) - fwd_in = 0 - def pack(x): if isinstance(x, FlopTensor): x._node.meta['save'] = True @@ -138,7 +137,7 @@ def pack(x): def unpack(x): return x - # mark saved tensors with save_tensors_hooks + # mark saved tensors with saved_tensors_hooks with torch.autograd.graph.saved_tensors_hooks(pack, unpack): if isinstance(target, str): # args[0] is the `self` object for this method call @@ -155,15 +154,13 @@ def unpack(x): stage = 'b' loss.backward() - fwd_flop = flop_count['f'] - bwd_flop = flop_count['b'] - - fwd_in, fwd_tmp, bwd_tmp, bwd_out = autograd_graph_analysis(subgraph) + flop_meta = {'fwd_flop': flop_count['f'], 'bwd_flop': flop_count['b']} + mem_meta = autograd_graph_analysis(subgraph) def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_in, fwd_tmp, bwd_tmp, bwd_out) + return tree_map(unwrap, out), {**flop_meta, **mem_meta} def profile_function(target: 'Target') -> Callable: @@ -189,9 +186,9 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) out = func(*args, **kwargs) - return out, (out.numel(), out.numel()), (0, 0, 0, 0) - out, flop_count, mem_stat = _profile(func, *args, **kwargs) - return out, flop_count, mem_stat + return out, {'fwd_flop': out.numel(), 'bwd_flop': out.numel()} + out, meta = _profile(func, *args, **kwargs) + return out, meta f.__name__ = target.__name__ func = target @@ -207,8 +204,8 @@ def profile_method(target: 'Target') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # execute the method and return the result assert isinstance(target, str), f'{target} instance is not str.' - out, flop_count, mem_stat = _profile(target, *args, **kwargs) - return out, flop_count, mem_stat + out, meta = _profile(target, *args, **kwargs) + return out, meta return f @@ -236,9 +233,9 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: args = tree_map(lambda x: x.to('meta'), args) kwargs = tree_map(lambda x: x.to('meta'), kwargs) out = func(*args, **kwargs) - return out, (out.numel(), out.numel()), (0, 0, 0, 0) - out, flop_count, mem_stat = _profile(func, *args, **kwargs) - return out, flop_count, mem_stat + return out, {'fwd_flop': out.numel(), 'bwd_flop': out.numel()} + out, meta = _profile(func, *args, **kwargs) + return out, meta f.__name__ = module.__class__.__name__ func = module.forward From 6bdeb298f65522f4d7d35c240d62bbd5d02252ac Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 15:27:06 +0800 Subject: [PATCH 09/14] [fx] refactor code. --- colossalai/fx/passes/meta_info_prop.py | 36 +++---- colossalai/fx/profiler/__init__.py | 2 +- colossalai/fx/profiler/dataflow.py | 77 +++++++++++---- .../fx/profiler/experimental/profiler.py | 7 +- colossalai/fx/profiler/profiler.py | 97 +++++++++++-------- 5 files changed, 132 insertions(+), 87 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 15aca1d2dcc6..e3ba8955c80a 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -1,3 +1,5 @@ +from dataclasses import asdict +from colossalai.fx.profiler.profiler import MetaInfo import torch import torch.fx from torch.fx.node import Node, Argument, Target @@ -42,8 +44,6 @@ class MetaInfoProp(torch.fx.Interpreter): Execute an FX graph Node-by-Node with meta tensor and record the memory usage, FLOPs, and type of the result into the corresponding node. - All information should be retrieved with - `node.meta.get(key, default=0)`. Usage: BATCH_SIZE = 2 @@ -94,12 +94,12 @@ def extract_tensor_meta(obj): tensor_meta = tree_map(extract_tensor_meta, result) n.meta['tensor_meta'] = tensor_meta + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', meta_info.get('fwd_tmp', 0) + meta_info.get('fwd_out', 0)) - n.meta = {**n.meta, **meta_info} + setattr(n, 'node_size', n.meta.get('fwd_tmp', 0) + n.meta.get('fwd_out', 0)) for par in n.all_input_nodes: - par.meta['fwd_out'] = max(par.meta.get('fwd_out', 0), meta_info.get('fwd_in', 0)) + par.meta['fwd_out'] = max(par.meta.get('fwd_out', 0), n.meta.get('fwd_in', 0)) n.meta['type'] = type(result) # retain the autograd graph @@ -126,12 +126,9 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict Returns: result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - result = super().placeholder(target, args, kwargs) - # A placeholder node only has activation - return result, {} + return super().placeholder(target, args, kwargs), MetaInfo() @compatibility(is_backward_compatible=True) def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -148,10 +145,9 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st Return: result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return super().get_attr(target, args, kwargs), {} + return super().get_attr(target, args, kwargs), MetaInfo() @compatibility(is_backward_compatible=True) def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -167,8 +163,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di Return result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ assert not isinstance(target, str) return profile_function(target)(*args, **kwargs) @@ -187,8 +182,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict Return result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ return profile_method(target)(*args, **kwargs) @@ -206,8 +200,7 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict Return result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ # Retrieve executed args and kwargs values from the environment # Execute the method and return the result @@ -230,10 +223,9 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Return: result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return args[0], {'fwd_in': activation_size(args[0])} + return args[0], MetaInfo(fwd_in=activation_size(args[0])) def propagate(self, *args): """ diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index 1b46bd494a98..bc9418938fc9 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -2,7 +2,7 @@ if META_COMPATIBILITY: from .opcount import flop_mapping from .tensor import MetaTensor - from .profiler import profile_function, profile_method, profile_module, _profile + from .profiler import profile_function, profile_method, profile_module else: from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 7eb5d868b190..401af22460f0 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,29 +1,73 @@ +from dataclasses import dataclass +from enum import Enum from typing import Dict from torch.fx import Graph, Node -from .memory import INPLACE_ATEN, activation_size +from .memory import activation_size + + +class Stage(Enum): + F = 0 + L = 1 + B = 2 + P = 3 + + +@dataclass +class GraphInfo: + """ + GraphInfo is a dataclass for the dataflow analysis. + The dataflow analysis is conducted on a single node of the FX graph. + ============================================================================ + ------------------------------- + | Node | + [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out` + placeholders saved for | | \__________ | | + backward. | | \ | | + | [fwd_tmp] ------> [bwd_tmp] | <----- + | | \_________ | | [bwd_tmp] marks the peak memory + | / \ \ | | in backward pass. + [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- + in [fwd_tmp] because | | | \_____ | | + it is not saved for | | | \ | | + backward. ------------------------------- + ============================================================================ + Attributes: + fwd_in (int): See the above illustration. + fwd_tmp (int): See the above illustration. + bwd_tmp (int): See the above illustration. + bwd_out (int): See the above illustration. + """ + fwd_in: int = 0 + fwd_tmp: int = 0 + bwd_tmp: int = 0 + bwd_out: int = 0 def is_forward(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == 'f' + return n.meta['stage'] == Stage.F def is_loss(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == 'l' + return n.meta['stage'] == Stage.L def is_placeholder(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == 'p' + return n.meta['stage'] == Stage.P def is_backward(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == 'b' + return n.meta['stage'] == Stage.B + + +def is_saved(n: Node): + return n.meta.get('saved', False) -def autograd_graph_analysis(graph: Graph) -> Dict[str, int]: +def autograd_graph_analysis(graph: Graph) -> GraphInfo: """Analyze the autograd node dependencies and find out the memory usage. Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Nodes should have attribute `out` indicating the output of each node. @@ -45,7 +89,7 @@ def autograd_graph_analysis(graph: Graph) -> Dict[str, int]: graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Returns: - meta (Dict): Meta information for the dataflow. + graphinfo (GraphInfo): Meta information for the dataflow. """ def _peak_memory(deps: Dict[Node, int]): @@ -57,16 +101,11 @@ def _peak_memory(deps: Dict[Node, int]): # deps is used to track all the memory dependencies of the graph. deps = {} - meta = { - 'fwd_in': 0, - 'fwd_tmp': 0, - 'bwd_tmp': 0, - 'bwd_out': 0, - } + graph_info = GraphInfo() for n in graph.nodes: n: Node - if n.meta['save'] and not any(map(is_loss, n.users)): + if is_saved(n) and not any(map(is_loss, n.users)): # A forward tensor who is marked `save` but is not # an input to `loss` should be saved during forward. # If the tensor is a placeholder, then it belongs to `fwd_in`. @@ -75,18 +114,18 @@ def _peak_memory(deps: Dict[Node, int]): # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint # the node, `fwd_tmp` can be freed. if is_placeholder(n): - meta['fwd_in'] += activation_size(n.meta['out']) + graph_info.fwd_in += activation_size(n.meta['out']) if is_forward(n): - meta['fwd_tmp'] += activation_size(n.meta['out']) + graph_info.fwd_tmp += activation_size(n.meta['out']) elif is_backward(n): if len(n.users): # liveness analysis is only used in backward deps[n] = len(n.users) - meta['bwd_tmp'] = max(meta['bwd_tmp'], _peak_memory(deps)) + graph_info.bwd_tmp = max(graph_info.bwd_tmp, _peak_memory(deps)) for input_n in n.all_input_nodes: if input_n in deps: deps[input_n] -= 1 else: # basically a backward node without user is a `grad_out` node - meta['bwd_out'] += activation_size(n.meta['out']) - return meta + graph_info.bwd_out += activation_size(n.meta['out']) + return graph_info diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index 46d4add3c5e9..9a746f95fc15 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -3,6 +3,7 @@ from torch.fx.node import Argument, Target from . import meta_profiler_function, meta_profiler_module from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS +from ..profiler import MetaInfo __all__ = ['profile_function', 'profile_module', 'profile_method'] @@ -59,7 +60,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: else: profiler = meta_profiler_function.get(target.__name__) fwd_flop, _ = profiler(*args, **kwargs) - return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, MetaInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) f.__name__ = target.__name__ func = target @@ -88,7 +89,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) - return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, MetaInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) return f @@ -118,7 +119,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: fwd_out = activation_size(out) profiler = meta_profiler_module.get(type(module)) fwd_flop, _ = profiler(module, *args, **kwargs) - return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, MetaInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) f.__name__ = module.__class__.__name__ func = module.forward diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 1acedef69462..47558368fd73 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,16 +1,16 @@ -from lib2to3.pytree import Node -from operator import getitem +from dataclasses import dataclass +from enum import auto from typing import Callable, Any, Dict, Tuple import torch -from torch.fx import Graph +from torch.fx import Graph, Node from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map -from .dataflow import autograd_graph_analysis +from .dataflow import autograd_graph_analysis, Stage from .memory import WEIRD_OPS, activation_size from .tensor import MetaTensor from .opcount import flop_mapping -__all__ = ['profile_function', 'profile_module', 'profile_method', '_profile'] +__all__ = ['profile_function', 'profile_module', 'profile_method'] def normalize_tuple(x): @@ -23,7 +23,28 @@ def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() -def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: +@dataclass +class MetaInfo: + """ + This is a dataclass for MetaInfo, which measures + the execution memory cost and FLOPs with `MetaTensor`. + Attributes: + fwd_flop (int): The forward FLOPs of a certain node + bwd_flop (int): The backward FLOPs of a certain node. + fwd_in (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py + fwd_tmp (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py + bwd_tmp (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py + bwd_out (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py + """ + fwd_flop: int = 0 + bwd_flop: int = 0 + fwd_in: int = 0 + fwd_tmp: int = 0 + bwd_tmp: int = 0 + bwd_out: int = 0 + + +def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]: """Profile a Callable function with args and kwargs. Args: @@ -32,22 +53,23 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: kwargs (Any): Argument Returns: - out (Tuple[Any, ...]): The argument value that was retrieved - flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + out (Tuple[Any, ...]): The argument value that was retrieved. + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ # This subgraph traces aten level ops inside one node. subgraph = Graph() - # flop_count serves as a global dictionary to store results. + meta_info = MetaInfo() + + # `flop_count`` serves as a global dictionary to store results. flop_count = { - 'f': 0, - 'l': 0, - 'b': 0, + Stage.F: 0, + Stage.L: 0, + Stage.B: 0, } # `stage` will mark the stage of autograd from outside scope. - stage = 'f' + stage = Stage.F # FlopTensor not only get the flop statistics of a single node, # it also build a full autograd graph for this node. @@ -88,7 +110,6 @@ def unwrap(x): flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) node.meta['out'] = normalize_tuple(out) node.meta['stage'] = stage - node.meta['save'] = False def wrap(x): return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x @@ -105,13 +126,13 @@ def set_node(x): if target not in WEIRD_OPS: def wrap(x): - return FlopTensor( - x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x + return FlopTensor(x.detach().requires_grad_( + True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x else: def wrap(x): - return FlopTensor( - x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x + return FlopTensor(x.detach().requires_grad_( + False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x # Basically, we need to detach the args and kwargs from the outer graph. args = tree_map(wrap, args) @@ -122,16 +143,15 @@ def set_placeholder(x): x._node = subgraph.create_node('placeholder', 'placeholder', (subgraph._root,), name=subgraph._graph_namespace.create_name('input', x._tensor)) - x._node.meta['stage'] = 'p' + x._node.meta['stage'] = Stage.P x._node.meta['out'] = (x._tensor,) - x._node.meta['save'] = False tree_map(set_placeholder, args) tree_map(set_placeholder, kwargs) def pack(x): if isinstance(x, FlopTensor): - x._node.meta['save'] = True + x._node.meta['saved'] = True return x def unpack(x): @@ -148,19 +168,22 @@ def unpack(x): # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. - if is_autogradable(out) and out.requires_grad: - stage = 'l' + if is_autogradable(out) and out.requires_grad and not inplace: + stage = Stage.L loss = out.sum() - stage = 'b' + stage = Stage.B loss.backward() - flop_meta = {'fwd_flop': flop_count['f'], 'bwd_flop': flop_count['b']} - mem_meta = autograd_graph_analysis(subgraph) + graph_info = autograd_graph_analysis(subgraph) + meta_info.fwd_flop, meta_info.bwd_flop = flop_count[Stage.F], flop_count[Stage.B] + meta_info.__dict__.update(graph_info.__dict__) + if inplace: + meta_info.fwd_in = 0 def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - return tree_map(unwrap, out), {**flop_meta, **mem_meta} + return tree_map(unwrap, out), meta_info def profile_function(target: 'Target') -> Callable: @@ -175,18 +198,13 @@ def profile_function(target: 'Target') -> Callable: Examples: >>> input = torch.rand(100, 100, 100, 100, device='meta') >>> func = torch.nn.functional.relu - >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False) + >>> output, meta_info = profile_function(func)(input, inplace=False) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # If there is an argument that this `call_function` is inplace, we should # skip the autograd profiling. - if kwargs.get('inplace', False): - args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) - kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) - out = func(*args, **kwargs) - return out, {'fwd_flop': out.numel(), 'bwd_flop': out.numel()} out, meta = _profile(func, *args, **kwargs) return out, meta @@ -204,7 +222,7 @@ def profile_method(target: 'Target') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # execute the method and return the result assert isinstance(target, str), f'{target} instance is not str.' - out, meta = _profile(target, *args, **kwargs) + out, meta = _profile(target, *args, inplace=False, **kwargs) return out, meta return f @@ -222,19 +240,14 @@ def profile_module(module: torch.nn.Module) -> Callable: Example: >>> input = torch.rand(4, 3, 224, 224, device='meta') >>> mod = torch.nn.Conv2d(3, 128, 3) - >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input) + >>> output, meta_info = profile_module(mod)(input) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # If there is an argument that this `call_module` is inplace, we should # skip the autograd profiling. - if getattr(module, 'inplace', False): - args = tree_map(lambda x: x.to('meta'), args) - kwargs = tree_map(lambda x: x.to('meta'), kwargs) - out = func(*args, **kwargs) - return out, {'fwd_flop': out.numel(), 'bwd_flop': out.numel()} - out, meta = _profile(func, *args, **kwargs) + out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs) return out, meta f.__name__ = module.__class__.__name__ From fafb7d0a01b720d265679d90268921482f2d8233 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 15:40:24 +0800 Subject: [PATCH 10/14] [fx] remove redundant inplace=True. --- colossalai/fx/profiler/dataflow.py | 2 ++ colossalai/fx/profiler/profiler.py | 4 +--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 401af22460f0..12b8a3f48b3a 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -115,8 +115,10 @@ def _peak_memory(deps: Dict[Node, int]): # the node, `fwd_tmp` can be freed. if is_placeholder(n): graph_info.fwd_in += activation_size(n.meta['out']) + # print(activation_size(n.meta['out'])) if is_forward(n): graph_info.fwd_tmp += activation_size(n.meta['out']) + # print(activation_size(n.meta['out'])) elif is_backward(n): if len(n.users): # liveness analysis is only used in backward diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 47558368fd73..33912ec377ac 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -168,7 +168,7 @@ def unpack(x): # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. - if is_autogradable(out) and out.requires_grad and not inplace: + if is_autogradable(out) and out.requires_grad: stage = Stage.L loss = out.sum() stage = Stage.B @@ -177,8 +177,6 @@ def unpack(x): graph_info = autograd_graph_analysis(subgraph) meta_info.fwd_flop, meta_info.bwd_flop = flop_count[Stage.F], flop_count[Stage.B] meta_info.__dict__.update(graph_info.__dict__) - if inplace: - meta_info.fwd_in = 0 def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x From d3c3690d8e1ef409be8fb6ce771343d46bf73ad3 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 16:48:22 +0800 Subject: [PATCH 11/14] [fx] refactor code. --- colossalai/fx/passes/meta_info_prop.py | 14 ++++---- colossalai/fx/profiler/__init__.py | 1 + colossalai/fx/profiler/dataflow.py | 49 ++++++++++++++------------ colossalai/fx/profiler/profiler.py | 49 +++++++------------------- 4 files changed, 48 insertions(+), 65 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index e3ba8955c80a..1d2638a02a10 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -1,12 +1,12 @@ from dataclasses import asdict -from colossalai.fx.profiler.profiler import MetaInfo +from colossalai.fx.profiler import GraphInfo import torch import torch.fx from torch.fx.node import Node, Argument, Target from torch.utils._pytree import tree_map from typing import Any, Tuple, NamedTuple, Dict from torch.fx._compatibility import compatibility -from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size +from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size @compatibility(is_backward_compatible=True) @@ -97,9 +97,9 @@ def extract_tensor_meta(obj): n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', n.meta.get('fwd_tmp', 0) + n.meta.get('fwd_out', 0)) + setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) for par in n.all_input_nodes: - par.meta['fwd_out'] = max(par.meta.get('fwd_out', 0), n.meta.get('fwd_in', 0)) + par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0) n.meta['type'] = type(result) # retain the autograd graph @@ -128,7 +128,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict result (Any): The argument value that was retrieved meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return super().placeholder(target, args, kwargs), MetaInfo() + return super().placeholder(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -147,7 +147,7 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st result (Any): The argument value that was retrieved meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return super().get_attr(target, args, kwargs), MetaInfo() + return super().get_attr(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -225,7 +225,7 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, result (Any): The argument value that was retrieved meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return args[0], MetaInfo(fwd_in=activation_size(args[0])) + return args[0], GraphInfo(fwd_mem_in=activation_size(args[0])) def propagate(self, *args): """ diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index bc9418938fc9..fb19618b2762 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -6,4 +6,5 @@ else: from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module +from .dataflow import GraphInfo from .memory import parameter_size, activation_size diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 12b8a3f48b3a..da3b56a2adf9 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -6,16 +6,17 @@ class Stage(Enum): - F = 0 - L = 1 - B = 2 - P = 3 + FORWARD = 0 + LOSS = 1 + BACKWARD = 2 + PLACEHOLDER = 3 @dataclass class GraphInfo: """ - GraphInfo is a dataclass for the dataflow analysis. + GraphInfo is a dataclass for MetaInfo, which measures + the execution memory cost and FLOPs with `MetaTensor`. The dataflow analysis is conducted on a single node of the FX graph. ============================================================================ ------------------------------- @@ -32,35 +33,39 @@ class GraphInfo: backward. ------------------------------- ============================================================================ Attributes: - fwd_in (int): See the above illustration. - fwd_tmp (int): See the above illustration. - bwd_tmp (int): See the above illustration. - bwd_out (int): See the above illustration. + fwd_flop (int): The forward FLOPs of a certain node + bwd_flop (int): The backward FLOPs of a certain node. + fwd_mem_in (int): See the above illustration. + fwd_mem_tmp (int): See the above illustration. + bwd_mem_tmp (int): See the above illustration. + bwd_mem_out (int): See the above illustration. """ - fwd_in: int = 0 - fwd_tmp: int = 0 - bwd_tmp: int = 0 - bwd_out: int = 0 + fwd_flop: int = 0 + bwd_flop: int = 0 + fwd_mem_in: int = 0 + fwd_mem_tmp: int = 0 + bwd_mem_tmp: int = 0 + bwd_mem_out: int = 0 def is_forward(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Stage.F + return n.meta['stage'] == Stage.FORWARD def is_loss(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Stage.L + return n.meta['stage'] == Stage.LOSS def is_placeholder(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Stage.P + return n.meta['stage'] == Stage.PLACEHOLDER def is_backward(n: Node): assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Stage.B + return n.meta['stage'] == Stage.BACKWARD def is_saved(n: Node): @@ -89,7 +94,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Returns: - graphinfo (GraphInfo): Meta information for the dataflow. + graph_info (GraphInfo): Meta information for the dataflow. """ def _peak_memory(deps: Dict[Node, int]): @@ -114,20 +119,20 @@ def _peak_memory(deps: Dict[Node, int]): # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint # the node, `fwd_tmp` can be freed. if is_placeholder(n): - graph_info.fwd_in += activation_size(n.meta['out']) + graph_info.fwd_mem_in += activation_size(n.meta['out']) # print(activation_size(n.meta['out'])) if is_forward(n): - graph_info.fwd_tmp += activation_size(n.meta['out']) + graph_info.fwd_mem_tmp += activation_size(n.meta['out']) # print(activation_size(n.meta['out'])) elif is_backward(n): if len(n.users): # liveness analysis is only used in backward deps[n] = len(n.users) - graph_info.bwd_tmp = max(graph_info.bwd_tmp, _peak_memory(deps)) + graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) for input_n in n.all_input_nodes: if input_n in deps: deps[input_n] -= 1 else: # basically a backward node without user is a `grad_out` node - graph_info.bwd_out += activation_size(n.meta['out']) + graph_info.bwd_mem_out += activation_size(n.meta['out']) return graph_info diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 33912ec377ac..347c68c3ac5c 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -6,7 +6,7 @@ from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map from .dataflow import autograd_graph_analysis, Stage -from .memory import WEIRD_OPS, activation_size +from .memory import WEIRD_OPS from .tensor import MetaTensor from .opcount import flop_mapping @@ -23,29 +23,9 @@ def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() -@dataclass -class MetaInfo: - """ - This is a dataclass for MetaInfo, which measures - the execution memory cost and FLOPs with `MetaTensor`. - Attributes: - fwd_flop (int): The forward FLOPs of a certain node - bwd_flop (int): The backward FLOPs of a certain node. - fwd_in (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py - fwd_tmp (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py - bwd_tmp (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py - bwd_out (int): See definitions in https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/fx/profiler/dataflow.py - """ - fwd_flop: int = 0 - bwd_flop: int = 0 - fwd_in: int = 0 - fwd_tmp: int = 0 - bwd_tmp: int = 0 - bwd_out: int = 0 - - def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]: - """Profile a Callable function with args and kwargs. + """ + Profile a Callable function with args and kwargs. Args: target (Callable): A Callable function @@ -54,22 +34,20 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... Returns: out (Tuple[Any, ...]): The argument value that was retrieved. - meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ # This subgraph traces aten level ops inside one node. subgraph = Graph() - meta_info = MetaInfo() - # `flop_count`` serves as a global dictionary to store results. flop_count = { - Stage.F: 0, - Stage.L: 0, - Stage.B: 0, + Stage.FORWARD: 0, + Stage.LOSS: 0, + Stage.BACKWARD: 0, } # `stage` will mark the stage of autograd from outside scope. - stage = Stage.F + stage = Stage.FORWARD # FlopTensor not only get the flop statistics of a single node, # it also build a full autograd graph for this node. @@ -143,7 +121,7 @@ def set_placeholder(x): x._node = subgraph.create_node('placeholder', 'placeholder', (subgraph._root,), name=subgraph._graph_namespace.create_name('input', x._tensor)) - x._node.meta['stage'] = Stage.P + x._node.meta['stage'] = Stage.PLACEHOLDER x._node.meta['out'] = (x._tensor,) tree_map(set_placeholder, args) @@ -169,19 +147,18 @@ def unpack(x): # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. if is_autogradable(out) and out.requires_grad: - stage = Stage.L + stage = Stage.LOSS loss = out.sum() - stage = Stage.B + stage = Stage.BACKWARD loss.backward() graph_info = autograd_graph_analysis(subgraph) - meta_info.fwd_flop, meta_info.bwd_flop = flop_count[Stage.F], flop_count[Stage.B] - meta_info.__dict__.update(graph_info.__dict__) + graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD] def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - return tree_map(unwrap, out), meta_info + return tree_map(unwrap, out), graph_info def profile_function(target: 'Target') -> Callable: From de2be8f6d238c108d1ca4fdb76c19909c4fa4932 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 16:50:12 +0800 Subject: [PATCH 12/14] [fx] refactor code. --- .../fx/profiler/experimental/profiler.py | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index 9a746f95fc15..cbf95c1ba65c 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -1,12 +1,50 @@ +from dataclasses import dataclass from typing import Callable, Any, Dict, Tuple import torch from torch.fx.node import Argument, Target from . import meta_profiler_function, meta_profiler_module from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS -from ..profiler import MetaInfo __all__ = ['profile_function', 'profile_module', 'profile_method'] + +# this is for compatibility use +@dataclass +class GraphInfo: + """ + GraphInfo is a dataclass for MetaInfo, which measures + the execution memory cost and FLOPs with `MetaTensor`. + The dataflow analysis is conducted on a single node of the FX graph. + ============================================================================ + ------------------------------- + | Node | + [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out` + placeholders saved for | | \__________ | | + backward. | | \ | | + | [fwd_tmp] ------> [bwd_tmp] | <----- + | | \_________ | | [bwd_tmp] marks the peak memory + | / \ \ | | in backward pass. + [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- + in [fwd_tmp] because | | | \_____ | | + it is not saved for | | | \ | | + backward. ------------------------------- + ============================================================================ + Attributes: + fwd_flop (int): The forward FLOPs of a certain node + bwd_flop (int): The backward FLOPs of a certain node. + fwd_mem_in (int): See the above illustration. + fwd_mem_tmp (int): See the above illustration. + bwd_mem_tmp (int): See the above illustration. + bwd_mem_out (int): See the above illustration. + """ + fwd_flop: int = 0 + bwd_flop: int = 0 + fwd_mem_in: int = 0 + fwd_mem_tmp: int = 0 + bwd_mem_tmp: int = 0 + bwd_mem_out: int = 0 + + CALL_FUNCTION_MSG = \ """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n From bc727c2e274fb8e5ec818a7252fd81b3ad317f44 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 16:54:24 +0800 Subject: [PATCH 13/14] [fx] refactor code. --- colossalai/fx/profiler/experimental/profiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index cbf95c1ba65c..954e8b49bbe9 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -98,7 +98,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: else: profiler = meta_profiler_function.get(target.__name__) fwd_flop, _ = profiler(*args, **kwargs) - return out, MetaInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) f.__name__ = target.__name__ func = target @@ -127,7 +127,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) - return out, MetaInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, GraphInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) return f @@ -157,7 +157,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: fwd_out = activation_size(out) profiler = meta_profiler_module.get(type(module)) fwd_flop, _ = profiler(module, *args, **kwargs) - return out, MetaInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) f.__name__ = module.__class__.__name__ func = module.forward From db4483a737d37a98eb27b39b296e72a763d01cf4 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 13 Sep 2022 18:21:15 +0800 Subject: [PATCH 14/14] [fx] dive into backward memory. --- colossalai/fx/profiler/dataflow.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index da3b56a2adf9..f6efbf312dd6 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -120,10 +120,8 @@ def _peak_memory(deps: Dict[Node, int]): # the node, `fwd_tmp` can be freed. if is_placeholder(n): graph_info.fwd_mem_in += activation_size(n.meta['out']) - # print(activation_size(n.meta['out'])) if is_forward(n): graph_info.fwd_mem_tmp += activation_size(n.meta['out']) - # print(activation_size(n.meta['out'])) elif is_backward(n): if len(n.users): # liveness analysis is only used in backward