Skip to content

[fx] fix MetaInfoProp for incorrect calculations and add detections for inplace op.#1466

Merged
YuliangLiu0306 merged 32 commits intohpcaitech:mainfrom
super-dainiu:feature/linear_ckpt
Aug 18, 2022
Merged

[fx] fix MetaInfoProp for incorrect calculations and add detections for inplace op.#1466
YuliangLiu0306 merged 32 commits intohpcaitech:mainfrom
super-dainiu:feature/linear_ckpt

Conversation

@super-dainiu
Copy link
Copy Markdown
Contributor

What's fixed?

In PR #1344, I added more calculations for node statistics. However, it turned out to be incorrect and might have bugs. I fixed this and hope it could be merged as soon as possible in order not to cause conflicts with other parts of the codes.
Also, during my tests on real models, I discovered that inplace operands are very common and cannot be ignored when calculating node statistics.

After using the following script to test on resnet18, the result of MetaInfoProp is close to the real memory cost.

from typing import Tuple
import torch
import torch.nn as nn
import torchvision.models as tm
from torch.fx import symbolic_trace
import torch.fx
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from colossalai.fx.passes.meta_info_prop import MetaInfoProp


def _forward_mem(gm: torch.fx.GraphModule):
    node_size = 0
    param_size = 0
    for node in gm.graph.nodes:
        node_size += getattr(node, 'node_size', 0)
        param_size += getattr(node, 'param_size', 0)
    return node_size / 1024**2, param_size / 1024**2


def data_gen(batch_size: int, shape: Tuple[int, int, int], device='cuda'):
    data = torch.rand(batch_size, *shape, device=device)
    label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
    return data, label


def test_forward(gm: torch.fx.GraphModule, num_steps: int=5):
    def get_gpu_mem():
        result = torch.cuda.max_memory_allocated() / 1024**2
        torch.cuda.reset_peak_memory_stats()
        return result

    forward_mem = -get_gpu_mem()
    param_mem = -get_gpu_mem()
    gm.train()
    gm.cuda()
    param_mem += get_gpu_mem()
    criterion = CrossEntropyLoss()
    optimizer = Adam(gm.parameters(), lr=1e-3)
    for n in range(num_steps):
        data, label = data_gen(4, (3, 224, 224))
        output = gm(data)
        optimizer.zero_grad()
        loss = criterion(output, label)
        forward_mem += get_gpu_mem() / num_steps
        loss.backward()
        optimizer.step()
    return forward_mem, param_mem

        
def test_meta_info_prop():
    model = tm.resnet18()
    data = torch.rand(4, 3, 224, 224)
    gm = symbolic_trace(model)
    MetaInfoProp(gm).run(data)
    meta_forward_mem, meta_param_mem = _forward_mem(gm)
    concrete_forward_mem, concrete_param_mem = test_forward(gm, num_steps=1)
    print(f'Estimated ({meta_forward_mem:.3f} MB, {meta_param_mem:.3f} MB), Real ({concrete_forward_mem:.3f} MB, {concrete_param_mem:.3f} MB)')
    
        
if __name__ == '__main__':
    test_meta_info_prop()

The result is Estimated (137.279 MB, 44.592 MB), Real (141.258 MB, 44.690 MB)

super-dainiu and others added 30 commits August 9, 2022 23:23
* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.
* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] fix test and algorithm bugs in activation checkpointing.

* [fx] polish ckpt_test.

* [fx] add rules to linearize computation graphs for searching.
@YuliangLiu0306 YuliangLiu0306 merged commit bbc58d8 into hpcaitech:main Aug 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants