From cdaa42a9a3a1a569a10359cba46938f3c5114db0 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 1 Feb 2023 15:11:13 +0800 Subject: [PATCH 01/30] add alphafold benchmark --- .../test_alphafold/benchmark_alphafold.py | 131 ++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 tests/test_autochunk/test_alphafold/benchmark_alphafold.py diff --git a/tests/test_autochunk/test_alphafold/benchmark_alphafold.py b/tests/test_autochunk/test_alphafold/benchmark_alphafold.py new file mode 100644 index 000000000000..58b79883b739 --- /dev/null +++ b/tests/test_autochunk/test_alphafold/benchmark_alphafold.py @@ -0,0 +1,131 @@ +import time +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def _benchmark_evoformer_stack_gm( + data_args: tuple, + max_memory: int, + get_model: Any, + get_data: Any, +) -> None: + # build model and input + model = get_model() + meta_args, concrete_args = get_data(*data_args) + if concrete_args is None: + concrete_args = [] + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + interp = MetaInfoProp(meta_graph) + meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + ) + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # init inputs + inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda() + + # bench + mem = _benchmark_memory(gm, inputs) + speed = _benchmark_speed(gm, inputs) + print("evoformer stack gm, mem: %.2fMB, time: %.4fs, data_args: %s" % (mem, speed, str(data_args))) + + +def _benchmark_evoformer_stack_origin( + data_args: tuple, + get_model: Any, + get_data: Any, +) -> None: + # build model and input + model = get_model() + meta_args, concrete_args = get_data(*data_args) + if concrete_args is None: + concrete_args = [] + + # init inputs + inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda() + + # bench + mem = _benchmark_memory(model, inputs) + speed = _benchmark_speed(model, inputs) + print("evoformer stack origin, mem: %.2fMB, time: %.4fs, data_args: %s" % (mem, speed, str(data_args))) + + +def _benchmark_memory(model, inputs): + with torch.no_grad(): + torch.cuda.reset_peak_memory_stats() + now_mem = torch.cuda.memory_allocated() / 1024**2 + model(*inputs) + new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + return new_max_mem - now_mem + + +def _benchmark_speed(model, inputs, loop=5): + with torch.no_grad(): + for _ in range(loop // 2 + 1): + model(*inputs) + torch.cuda.synchronize() + time1 = time.time() + for _ in range(loop): + model(*inputs) + torch.cuda.synchronize() + time2 = time.time() + return (time2 - time1) / loop + + +def benchmark_evoformer_stack(): + from test_evoformer_stack import get_data, get_model + data_args = [128, 256] + print("") + _benchmark_evoformer_stack_origin(data_args, get_model, get_data) + _benchmark_evoformer_stack_gm(data_args, 600, get_model, get_data) + _benchmark_evoformer_stack_gm(data_args, 400, get_model, get_data) + _benchmark_evoformer_stack_gm(data_args, None, get_model, get_data) + + +if __name__ == "__main__": + # launch colossalai + colossalai.launch( + config={}, + rank=0, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + benchmark_evoformer_stack() From 752e37adf3164996657e0e6012345f947ea41584 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 1 Feb 2023 15:32:31 +0800 Subject: [PATCH 02/30] renae alphafold test --- .../benchmark_alphafold.py | 0 .../test_alphafold_utils.py | 0 .../test_evoformer_block.py | 0 .../test_evoformer_stack.py | 0 .../test_extramsa_block.py | 0 5 files changed, 0 insertions(+), 0 deletions(-) rename tests/test_autochunk/{test_alphafold => test_autochunk_alphafold}/benchmark_alphafold.py (100%) rename tests/test_autochunk/{test_alphafold => test_autochunk_alphafold}/test_alphafold_utils.py (100%) rename tests/test_autochunk/{test_alphafold => test_autochunk_alphafold}/test_evoformer_block.py (100%) rename tests/test_autochunk/{test_alphafold => test_autochunk_alphafold}/test_evoformer_stack.py (100%) rename tests/test_autochunk/{test_alphafold => test_autochunk_alphafold}/test_extramsa_block.py (100%) diff --git a/tests/test_autochunk/test_alphafold/benchmark_alphafold.py b/tests/test_autochunk/test_autochunk_alphafold/benchmark_alphafold.py similarity index 100% rename from tests/test_autochunk/test_alphafold/benchmark_alphafold.py rename to tests/test_autochunk/test_autochunk_alphafold/benchmark_alphafold.py diff --git a/tests/test_autochunk/test_alphafold/test_alphafold_utils.py b/tests/test_autochunk/test_autochunk_alphafold/test_alphafold_utils.py similarity index 100% rename from tests/test_autochunk/test_alphafold/test_alphafold_utils.py rename to tests/test_autochunk/test_autochunk_alphafold/test_alphafold_utils.py diff --git a/tests/test_autochunk/test_alphafold/test_evoformer_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_evoformer_block.py similarity index 100% rename from tests/test_autochunk/test_alphafold/test_evoformer_block.py rename to tests/test_autochunk/test_autochunk_alphafold/test_evoformer_block.py diff --git a/tests/test_autochunk/test_alphafold/test_evoformer_stack.py b/tests/test_autochunk/test_autochunk_alphafold/test_evoformer_stack.py similarity index 100% rename from tests/test_autochunk/test_alphafold/test_evoformer_stack.py rename to tests/test_autochunk/test_autochunk_alphafold/test_evoformer_stack.py diff --git a/tests/test_autochunk/test_alphafold/test_extramsa_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_extramsa_block.py similarity index 100% rename from tests/test_autochunk/test_alphafold/test_extramsa_block.py rename to tests/test_autochunk/test_autochunk_alphafold/test_extramsa_block.py From d78a7378bbc438c5ab7e7cf5871f7622a7909e4b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 1 Feb 2023 15:35:55 +0800 Subject: [PATCH 03/30] rename tests --- ...{benchmark_alphafold.py => benchmark_autochunk_alphafold.py} | 2 +- ...est_alphafold_utils.py => test_autochunk_alphafold_utils.py} | 0 ...est_evoformer_block.py => test_autochunk_evoformer_block.py} | 2 +- ...est_evoformer_stack.py => test_autochunk_evoformer_stack.py} | 2 +- ...{test_extramsa_block.py => test_autochunk_extramsa_block.py} | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename tests/test_autochunk/test_autochunk_alphafold/{benchmark_alphafold.py => benchmark_autochunk_alphafold.py} (98%) rename tests/test_autochunk/test_autochunk_alphafold/{test_alphafold_utils.py => test_autochunk_alphafold_utils.py} (100%) rename tests/test_autochunk/test_autochunk_alphafold/{test_evoformer_block.py => test_autochunk_evoformer_block.py} (97%) rename tests/test_autochunk/test_autochunk_alphafold/{test_evoformer_stack.py => test_autochunk_evoformer_stack.py} (97%) rename tests/test_autochunk/test_autochunk_alphafold/{test_extramsa_block.py => test_autochunk_extramsa_block.py} (97%) diff --git a/tests/test_autochunk/test_autochunk_alphafold/benchmark_alphafold.py b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py similarity index 98% rename from tests/test_autochunk/test_autochunk_alphafold/benchmark_alphafold.py rename to tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py index 58b79883b739..d2ff3941271e 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/benchmark_alphafold.py +++ b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py @@ -109,7 +109,7 @@ def _benchmark_speed(model, inputs, loop=5): def benchmark_evoformer_stack(): - from test_evoformer_stack import get_data, get_model + from test_autochunk_evoformer_stack import get_data, get_model data_args = [128, 256] print("") _benchmark_evoformer_stack_origin(data_args, get_model, get_data) diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_alphafold_utils.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py similarity index 100% rename from tests/test_autochunk/test_autochunk_alphafold/test_alphafold_utils.py rename to tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_evoformer_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py similarity index 97% rename from tests/test_autochunk/test_autochunk_alphafold/test_evoformer_block.py rename to tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py index 99a54fe18e5d..be727701c091 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_evoformer_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py @@ -12,7 +12,7 @@ except: HAS_REPO = False -from test_alphafold_utils import run_test +from test_autochunk_alphafold_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_evoformer_stack.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py similarity index 97% rename from tests/test_autochunk/test_autochunk_alphafold/test_evoformer_stack.py rename to tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py index 06aba07990e8..5210c1c8d48e 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_evoformer_stack.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py @@ -12,7 +12,7 @@ except: HAS_REPO = False -from test_alphafold_utils import run_test +from test_autochunk_alphafold_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_extramsa_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py similarity index 97% rename from tests/test_autochunk/test_autochunk_alphafold/test_extramsa_block.py rename to tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py index 1b0273a1684f..f8102f351982 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_extramsa_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py @@ -11,7 +11,7 @@ HAS_REPO = True except: HAS_REPO = False -from test_alphafold_utils import run_test +from test_autochunk_alphafold_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE From 901126f7a4763bc07a80113c50fa67bbe7ee7cb9 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 1 Feb 2023 15:37:52 +0800 Subject: [PATCH 04/30] rename diffuser --- .../test_autochunk_diffuser_utils.py} | 0 .../test_autochunk_unet.py} | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename tests/test_autochunk/{test_diffuser/test_diffuser_utils.py => test_autochunk_diffuser/test_autochunk_diffuser_utils.py} (100%) rename tests/test_autochunk/{test_diffuser/test_unet.py => test_autochunk_diffuser/test_autochunk_unet.py} (96%) diff --git a/tests/test_autochunk/test_diffuser/test_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py similarity index 100% rename from tests/test_autochunk/test_diffuser/test_diffuser_utils.py rename to tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py diff --git a/tests/test_autochunk/test_diffuser/test_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py similarity index 96% rename from tests/test_autochunk/test_diffuser/test_unet.py rename to tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index db154b4bba60..9ebe6f393b20 100644 --- a/tests/test_autochunk/test_diffuser/test_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -13,7 +13,7 @@ MODELS = [] HAS_REPO = False -from test_diffuser_utils import run_test +from test_autochunk_diffuser_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE From 14b32beb10325592cb38821fd4e7b608645c6f56 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 1 Feb 2023 15:41:40 +0800 Subject: [PATCH 05/30] renme --- ...t_transformer_utils.py => test_autochunk_transformer_utils.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/test_autochunk/test_transformer/{test_transformer_utils.py => test_autochunk_transformer_utils.py} (100%) diff --git a/tests/test_autochunk/test_transformer/test_transformer_utils.py b/tests/test_autochunk/test_transformer/test_autochunk_transformer_utils.py similarity index 100% rename from tests/test_autochunk/test_transformer/test_transformer_utils.py rename to tests/test_autochunk/test_transformer/test_autochunk_transformer_utils.py From e93a865db6895b317362fe0c82aa5daf3c2215df Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 1 Feb 2023 15:41:53 +0800 Subject: [PATCH 06/30] rename --- tests/test_autochunk/test_transformer/test_autochunk_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autochunk/test_transformer/test_autochunk_gpt.py b/tests/test_autochunk/test_transformer/test_autochunk_gpt.py index 256df8bbbae5..6e1076ec792b 100644 --- a/tests/test_autochunk/test_transformer/test_autochunk_gpt.py +++ b/tests/test_autochunk/test_transformer/test_autochunk_gpt.py @@ -13,7 +13,7 @@ MODELS = [] HAS_REPO = False -from test_transformer_utils import run_test +from test_autochunk_transformer_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE From 8d6c17696aa078034d90175819ee3b7fd80103ad Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 1 Feb 2023 15:57:42 +0800 Subject: [PATCH 07/30] update transformer --- .../benchmark_autochunk_transformer.py | 136 ++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py diff --git a/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py b/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py new file mode 100644 index 000000000000..65f8b4157585 --- /dev/null +++ b/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py @@ -0,0 +1,136 @@ +import time +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def _benchmark_autochunk_gpt_gm( + model: Any, + data: tuple, + max_memory: int = None, +) -> None: + model = model.cuda().eval() + + # build model and input + meta_args, concrete_args, sequence = data + if concrete_args is None: + concrete_args = {} + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}, + concrete_args={k: v for k, v in concrete_args.items()}, + ) + interp = MetaInfoProp(meta_graph) + meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] + meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + ) + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model.cuda().eval(), + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}, + concrete_args={k: v for k, v in concrete_args.items()}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # init inputs + inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda().eval() + + # bench + mem = _benchmark_memory(gm, inputs) + speed = _benchmark_speed(gm, inputs) + print("gpt gm, mem: %.2fMB, time: %.4fs" % (mem, speed)) + + +def _benchmark_autochunk_gpt_origin( + model: Any, + data: tuple, +) -> None: + # build model and input + meta_args, concrete_args, sequence = data + if concrete_args is None: + concrete_args = {} + + # init inputs + inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda().eval() + + # bench + mem = _benchmark_memory(model, inputs) + speed = _benchmark_speed(model, inputs) + print("gpt origin, mem: %.2fMB, time: %.4fs" % (mem, speed)) + + +def _benchmark_memory(model, inputs): + with torch.no_grad(): + torch.cuda.reset_peak_memory_stats() + now_mem = float(torch.cuda.memory_allocated()) / 1024**2 + model(*inputs) + new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2 + return new_max_mem - now_mem + + +def _benchmark_speed(model, inputs, loop=5): + with torch.no_grad(): + for _ in range(loop // 2 + 1): + model(*inputs) + torch.cuda.synchronize() + time1 = time.time() + for _ in range(loop): + model(*inputs) + torch.cuda.synchronize() + time2 = time.time() + return (time2 - time1) / loop + + +def benchmark_autochunk_gpt(): + from test_autochunk_gpt import GPT2Config, GPT2Model, get_data + + batch = 1 + seq = 512 + n_embd = 96 + + model = GPT2Model + config = GPT2Config(n_embd=n_embd, n_position=seq, n_layer=2, n_head=4) + model = model(config=config) + shape = [batch, seq] + print("") + _benchmark_autochunk_gpt_origin(model, get_data(shape)) + _benchmark_autochunk_gpt_gm(model, get_data(shape), None) + + +if __name__ == "__main__": + # launch colossalai + colossalai.launch( + config={}, + rank=0, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + benchmark_autochunk_gpt() From a548157ff94bb5510b37f2b8a736f7dc0c0010f0 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 1 Feb 2023 16:16:30 +0800 Subject: [PATCH 08/30] update benchmark --- .../benchmark_autochunk_transformer.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py b/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py index 65f8b4157585..8a4d99e5efd5 100644 --- a/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py +++ b/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py @@ -62,7 +62,7 @@ def _benchmark_autochunk_gpt_gm( # bench mem = _benchmark_memory(gm, inputs) speed = _benchmark_speed(gm, inputs) - print("gpt gm, mem: %.2fMB, time: %.4fs" % (mem, speed)) + print("gpt autochunk, mem: %.2fMB, time: %.4fs" % (mem, speed)) def _benchmark_autochunk_gpt_origin( @@ -82,7 +82,7 @@ def _benchmark_autochunk_gpt_origin( # bench mem = _benchmark_memory(model, inputs) speed = _benchmark_speed(model, inputs) - print("gpt origin, mem: %.2fMB, time: %.4fs" % (mem, speed)) + print("gpt origin , mem: %.2fMB, time: %.4fs" % (mem, speed)) def _benchmark_memory(model, inputs): @@ -107,18 +107,13 @@ def _benchmark_speed(model, inputs, loop=5): return (time2 - time1) / loop -def benchmark_autochunk_gpt(): +def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12): from test_autochunk_gpt import GPT2Config, GPT2Model, get_data - - batch = 1 - seq = 512 - n_embd = 96 - model = GPT2Model - config = GPT2Config(n_embd=n_embd, n_position=seq, n_layer=2, n_head=4) + config = GPT2Config(n_embd=n_embd, n_position=seq, n_layer=2, n_head=n_head) model = model(config=config) shape = [batch, seq] - print("") + print("\nbatch: %d, seq: %d, n_embd: %d, n_head: %d" % (batch, seq, n_embd, n_head)) _benchmark_autochunk_gpt_origin(model, get_data(shape)) _benchmark_autochunk_gpt_gm(model, get_data(shape), None) @@ -133,4 +128,4 @@ def benchmark_autochunk_gpt(): port=free_port(), backend="nccl", ) - benchmark_autochunk_gpt() + benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12) From ef4bf3d4bb97e821644ac8fd8cacd002fc02c5de Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 1 Feb 2023 16:18:49 +0800 Subject: [PATCH 09/30] update benchmark --- .../test_transformer/benchmark_autochunk_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py b/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py index 8a4d99e5efd5..836a94e9c882 100644 --- a/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py +++ b/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py @@ -82,7 +82,7 @@ def _benchmark_autochunk_gpt_origin( # bench mem = _benchmark_memory(model, inputs) speed = _benchmark_speed(model, inputs) - print("gpt origin , mem: %.2fMB, time: %.4fs" % (mem, speed)) + print("gpt origin, mem: %.2fMB, time: %.4fs" % (mem, speed)) def _benchmark_memory(model, inputs): From 4255f1cf83cd21d0c6e75c4018bc38284ea9fa28 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 1 Feb 2023 16:21:05 +0800 Subject: [PATCH 10/30] update bench memory --- .../test_autochunk_alphafold/benchmark_autochunk_alphafold.py | 2 +- .../test_transformer/benchmark_autochunk_transformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py index d2ff3941271e..2f56f139abaf 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py +++ b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py @@ -90,7 +90,7 @@ def _benchmark_memory(model, inputs): with torch.no_grad(): torch.cuda.reset_peak_memory_stats() now_mem = torch.cuda.memory_allocated() / 1024**2 - model(*inputs) + model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 return new_max_mem - now_mem diff --git a/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py b/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py index 836a94e9c882..3f9a3542d8f5 100644 --- a/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py +++ b/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py @@ -89,7 +89,7 @@ def _benchmark_memory(model, inputs): with torch.no_grad(): torch.cuda.reset_peak_memory_stats() now_mem = float(torch.cuda.memory_allocated()) / 1024**2 - model(*inputs) + model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2 return new_max_mem - now_mem From 75ed562202706975aa3474124085cff756a9c058 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 2 Feb 2023 11:37:43 +0800 Subject: [PATCH 11/30] update transformer benchmark --- .../benchmark_autochunk_transformer.py | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py b/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py index 3f9a3542d8f5..43cefcb74988 100644 --- a/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py +++ b/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py @@ -8,6 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import parameter_size from colossalai.utils import free_port if AUTOCHUNK_AVAILABLE: @@ -60,9 +61,11 @@ def _benchmark_autochunk_gpt_gm( model.cuda().eval() # bench - mem = _benchmark_memory(gm, inputs) + para_mem = float(parameter_size(model)) / 1024**2 * 6 + act_mem = _benchmark_memory(gm, inputs) speed = _benchmark_speed(gm, inputs) - print("gpt autochunk, mem: %.2fMB, time: %.4fs" % (mem, speed)) + print("gpt autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % + (speed, act_mem, para_mem, act_mem + para_mem)) def _benchmark_autochunk_gpt_origin( @@ -80,9 +83,12 @@ def _benchmark_autochunk_gpt_origin( model.cuda().eval() # bench - mem = _benchmark_memory(model, inputs) + para_mem = float(parameter_size(model)) / 1024**2 * 6 + act_mem = _benchmark_memory(model, inputs) speed = _benchmark_speed(model, inputs) - print("gpt origin, mem: %.2fMB, time: %.4fs" % (mem, speed)) + print("gpt origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % + (speed, act_mem, para_mem, act_mem + para_mem)) + return act_mem def _benchmark_memory(model, inputs): @@ -111,10 +117,19 @@ def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12): from test_autochunk_gpt import GPT2Config, GPT2Model, get_data model = GPT2Model config = GPT2Config(n_embd=n_embd, n_position=seq, n_layer=2, n_head=n_head) + config.max_position_embeddings = seq model = model(config=config) shape = [batch, seq] print("\nbatch: %d, seq: %d, n_embd: %d, n_head: %d" % (batch, seq, n_embd, n_head)) - _benchmark_autochunk_gpt_origin(model, get_data(shape)) + max_mem = _benchmark_autochunk_gpt_origin(model, get_data(shape)) + for ratio in [0.5, 0.4, 0.3, 0.2]: + try: + _benchmark_autochunk_gpt_gm(model, get_data(shape), max_mem * ratio) + except RuntimeError as e: + if e.args[0] == 'Search failed. Try a larger memory threshold.': + break + except Exception as e: + raise e _benchmark_autochunk_gpt_gm(model, get_data(shape), None) @@ -128,4 +143,8 @@ def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12): port=free_port(), backend="nccl", ) - benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12) + benchmark_autochunk_gpt(batch=1, seq=1024, n_embd=768, n_head=12) + benchmark_autochunk_gpt(batch=1, seq=2048, n_embd=768, n_head=12) + benchmark_autochunk_gpt(batch=1, seq=4096, n_embd=768, n_head=12) + benchmark_autochunk_gpt(batch=1, seq=6144, n_embd=768, n_head=12) + benchmark_autochunk_gpt(batch=1, seq=8192, n_embd=768, n_head=12) From a5940dc3ed4443c41c26d5b73abab6642e27d24e Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 2 Feb 2023 11:58:28 +0800 Subject: [PATCH 12/30] rename --- .../benchmark_autochunk_transformer.py | 0 .../test_autochunk_gpt.py | 0 .../test_autochunk_transformer_utils.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename tests/test_autochunk/{test_transformer => test_autochunk_transformer}/benchmark_autochunk_transformer.py (100%) rename tests/test_autochunk/{test_transformer => test_autochunk_transformer}/test_autochunk_gpt.py (100%) rename tests/test_autochunk/{test_transformer => test_autochunk_transformer}/test_autochunk_transformer_utils.py (100%) diff --git a/tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py similarity index 100% rename from tests/test_autochunk/test_transformer/benchmark_autochunk_transformer.py rename to tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py diff --git a/tests/test_autochunk/test_transformer/test_autochunk_gpt.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py similarity index 100% rename from tests/test_autochunk/test_transformer/test_autochunk_gpt.py rename to tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py diff --git a/tests/test_autochunk/test_transformer/test_autochunk_transformer_utils.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py similarity index 100% rename from tests/test_autochunk/test_transformer/test_autochunk_transformer_utils.py rename to tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py From d5f39a6fe3ebfef55176f91ed15f822a9b75c6eb Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 2 Feb 2023 15:04:13 +0800 Subject: [PATCH 13/30] support diffuser --- .../test_autochunk_diffuser/test_autochunk_diffuser_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py index 0f3d22dc51e2..ee228f842aed 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py @@ -35,6 +35,7 @@ def assert_codegen_run( meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, concrete_args={k: v for k, v in concrete_args}, ) + model = model.cuda().eval() interp = MetaInfoProp(meta_graph) meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args] interp.propagate(*meta_tensors) @@ -65,6 +66,7 @@ def assert_codegen_run( # assert result inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] model.cuda().eval() gm.eval() with torch.no_grad(): From b6992e074e099e1de3e047485e96d89583162009 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 2 Feb 2023 15:05:48 +0800 Subject: [PATCH 14/30] support unet metainfo prop --- colossalai/fx/_meta_registrations.py | 31 +++++++++++++--------------- colossalai/fx/profiler/opcount.py | 21 +++++++++++++++++++ 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/colossalai/fx/_meta_registrations.py b/colossalai/fx/_meta_registrations.py index 8c0201c71e08..153214447223 100644 --- a/colossalai/fx/_meta_registrations.py +++ b/colossalai/fx/_meta_registrations.py @@ -164,18 +164,9 @@ def pick_memory_format(): @register_meta(aten._convolution.default) -def meta_conv_1( - input_tensor: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - stride: List[int], - padding: List[int], - dilation: List[int], - is_transposed: bool, - output_padding: List[int], - groups: int, - *extra_args -): +def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], + padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, + *extra_args): out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) return out @@ -233,11 +224,8 @@ def meta_cuda_rnn( if is_input_packed: out_shape = [batch_sizes_sum, out_size * num_directions] else: - out_shape = ( - [mini_batch, seq_length, out_size * num_directions] - if batch_first - else [seq_length, mini_batch, out_size * num_directions] - ) + out_shape = ([mini_batch, seq_length, out_size * + num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) output = input.new_empty(out_shape) cell_shape = [num_layers * num_directions, mini_batch, hidden_size] @@ -372,6 +360,15 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me return dX, dgamma, dbeta +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/group_norm.cpp +@register_meta(aten.native_group_norm_backward.default) +def meta_gn_backward(dY: torch.Tensor, input: torch.Tensor, mean, rstd, gamma, N, C, HxW, group, grad_input_mask): + dX = torch.empty_like(input) + dgamma = torch.empty_like(gamma) + dbeta = torch.empty_like(gamma) + return dX, dgamma, dbeta + + # ================================== Misc ========================================== # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml @register_meta(aten.roll.default) diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 6bd612ad2fd1..d780ef6d49c9 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -70,6 +70,19 @@ def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: return flops +def baddbmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the baddbmm(batch add and batch matmul) operation. + """ + # Inputs = [input, batch1, batch2] + # out = input + batch1 x batch2 + assert len(inputs) == 3, len(inputs) + n, c, t = inputs[1].shape + d = inputs[2].shape[-1] + flops = n * c * t * d + return flops + + def conv_flop_count( x_shape: List[int], w_shape: List[int], @@ -196,6 +209,7 @@ def zero_flop_jit(*args): aten.matmul.default: matmul_flop_jit, aten.addmm.default: addmm_flop_jit, aten.bmm.default: bmm_flop_jit, + aten.baddbmm.default: baddbmm_flop_jit, # convolution aten.convolution.default: conv_flop_jit, @@ -209,6 +223,8 @@ def zero_flop_jit(*args): aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), aten.native_layer_norm.default: norm_flop_counter(2, 0), aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), + aten.native_group_norm.default: norm_flop_counter(2, 0), + aten.native_group_norm_backward.default: norm_flop_counter(2, 0), # pooling aten.avg_pool1d.default: elementwise_flop_counter(1, 0), @@ -230,6 +246,8 @@ def zero_flop_jit(*args): aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), aten.embedding.default: elementwise_flop_counter(1, 0), + aten.upsample_nearest2d.vec: elementwise_flop_counter(0, 1), + aten.upsample_nearest2d_backward.vec: elementwise_flop_counter(0, 1), } elementwise_flop_aten = [ @@ -251,6 +269,9 @@ def zero_flop_jit(*args): aten.mean.dim, aten.sub.Tensor, aten.sub_.Tensor, + aten.exp.default, + aten.sin.default, + aten.cos.default, # activation op aten.hardswish.default, From 1fde8224b73ac6f20309927ea315173ccdbe75e8 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 2 Feb 2023 17:25:37 +0800 Subject: [PATCH 15/30] fix bug and simplify code --- colossalai/autochunk/trace_indice.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index b591fa764423..0c61b6a1b407 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -469,17 +469,6 @@ def _assign_unsqueeze_indice(self, node: Node, node_idx: int) -> None: dim_idx = list(range(len(get_node_shape(node))))[dim_idx] self._add_dim(node_idx, dim_idx) - def _assign_ones_like_indice(self, node: Node, node_idx: int) -> None: - """ - Assign indice for oneslike op. - 1. assign new indice for all dim - - Args: - node (node) - node_idx (int) - """ - self._assign_all_indice(node, node_idx) - def _assign_cat_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for cat op. @@ -766,7 +755,7 @@ def trace_indice(self) -> None: elif any(i == node_name for i in ["to", "contiguous", "clone", "type"]): self._assgin_no_change_indice(node, idx) elif "new_ones" == node_name: - self._assign_ones_like_indice(node, idx) + self._assign_all_indice(node, idx) elif any(i == node_name for i in ["size"]): continue else: @@ -793,8 +782,6 @@ def trace_indice(self) -> None: "tanh", ]): self._assign_elementwise_indice(node, idx) - elif "ones_like" == node_name: - self._assign_ones_like_indice(node, idx) elif "einsum" == node_name: self._assign_einsum_indice(node, idx) elif "sum" == node_name: @@ -805,10 +792,8 @@ def trace_indice(self) -> None: self._assign_getitem_indice(node, idx) elif "addmm" == node_name: self._assign_addmm_indice(node, idx) - elif "arange" == node_name: - self._assign_arange_indice(node, idx) - elif "tensor" == node_name: - self._assign_arange_indice(node, idx) + elif any(i == node_name for i in ["arange", "one", "ones_like", "tensor"]): + self._assign_all_indice(node, idx) elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]): continue else: From 05ca225a83b2cf06dc705c49ae24593929883226 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 2 Feb 2023 17:39:31 +0800 Subject: [PATCH 16/30] update linear and support some op --- colossalai/autochunk/trace_indice.py | 30 +++++++++++----------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 0c61b6a1b407..68a19e9162e0 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -308,14 +308,14 @@ def _assign_linear_indice(self, node: Node, node_idx: int) -> None: node (node) node_idx (int) """ - if len(node.args) == 2: - _, weight = node.args - else: - _, weight, _ = node.args - self._assign_indice_as_input(node, node_idx) - self._inherit_indice(weight, 1, node, -1) + if len(node.args) >= 2: + weight = node.args[1] + self._inherit_indice(weight, 1, node, -1) + else: + self._del_dim(node_idx, -1) + self._add_dim(node_idx, -1) self._mark_computation(node, node_idx, [-1]) def _assign_addmm_indice(self, node: Node, node_idx: int) -> None: @@ -752,7 +752,7 @@ def trace_indice(self) -> None: self._assign_unsqueeze_indice(node, idx) elif "split" == node_name: self._assign_split_indice(node, idx) - elif any(i == node_name for i in ["to", "contiguous", "clone", "type"]): + elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]): self._assgin_no_change_indice(node, idx) elif "new_ones" == node_name: self._assign_all_indice(node, idx) @@ -770,16 +770,8 @@ def trace_indice(self) -> None: elif "softmax" == node_name: self._assign_softmax_indice(node, idx) elif any(n == node_name for n in [ - "mul", - "add", - "sigmoid", - "relu", - "sub", - "truediv", - "pow", - "dropout", - "where", - "tanh", + "mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp", + "sin", "cos" ]): self._assign_elementwise_indice(node, idx) elif "einsum" == node_name: @@ -792,7 +784,7 @@ def trace_indice(self) -> None: self._assign_getitem_indice(node, idx) elif "addmm" == node_name: self._assign_addmm_indice(node, idx) - elif any(i == node_name for i in ["arange", "one", "ones_like", "tensor"]): + elif any(i == node_name for i in ["arange", "ones", "ones_like", "tensor"]): self._assign_all_indice(node, idx) elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]): continue @@ -804,6 +796,8 @@ def trace_indice(self) -> None: self._assign_layernorm_indice(node, idx) elif "embedding" == node_name: self._assign_embedding_indice(node, idx) + elif "linear" == node_name: + self._assign_linear_indice(node, idx) elif any(n == node_name for n in ["sigmoid", "dropout", "relu"]): self._assign_elementwise_indice(node, idx) else: From b4566dd29b4a788c98c1cf26c86c8b51bb3a4d32 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 3 Feb 2023 16:05:03 +0800 Subject: [PATCH 17/30] optimize max region search, support conv --- colossalai/autochunk/search_chunk.py | 22 +++++++++++++++++--- colossalai/autochunk/trace_indice.py | 30 +++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 0278e03f78de..a758b3300999 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -75,8 +75,8 @@ def _init_trace(self) -> None: max_chunk_region_list = [] while True: max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx) - cur_node_idx = max_chunk_region[1] - if cur_node_idx == len(active_nodes) - 1: + cur_node_idx = max_chunk_region[1] + 1 + if cur_node_idx >= len(active_nodes) - 1: break max_chunk_region_list.append(max_chunk_region) @@ -135,6 +135,7 @@ def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_ min_active_node_num = min(active_node_num[free_var_num:]) threshold = max(free_var_num, min_active_node_num) + # normal search # from peak_node to free_var inside_flag = False chunk_region_start = free_var_num @@ -144,7 +145,6 @@ def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_ if inside_flag and active_node_num[i] > threshold: chunk_region_start = i + 1 break - # from peak_node to len-2 inside_flag = False chunk_region_end = len(active_node) - 1 @@ -155,6 +155,22 @@ def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_ chunk_region_end = i break + # if normal search fails, use approximate search + if (chunk_region_end - chunk_region_start) > 250: + window_size = 100 + # search min for start + min_num = 1e3 + for i in range(max(peak_node_idx - window_size, 0), peak_node_idx + 1): + if active_node_num[i] < min_num: + min_num = active_node_num[i] + chunk_region_start = i + # search min for end + min_num = 1e3 + for i in range(min(peak_node_idx + window_size, len(active_node_num) - 1), peak_node_idx - 1, -1): + if active_node_num[i] < min_num: + min_num = active_node_num[i] + chunk_region_end = i + # avoid chunk regions overlap if chunk_regions is not None: for i in chunk_regions: diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 68a19e9162e0..89563226bb7a 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -354,6 +354,32 @@ def _assign_matmul_indice(self, node: Node, node_idx: int) -> None: self._inherit_more_indice_from_node(matmul_right, node, [-1, -2]) self._mark_computation(node, node_idx, [-1]) + def _assign_conv2d_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for conv2d op. + + Args: + node (node) + node_idx (int) + """ + # get conv module + node_targets = node.target.split(".") + conv_module = node.graph.owning_module + for i in node_targets: + conv_module = getattr(conv_module, i) + assert conv_module.dilation == (1, 1), "dilation for conv2d not implemented" + + # get conv input + assert len(node.args) == 1 + input_node = node.args[0] + assert len(get_node_shape(input_node)) == 4 + + # assgin index + self._assign_indice_as_input(node, node_idx, input_node) + self._del_dim(node_idx, 1) + self._add_dim(node_idx, 1) + self._mark_computation(node, node_idx, [1]) + def _assign_layernorm_indice(self, node, idx): """ Assign indice for layernorm op. @@ -798,7 +824,9 @@ def trace_indice(self) -> None: self._assign_embedding_indice(node, idx) elif "linear" == node_name: self._assign_linear_indice(node, idx) - elif any(n == node_name for n in ["sigmoid", "dropout", "relu"]): + elif "conv2d" == node_name: + self._assign_conv2d_indice(node, idx) + elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu"]): self._assign_elementwise_indice(node, idx) else: raise NotImplementedError(node_name, "module not implemented yet!") From b532e293b944c341956e8aeb57577582396176f6 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 6 Feb 2023 09:49:40 +0800 Subject: [PATCH 18/30] update unet test --- .../test_autochunk_diffuser/test_autochunk_unet.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index 9ebe6f393b20..ef9f5c2b1246 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -18,9 +18,8 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE BATCH_SIZE = 2 -SEQ_LENGTH = 5 -HEIGHT = 224 -WIDTH = 224 +HEIGHT = 448 +WIDTH = 448 IN_CHANNELS = 3 LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7) @@ -44,7 +43,7 @@ def get_data(shape: tuple) -> Tuple[List, List]: ) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("shape", [LATENTS_SHAPE]) -@pytest.mark.parametrize("max_memory", [64]) +@pytest.mark.parametrize("max_memory", [None]) def test_evoformer_block(model, shape, max_memory): run_func = partial( run_test, @@ -62,7 +61,7 @@ def test_evoformer_block(model, shape, max_memory): run_test( rank=0, data=get_data(LATENTS_SHAPE), - max_memory=64, + max_memory=None, model=UNet2DModel, print_code=False, print_mem=False, From baf001e87703e3f54470c7718fd46367649c9ef5 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 6 Feb 2023 11:04:15 +0800 Subject: [PATCH 19/30] support some op --- colossalai/autochunk/trace_indice.py | 52 +++++++++++++++++++++------- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 89563226bb7a..07fd78f36c9e 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -150,7 +150,7 @@ def _inherit_all_indice(self, node_from: Node, node_to: Node) -> None: for i in range(len(node_from_indice)): self._inherit_indice(node_from, i, node_to, i, init=True) - def _inherit_more_indice_from_node(self, node_from: Node, node_to: Node, exclude: List = None) -> None: + def _inherit_more_indice_from_node_with_exclude(self, node_from: Node, node_to: Node, exclude: List = None) -> None: """ inheirt indice from node without init """ @@ -327,13 +327,35 @@ def _assign_addmm_indice(self, node: Node, node_idx: int) -> None: node_idx (int) """ bias, input_node, weight = node.args - + assert len(get_node_shape(bias)) == 1 and len(get_node_shape(weight)) == 2 self._assign_indice_as_input(node, node_idx, input_node) self._inherit_indice(weight, 1, node, -1) - self._inherit_indice(bias, -1, node, -1) + self._inherit_more_indice_from_node_with_exclude(bias, node) self._mark_computation(node, node_idx, [-1]) + def _assign_baddbmm_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for baddbmm(batch add and batch matmul) op. + add, matmul_left, matmul_right = args + out = add + (matmul_left x matmul_right) + + Args: + node (node) + node_idx (int) + """ + add, matmul_left, matmul_right = node.args + + assert get_node_shape(add) == get_node_shape(node) + assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) + self._assign_indice_as_input(node, node_idx, matmul_left) + # matmul + self._inherit_indice(matmul_right, -1, node, -1) + self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-2, -1]) + self._mark_computation(node, node_idx, [-1]) + # add + self._inherit_more_indice_from_node_with_exclude(add, node) + def _assign_matmul_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for matmul op. @@ -349,9 +371,9 @@ def _assign_matmul_indice(self, node: Node, node_idx: int) -> None: assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) self._assign_indice_as_input(node, node_idx, matmul_left) - self._inherit_indice(matmul_right, -1, node, -1) - self._inherit_more_indice_from_node(matmul_right, node, [-1, -2]) + self._inherit_indice(matmul_right, -1, node, -1) + self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-1, -2]) self._mark_computation(node, node_idx, [-1]) def _assign_conv2d_indice(self, node: Node, node_idx: int) -> None: @@ -378,7 +400,7 @@ def _assign_conv2d_indice(self, node: Node, node_idx: int) -> None: self._assign_indice_as_input(node, node_idx, input_node) self._del_dim(node_idx, 1) self._add_dim(node_idx, 1) - self._mark_computation(node, node_idx, [1]) + self._mark_computation(node, node_idx, [1, 2, 3]) def _assign_layernorm_indice(self, node, idx): """ @@ -408,13 +430,13 @@ def _assign_elementwise_indice(self, node, idx): for node_in in node.args: if type(node_in) == type(node): nodes_in.append(node_in) - self._inherit_more_indice_from_node(node_in, node) + self._inherit_more_indice_from_node_with_exclude(node_in, node) def _assgin_no_change_indice(self, node, idx): self._assign_indice_as_input(node, idx) for node_in in node.args: if type(node_in) == type(node): - self._inherit_more_indice_from_node(node_in, node) + self._inherit_more_indice_from_node_with_exclude(node_in, node) def _assign_einsum_indice(self, node, idx): """ @@ -506,7 +528,7 @@ def _assign_cat_indice(self, node: Node, node_idx: int) -> None: nodes_in = flat_list(node.args[0]) self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0]) for n in nodes_in[1:]: - self._inherit_more_indice_from_node(n, node) + self._inherit_more_indice_from_node_with_exclude(n, node) cat_dim = node.kwargs["dim"] self._del_dim(node_idx, cat_dim) self._add_dim(node_idx, cat_dim) @@ -523,7 +545,7 @@ def _assign_sum_indice(self, node: Node, node_idx: int) -> None: self._add_dim(node_idx, 0) self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0]) for n in nodes_in[1:]: - self._inherit_more_indice_from_node(n, node) + self._inherit_more_indice_from_node_with_exclude(n, node) cat_dim = node.kwargs["dim"] self._del_dim(node_idx, cat_dim) @@ -791,7 +813,7 @@ def trace_indice(self) -> None: self._assign_linear_indice(node, idx) elif "cat" == node_name: self._assign_cat_indice(node, idx) - elif "matmul" == node_name: + elif any(n == node_name for n in ["matmul", "bmm"]): self._assign_matmul_indice(node, idx) elif "softmax" == node_name: self._assign_softmax_indice(node, idx) @@ -810,7 +832,11 @@ def trace_indice(self) -> None: self._assign_getitem_indice(node, idx) elif "addmm" == node_name: self._assign_addmm_indice(node, idx) - elif any(i == node_name for i in ["arange", "ones", "ones_like", "tensor"]): + elif "baddbmm" == node_name: + self._assign_baddbmm_indice(node, idx) + elif "interpolate" == node_name: + continue # TODO + elif any(i == node_name for i in ["arange", "ones", "ones_like", "tensor", "empty"]): self._assign_all_indice(node, idx) elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]): continue @@ -820,6 +846,8 @@ def trace_indice(self) -> None: node_name = get_module_node_name(node) if "layernorm" == node_name: self._assign_layernorm_indice(node, idx) + elif "groupnorm" == node_name: + self._assign_layernorm_indice(node, idx) # TODO to change elif "embedding" == node_name: self._assign_embedding_indice(node, idx) elif "linear" == node_name: From dba9f781fd52183b4f6d01f417c0d8469bfaf39f Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 6 Feb 2023 14:18:56 +0800 Subject: [PATCH 20/30] support groupnorm and interpolate --- colossalai/autochunk/trace_indice.py | 55 +++++++++++++++------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 07fd78f36c9e..1e41073d7da6 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -402,6 +402,22 @@ def _assign_conv2d_indice(self, node: Node, node_idx: int) -> None: self._add_dim(node_idx, 1) self._mark_computation(node, node_idx, [1, 2, 3]) + def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for interpolate op. + + Args: + node (node) + node_idx (int) + """ + # get conv input + assert node.kwargs['size'] is None + assert len(get_node_shape(node)) == 4 + + # assgin index + self._assign_indice_as_input(node, node_idx) + self._mark_computation(node, node_idx, [-1, -2]) + def _assign_layernorm_indice(self, node, idx): """ Assign indice for layernorm op. @@ -415,6 +431,18 @@ def _assign_layernorm_indice(self, node, idx): self._assign_indice_as_input(node, idx) self._mark_computation(node, idx, [-1]) + def _assign_groupnorm_indice(self, node, idx): + """ + Assign indice for groupnorm op. + + Args: + node (node) + node_idx (int) + """ + assert len(get_node_shape(node)) == 4 + self._assign_indice_as_input(node, idx) + self._mark_computation(node, idx, [-1, -2, -3]) + def _assign_elementwise_indice(self, node, idx): """ Assign indice for element-wise op (eg. relu sigmoid add mul). @@ -549,29 +577,6 @@ def _assign_sum_indice(self, node: Node, node_idx: int) -> None: cat_dim = node.kwargs["dim"] self._del_dim(node_idx, cat_dim) - def _assign_arange_indice(self, node: Node, node_idx: int) -> None: - """ - Assign indice for arange op. - - Args: - node (node) - node_idx (int) - """ - self._assign_all_indice(node, node_idx) - - def _assign_tensor_indice(self, node: Node, node_idx: int) -> None: - """ - Assign indice for tensor op. - - Args: - node (node) - node_idx (int) - """ - if len(get_node_shape(node)) == 0: - return - else: - raise NotImplementedError() - def _assign_embedding_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for embedding op. @@ -835,7 +840,7 @@ def trace_indice(self) -> None: elif "baddbmm" == node_name: self._assign_baddbmm_indice(node, idx) elif "interpolate" == node_name: - continue # TODO + self._assign_interpolate_indice(node, idx) elif any(i == node_name for i in ["arange", "ones", "ones_like", "tensor", "empty"]): self._assign_all_indice(node, idx) elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]): @@ -847,7 +852,7 @@ def trace_indice(self) -> None: if "layernorm" == node_name: self._assign_layernorm_indice(node, idx) elif "groupnorm" == node_name: - self._assign_layernorm_indice(node, idx) # TODO to change + self._assign_groupnorm_indice(node, idx) elif "embedding" == node_name: self._assign_embedding_indice(node, idx) elif "linear" == node_name: From ed59541612243cc010ed8740d26591db9a963360 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 6 Feb 2023 15:16:36 +0800 Subject: [PATCH 21/30] update flow search --- colossalai/autochunk/trace_flow.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index 11dbb266d4b4..853f69d5bfc5 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -100,6 +100,16 @@ def _assgin_single_node_flow( if not (start_idx <= arg_idx < end_idx): return True + # get fix dim + arg_fix_dim = [] + if cur_node_dim is not None: + for i in cur_node_fix_dim: + fix_dim_source = cur_node_source[i] + if arg_idx in fix_dim_source: + arg_fix_dim.append(fix_dim_source[arg_idx][0]) + if arg_node in all_node_info: + arg_fix_dim = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)) + # find arg dim if cur_node_dim is not None: # dim is computed @@ -109,6 +119,9 @@ def _assgin_single_node_flow( arg_dim = None else: arg_dim = cur_node_source[cur_node_dim][arg_idx][0] + # chunk dim cannot be in fix dims + if arg_dim in arg_fix_dim: + return False # chunk dim should be None if shape size is 1 if get_node_shape(arg_node)[arg_dim] == 1: arg_dim = None @@ -120,19 +133,11 @@ def _assgin_single_node_flow( else: arg_dim = None - # get fix dim - arg_fix_dim = [] - if cur_node_dim is not None: - for i in cur_node_fix_dim: - fix_dim_source = cur_node_source[i] - if arg_idx in fix_dim_source: - arg_fix_dim.append(fix_dim_source[arg_idx][0]) - # if already in node_info, arg dim must be same if arg_node in all_node_info: if all_node_info[arg_node]["chunk_dim"] != arg_dim: return False - all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)) + all_node_info[arg_node]["fix_dim"] = arg_fix_dim # else add it to list else: all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} @@ -164,6 +169,8 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx): continue if is_non_compute_node(arg): continue + if get_node_shape(arg) is None: + continue arg_list.append(arg) flow_flag = self._assgin_single_node_flow( arg, From 4b10d383742a5588e845a6048b468b88652f3b17 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 6 Feb 2023 16:05:34 +0800 Subject: [PATCH 22/30] add fix dim in node flow --- colossalai/autochunk/trace_flow.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index 853f69d5bfc5..16815215f52b 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -133,6 +133,11 @@ def _assgin_single_node_flow( else: arg_dim = None + # add arg rest dim as fix dim + arg_fix_dim = list(range(len(get_node_shape(arg_node)))) + if arg_dim is not None: + arg_fix_dim.remove(arg_dim) + # if already in node_info, arg dim must be same if arg_node in all_node_info: if all_node_info[arg_node]["chunk_dim"] != arg_dim: @@ -187,29 +192,6 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx): if flow_flag == False: return None - if len(arg_list) >= 2: - # need to mark fix dim - if any(i == get_node_name(cur_node) for i in ["add", "mul", "truediv", "sub", "where"]): - for arg in arg_list: - if get_node_shape(arg) is None: - continue - if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx): - continue - arg_chunk_dim = all_node_info[arg]["chunk_dim"] - arg_fix_dim = all_node_info[arg]["fix_dim"] - arg_shape = get_node_shape(arg) - # add all dim as fix dim except chunk dim - for i, shape in enumerate(arg_shape): - if shape != 1 and i != cur_node_chunk_dim: - if i == arg_chunk_dim: - return None - if i not in arg_fix_dim: - arg_fix_dim.append(i) - elif any(i == get_node_name(cur_node) - for i in ["einsum", "matmul", "view", "to", "getitem", "tensor", "type"]): - pass - else: - raise NotImplementedError() cur_node_list = next_node_list return all_node_info From f5a370f1bf282be018c470cbb68a5339b9ec663b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 6 Feb 2023 16:08:34 +0800 Subject: [PATCH 23/30] fix utils --- .../test_autochunk_diffuser/test_autochunk_diffuser_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py index ee228f842aed..d28e96058b28 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py @@ -62,7 +62,7 @@ def assert_codegen_run( code = graph.python_code("self").src if print_code: print(code) - assert "chunk_result = None; chunk_size = None;" in code + assert "chunk_size = None; " in code # assert result inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] From b07c0b0478416e2032b9936f38315b8b4871ff76 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 6 Feb 2023 16:18:58 +0800 Subject: [PATCH 24/30] rename --- colossalai/autochunk/autochunk_codegen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 82937db9f6ba..8701ab641258 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -143,7 +143,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) return context -def _replace_ones_like( +def _replace_new_tensor_like_shape( search_chunk: SearchChunk, chunk_infos: List[Dict], region_idx: int, @@ -265,8 +265,8 @@ def emit_code_with_chunk( body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node) # replace output var with chunk var body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node) - # ones like - body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body) + # tensor like + body = _replace_new_tensor_like_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body) # reassgin reshape size body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"]) body[-1] = " " + body[-1] From 95cf822f9bcacf63c4201cb413e77505b6477ab6 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 6 Feb 2023 16:58:26 +0800 Subject: [PATCH 25/30] support diffusion --- colossalai/autochunk/autochunk_codegen.py | 33 +++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 8701ab641258..3432ca9afba4 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -154,7 +154,7 @@ def _replace_new_tensor_like_shape( """ add chunk slice for new tensor op such as ones like """ - if "ones_like" in node.name: + if get_node_name(node) in ["ones_like", "zeros_like", "empty_like"]: meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx) chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] if get_node_shape(meta_node)[chunk_dim] != 1: @@ -166,6 +166,33 @@ def _replace_new_tensor_like_shape( return body +def _replace_new_tensor_shape( + search_chunk: SearchChunk, + chunk_infos: List[Dict], + region_idx: int, + node_idx: int, + node: Node, + body: List[str], +) -> List[str]: + """ + add chunk slice for new tensor op such as ones + """ + if get_node_name(node) in ["ones", "zeros", "empty"]: + meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx) + chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] + if chunk_dim is None: + return + if get_node_shape(meta_node)[chunk_dim] == 1: + return + origin_shape = str(node.args) + new_shape = list(node.args) + new_shape[chunk_dim] = "min(chunk_size, %d - chunk_idx)" % get_node_shape(meta_node)[chunk_dim] + new_shape = str(new_shape) + new_shape = new_shape.replace("'", "") + body[-1] = _replace_name(body[-1], origin_shape[1:-1], new_shape[1:-1]) + return body + + def _add_node_slice( chunk_nodes: List[Node], region_idx: int, @@ -265,8 +292,10 @@ def emit_code_with_chunk( body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node) # replace output var with chunk var body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node) - # tensor like + # new tensor like body = _replace_new_tensor_like_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body) + # new tensor + body = _replace_new_tensor_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body) # reassgin reshape size body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"]) body[-1] = " " + body[-1] From c6529adcfc2b307726824db0e14fc8712abe454b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 6 Feb 2023 17:28:56 +0800 Subject: [PATCH 26/30] update diffuser --- .../test_autochunk_diffuser_utils.py | 28 ++++++++++++++----- .../test_autochunk_unet.py | 4 +-- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py index d28e96058b28..529250fe8f51 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py @@ -22,6 +22,7 @@ def assert_codegen_run( concrete_args: List = None, max_memory: int = None, print_mem: bool = False, + print_est_mem: bool = False, print_progress: bool = False, print_code: bool = False, ) -> List[Dict]: @@ -42,7 +43,7 @@ def assert_codegen_run( codegen = AutoChunkCodeGen( meta_graph, max_memory=max_memory, - print_mem=print_mem, + print_mem=print_est_mem, print_progress=print_progress, ) chunks = codegen.chunk_infos @@ -70,10 +71,21 @@ def assert_codegen_run( model.cuda().eval() gm.eval() with torch.no_grad(): - out_gm = gm(*inputs) - out_model = model(*inputs) + if print_mem: + torch.cuda.reset_peak_memory_stats() + now_mem_gm = torch.cuda.memory_allocated() / 1024**2 + out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + max_mem_gm = torch.cuda.max_memory_allocated() / 1024**2 + torch.cuda.reset_peak_memory_stats() + now_mem_ori = torch.cuda.memory_allocated() / 1024**2 + out_model = model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2 + print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm)) + assert torch.allclose(out_gm["sample"], out_model["sample"], - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( torch.abs(out_gm["sample"] - out_model["sample"])) return chunks @@ -84,9 +96,10 @@ def run_test( model: Any, data: tuple, max_memory: int, - print_code: bool, - print_mem: bool, - print_progress: bool, + print_code: bool = False, + print_mem: bool = False, + print_est_mem: bool = False, + print_progress: bool = False, get_chunk_target: Any = None, ) -> None: # launch colossalai @@ -108,6 +121,7 @@ def run_test( max_memory=max_memory, print_code=print_code, print_mem=print_mem, + print_est_mem=print_est_mem, print_progress=print_progress, ) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index ef9f5c2b1246..7d1aab2b19be 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -50,9 +50,6 @@ def test_evoformer_block(model, shape, max_memory): max_memory=max_memory, model=model, data=get_data(shape), - print_code=False, - print_mem=False, - print_progress=False, ) mp.spawn(run_func, nprocs=1) @@ -65,5 +62,6 @@ def test_evoformer_block(model, shape, max_memory): model=UNet2DModel, print_code=False, print_mem=False, + print_est_mem=False, print_progress=False, ) From b2c4fdde9af113cd8d2f42de6c11881bb1fb9352 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 6 Feb 2023 17:51:27 +0800 Subject: [PATCH 27/30] update chunk search --- colossalai/autochunk/search_chunk.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index a758b3300999..9a0d1120bdc7 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -287,12 +287,6 @@ def _step_search( best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) return best_chunk_region - def _stop_search(self, init_mem_peak, mem_peak): - sorted_init_mem_peak = sorted(init_mem_peak) - if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]: - return True - return False - def search_region(self) -> Dict: """ Search all chunk regions: @@ -307,11 +301,7 @@ def search_region(self) -> Dict: get_logger().info("AutoChunk start searching chunk regions") chunk_infos = [] - ( - init_mem_peak, - _, - active_node, - ) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list()) + init_mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list()) mem_peak = init_mem_peak while True: @@ -320,18 +310,13 @@ def search_region(self) -> Dict: break chunk_infos.append(chunk_info) - ( - mem_peak, - _, - active_node, - ) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos) + mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem( + self.node_mgr.get_node_list(), chunk_infos) if self.print_progress: get_logger().info("AutoChunk find chunk region %d = (%d, %d)" % (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])) - if self._stop_search(init_mem_peak, mem_peak): - break if self.print_mem: self.print_mem = False self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), From 0f9b0f5c3d6c2d25fed5026fd666c8372a315fa5 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 6 Feb 2023 17:52:53 +0800 Subject: [PATCH 28/30] optimize imports --- colossalai/autochunk/search_chunk.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 9a0d1120bdc7..eb99490957aa 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -8,14 +8,7 @@ from .select_chunk import SelectChunk from .trace_flow import TraceFlow from .trace_indice import TraceIndice -from .utils import ( - NodeMgr, - find_chunk_compute_input_and_output_nodes, - get_logger, - get_node_shape, - is_non_compute_node, - is_non_compute_node_except_placeholder, -) +from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder class SearchChunk(object): From 46d4b86003efb6fcbe4907b5a4b5bf6181018783 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 6 Feb 2023 17:53:55 +0800 Subject: [PATCH 29/30] import --- colossalai/autochunk/autochunk_codegen.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 3432ca9afba4..90bde8730052 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -9,18 +9,7 @@ AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta() if AUTOCHUNK_AVAILABLE: - from torch.fx.graph import ( - CodeGen, - PythonCode, - _custom_builtins, - _CustomBuiltin, - _format_target, - _is_from_torch, - _Namespace, - _origin_type_map, - inplace_methods, - magic_methods, - ) + from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg From 24bdbf9bc1bf2b75483476ef0f9d7b1556797d10 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 7 Feb 2023 15:12:55 +0800 Subject: [PATCH 30/30] finish autochunk --- .../test_autochunk_diffuser/test_autochunk_unet.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index 7d1aab2b19be..518c7f45124d 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -17,7 +17,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE -BATCH_SIZE = 2 +BATCH_SIZE = 1 HEIGHT = 448 WIDTH = 448 IN_CHANNELS = 3 @@ -33,10 +33,6 @@ def get_data(shape: tuple) -> Tuple[List, List]: return meta_args, concrete_args -@pytest.mark.skipif( - True, - reason="not implemented", -) @pytest.mark.skipif( not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0",