Skip to content
Closed
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
06f8991
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu Aug 9, 2022
3cd7d22
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu Aug 9, 2022
0849b3b
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu Aug 10, 2022
701786c
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 10, 2022
a75e5a2
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 10, 2022
c20beb2
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 11, 2022
7e87286
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 11, 2022
f027931
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 12, 2022
9b4f460
[fx] merge development into main (#1)
super-dainiu Aug 12, 2022
bea7060
[fx] add rules to linearize computation graphs for searching. (#2)
super-dainiu Aug 16, 2022
86c005d
[fx] merge
super-dainiu Aug 16, 2022
da259cc
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
296b405
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
bf7feea
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
e6c5f70
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
0cbafd8
Merge branch 'feature/linear_ckpt' of http://github.com/super-dainiu/…
super-dainiu Aug 16, 2022
8e14703
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
92e8223
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
3e9531c
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
02c5cae
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
a8616ef
Merge branch 'hpcaitech:main' into feature/linear_ckpt
super-dainiu Aug 17, 2022
083cf7f
[fx] fix inconsistencies.
super-dainiu Aug 17, 2022
9c7441e
[fx] fix MetaInfoProp.
super-dainiu Aug 17, 2022
76f55b7
Merge branch 'hpcaitech:main' into feature/linear_ckpt
super-dainiu Aug 17, 2022
2c8a827
[fx] fix MetaInfoProp.
super-dainiu Aug 17, 2022
b1afd09
Merge branch 'feature/linear_ckpt' of http://github.com/super-dainiu/…
super-dainiu Aug 17, 2022
ff71edc
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
c90d14a
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
ea7250b
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
77406fe
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
0da5d29
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor)


def _compute_node_numel(node_metadata: any) -> int:
def _compute_activation_size(node_metadata: any) -> int:
"""
Compute numel of a node with ``tensor_meta`` attribute.
"""
node_numel = 0

if isinstance(node_metadata, TensorMetadata):
node_numel += node_metadata.numel
node_numel += node_metadata.numel * torch.tensor([], dtype=node_metadata.dtype).element_size()
elif isinstance(node_metadata, dict):
value_list = [v for _, v in node_metadata.items()]
node_numel += _compute_node_numel(value_list)
node_numel += _compute_activation_size(value_list)
else:
for element in node_metadata:
node_numel += _compute_node_numel(element)
node_numel += _compute_activation_size(element)

return node_numel

Expand Down Expand Up @@ -105,6 +105,7 @@ class MetaInfoProp(torch.fx.Interpreter):
"""

def run_node(self, n: Node) -> Any:
# TODO: We might run_node(n) with meta data, and count FLOPS for each node
result = super().run_node(n)

def extract_tensor_meta(obj):
Expand All @@ -116,24 +117,20 @@ def extract_tensor_meta(obj):
meta = _map_aggregate(result, extract_tensor_meta)
n.meta['tensor_meta'] = meta

# get byte size for each element
size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size()

# compute the total size of activation tensors
total_activation_size = _compute_node_numel(n.meta['tensor_meta'])

# compute the total size of model parameters
total_activation_size = 0
total_param_size = 0
if n.op == 'call_module':
target_module = n.graph.owning_module.get_submodule(n.target)
if getattr(target_module, 'inplace', False):
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
for param in target_module.parameters():
total_param_size += param.numel()

# compute the total memory cost of activation tensors and model parameters
total_activation_size *= size_per_elem_bytes
total_param_size *= size_per_elem_bytes
total_param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
elif n.op == 'call_function':
if 'inplace' not in n.kwargs:
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
else:
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])

# TODO: node.node_size is not an original attribute
setattr(n, 'node_size', total_activation_size + total_param_size)
setattr(n, 'param_size', total_param_size)
setattr(n, 'activation_size', total_activation_size)
Expand Down