Skip to content

[fx] provide a stable but not accurate enough version of profiler.#1547

Merged
FrankLeeeee merged 28 commits intohpcaitech:mainfrom
super-dainiu:feature/flop_tensor
Sep 7, 2022
Merged

[fx] provide a stable but not accurate enough version of profiler.#1547
FrankLeeeee merged 28 commits intohpcaitech:mainfrom
super-dainiu:feature/flop_tensor

Conversation

@super-dainiu
Copy link
Copy Markdown
Contributor

@super-dainiu super-dainiu commented Sep 6, 2022

What's new?

With MetaTensor, we can compute flops of any autograd procedure.

tm_models = [
    tm.vgg11, 
    tm.resnet18, 
    tm.densenet121, 
    tm.mobilenet_v3_small, 
    tm.resnext50_32x4d, 
    tm.wide_resnet50_2,
    tm.regnet_x_16gf, 
    tm.mnasnet0_5,
    tm.convnext_tiny,
    tm.efficientnet_b0,
    tm.vit_b_16,
]

for model in tm_models:
    input = torch.rand(4000, 3, 224, 224, device='meta')
    layer = model()
    print(_profile(layer.forward, (input, ), {})[1])

layer = torch.nn.Conv2d(3, 2, 5)
input = torch.rand(4000, 3, 224, 224, device='meta')
print(_profile(layer.forward, (input, ), {})[1])
===========================================================================
(30490748928000, 60927113120000)
(7321522176000, 14578522016000)
(11718887424000, 23091915680000)
(260871680000, 7640462784000)
(17287200768000, 62879690656000)
(45957365760000, 91549855648000)
(64395440128000, 238909743008000)
(490859008000, 12065158560000)
(17974926342624, 137366425504000)
(1713884400000, 63079273104000)
(70435619904000, 140725435360000)
(29040000000, 58080000000)

Combined with MetaInfoProp, every node will have the following attribute, which will facilitate research on act_ckpt with more specific information.

Node:
    flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop).
    mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)

And the MetaInfoProp results are almost accurate.

model estimated_fwd_mem estimated_param_mem real_fwd_mem real_param_mem fwd_flop bwd_flop
<function densenet121 at 0x7fc843351550> 550.279 MB 30.437 MB 537.939 MB 30.859 MB fwd_flop=11.719GFLOPs bwd_flop=23.055GFLOPs
<function densenet161 at 0x7fc843351700> 1071.932 MB 109.409 MB 1057.037 MB 111.325 MB fwd_flop=31.627GFLOPs bwd_flop=62.538GFLOPs
<function densenet169 at 0x7fc8433518b0> 678.152 MB 53.976 MB 666.501 MB 54.724 MB fwd_flop=13.902GFLOPs bwd_flop=27.341GFLOPs
<function densenet201 at 0x7fc843351a60> 876.703 MB 76.347 MB 866.970 MB 77.392 MB fwd_flop=17.765GFLOPs bwd_flop=34.930GFLOPs
<function convnext_tiny at 0x7fe7e2bf9310> 583.442 MB 109.059 MB 539.976 MB 109.942 MB fwd_flop=17.975GFLOPs bwd_flop=137.358GFLOPs
<function convnext_small at 0x7fe7e2bf94c0> 934.707 MB 191.588 MB 885.770 MB 192.682 MB fwd_flop=34.974GFLOPs bwd_flop=272.979GFLOPs
<function convnext_base at 0x7fe7e2bf9670> 1328.002 MB 337.950 MB 1233.178 MB 338.043 MB fwd_flop=61.738GFLOPs bwd_flop=484.819GFLOPs
<function convnext_large at 0x7fe7e2bf9820> 2238.342 MB 754.423 MB 2089.938 MB 755.756 MB fwd_flop=137.925GFLOPs bwd_flop=1089.767GFLOPs
<function vit_b_16 at 0x7f320112c5e0> 676.548 MB 330.229 MB 869.416 MB 330.229 MB fwd_flop=70.436GFLOPs bwd_flop=90.311GFLOPs
<function vit_b_32 at 0x7f320112c790> 426.163 MB 336.549 MB 460.042 MB 337.311 MB fwd_flop=17.678GFLOPs bwd_flop=23.616GFLOPs
<function vit_h_14 at 0x7f320112c430> 4366.081 MB 2411.063 MB 5599.554 MB 2475.842 MB fwd_flop=670.212GFLOPs bwd_flop=864.708GFLOPs
<function vit_l_16 at 0x7f320112c940> 2065.175 MB 1160.914 MB 2594.511 MB 1162.164 MB fwd_flop=246.692GFLOPs bwd_flop=318.905GFLOPs
<function vit_l_32 at 0x7f320112caf0> 1400.561 MB 1169.340 MB 1510.237 MB 1169.434 MB fwd_flop=61.619GFLOPs bwd_flop=81.867GFLOPs
<function gpt2_medium at 0x7f320007aaf0> 64452.645 MB 1353.543 MB 55721.340 MB 1377.555 MB fwd_flop=3321.385GFLOPs bwd_flop=6634.189GFLOPs

TODO

I skipped the test for checkpoint solvers because it should integrate new features.

Concerns

This profiler is still not accurate enough.

Tests

All tests passed with PyTorch 1.11 (CI) and Pytorch 1.12 (as below).
image

Comment on lines +184 to +187
@register_meta(aten.hardtanh_backward.default)
def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: int, max_val: int):
grad_in = torch.empty_like(input)
return grad_in
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only some extra registrations in this file.

Comment thread colossalai/__init__.py
Comment on lines +2 to +6
from . import _meta_registrations
META_COMPATIBILITY = True
except:
import torch
META_COMPATIBILITY = False
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

META_COMPATIBILITY is checked when Colossal-AI initializes.

Comment on lines +106 to +107
for param in self.module.parameters():
param.grad = None
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obviously, we need to clear grad of the parameter, because these grads are meta

@@ -0,0 +1,125 @@
from typing import Callable, Any, Dict, Tuple
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the old one, so I did not modify anything except for the output format.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old for PyTorch 1.11

Comment on lines +8 to +30
if META_COMPATIBILITY:
aten = torch.ops.aten

WEIRD_OPS = [
torch.where,
]

INPLACE_ATEN = [
aten.add_.Tensor,
aten.add.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.mul.Tensor,
aten.bernoulli_.float,

# inplace reshaping
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are created if we have META_COMPATIBILITY

@FrankLeeeee FrankLeeeee merged commit 4f59693 into hpcaitech:main Sep 7, 2022
@super-dainiu super-dainiu deleted the feature/flop_tensor branch September 7, 2022 05:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants