Skip to content

[fx] add profiler for fx nodes.#1480

Merged
FrankLeeeee merged 43 commits intohpcaitech:mainfrom
super-dainiu:feature/profiler
Aug 24, 2022
Merged

[fx] add profiler for fx nodes.#1480
FrankLeeeee merged 43 commits intohpcaitech:mainfrom
super-dainiu:feature/profiler

Conversation

@super-dainiu
Copy link
Copy Markdown
Contributor

@super-dainiu super-dainiu commented Aug 23, 2022

What's new?

After patching all possible ops, we can now profile the memory cost and FLOPs with lines of code. We only support the original torch.nn.functional and torch.nn, but it is not too challenging to profile your own model using MetaInfoProp.

import torch
from colossalai.fx.profiler import profile_function, profile_module


input = torch.rand(100, 100, 100, 100, device='meta')
func = torch.nn.functional.relu
output, profile = profile_function(func)(input, inplace=False)
print(f"Profiling function {func},")
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")

output, profile = profile_function(func)(input, inplace=True)
print(f"Profiling function {func},")
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")

input = torch.rand(4, 3, 224, 224, device='meta')
mod = torch.nn.Conv2d(3, 128, 3)
output, profile = profile_module(mod)(input)
print(f"Profiling function {mod},")
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")

===============================================================================
Result:
Profiling function <function relu at 0x7f3b6f8ead30>,
Param size: 0.000 MB, Activation size: 381.470 MB, 100000000 FLOPs, 0 MACs
Profiling function <function relu at 0x7f3b6f8ead30>,
Param size: 0.000 MB, Activation size: 0.000 MB, 100000000 FLOPs, 0 MACs
Profiling function Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1)),
Param size: 0.014 MB, Activation size: 96.258 MB, 1387837440 FLOPs, 681302016 MACs
===============================================================================

Also using MetaInfoProp, we can trace the model using option device='meta' solely and get all the required results.

from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tm
from torch.fx import symbolic_trace
import torch.fx
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from colossalai.fx.passes.meta_info_prop import MetaInfoProp


def _forward_mem(gm: torch.fx.GraphModule):
    node_size = 0
    param_size = 0
    for node in gm.graph.nodes:
        node_size += getattr(node, '__param__', 0) + getattr(node, '__activation__', 0)
        param_size += getattr(node, '__param__', 0)
    return node_size / 1024**2, param_size / 1024**2


def _forward_flops(gm: torch.fx.GraphModule):
    flops = 0
    macs = 0
    for node in gm.graph.nodes:
        flops += getattr(node, '__flops__', 0)
        macs += getattr(node, '__macs__', 0)
    return flops / 1e9, macs / 1e9


def data_gen(batch_size: int, shape: Tuple[int, int, int], device='cuda'):
    data = torch.rand(batch_size, *shape, device=device)
    label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
    return data, label


def test_forward(gm: torch.fx.GraphModule, num_steps: int=5):
    def get_gpu_mem():
        result = torch.cuda.max_memory_allocated() / 1024**2
        torch.cuda.reset_peak_memory_stats()
        return result

    get_gpu_mem()   # reset
    forward_mem = -get_gpu_mem()
    param_mem = -get_gpu_mem()
    gm.train()
    gm.cuda()
    param_mem += get_gpu_mem()
    criterion = CrossEntropyLoss()
    optimizer = Adam(gm.parameters(), lr=1e-3)
    for n in range(num_steps):
        data, label = data_gen(1, (3, 224, 224))
        output = gm(data)
        optimizer.zero_grad()
        loss = criterion(output, label)
        forward_mem += get_gpu_mem() / num_steps
        loss.backward()
        optimizer.step()
    return forward_mem, param_mem

        
def test_meta_info_prop():
    for M in [tm.densenet121, tm.densenet161, tm.densenet169, tm.densenet201]:
        model = M()
        data = torch.rand(1, 3, 224, 224, device='meta')
        gm = symbolic_trace(model)
        MetaInfoProp(gm).run(data)
        meta_forward_mem, meta_param_mem = _forward_mem(gm)
        flops, macs = _forward_flops(gm)
        concrete_forward_mem, concrete_param_mem = test_forward(gm, num_steps=1)

        print(f'|{M}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|{flops:.3f}GFLOPs|{macs:.3f}GMACs|')
    
        
if __name__ == '__main__':
    test_meta_info_prop()

===============================================================================
Result:
|<function densenet121 at 0x7f99d58f7b80>|158.786 MB|30.437 MB|156.183 MB|30.859 MB|5.717GFLOPs|2.834GMACs|
|<function densenet161 at 0x7f99d58f7d30>|347.533 MB|109.409 MB|349.309 MB|112.571 MB|15.546GFLOPs|7.728GMACs|
|<function densenet169 at 0x7f99d58f7ee0>|208.338 MB|53.976 MB|209.491 MB|54.724 MB|6.778GFLOPs|3.360GMACs|
|<function densenet201 at 0x7f99d58ff0d0>|274.686 MB|76.347 MB|277.507 MB|77.392 MB|8.659GFLOPs|4.291GMACs|
===============================================================================

super-dainiu and others added 30 commits August 9, 2022 23:23
* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.
* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] fix test and algorithm bugs in activation checkpointing.

* [fx] polish ckpt_test.

* [fx] add rules to linearize computation graphs for searching.
Copy link
Copy Markdown
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌GREAT🙌

@FrankLeeeee
Copy link
Copy Markdown
Contributor

Great work!

@super-dainiu
Copy link
Copy Markdown
Contributor Author

image
I passed all tests/test_fx tests locally on A100.

@FrankLeeeee FrankLeeeee merged commit 32efe8e into hpcaitech:main Aug 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants