From 6f39d4c0246067ce26703f5bdb708560f6a54a36 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Thu, 16 Feb 2023 17:04:50 +0800 Subject: [PATCH 1/3] [autoparallel] embedding metainfo --- .../meta_profiler/meta_registry/__init__.py | 1 + .../meta_profiler/meta_registry/activation.py | 2 +- .../meta_profiler/meta_registry/embedding.py | 52 +++++++++++++ .../test_metainfo/test_activation_metainfo.py | 2 +- .../test_metainfo/test_embedding_metainfo.py | 77 +++++++++++++++++++ 5 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py index aa5f77f6591e..359590c1fc04 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py @@ -1,6 +1,7 @@ from .activation import * from .binary_elementwise_ops import * from .conv import * +from .embedding import * from .linear import * from .norm import * from .pooling import * diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index c659cd9ac389..b7f7ece75586 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -8,7 +8,7 @@ from ..registry import meta_register -__all__ = ["relu_meta_info"] +__all__ = ["relu_meta_info", "softmax_meta_info"] @meta_register.register(torch.nn.ReLU) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py new file mode 100644 index 000000000000..2997f31adff8 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py @@ -0,0 +1,52 @@ +from typing import List, Tuple + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping + +from ..registry import meta_register + +__all__ = ["embedding_meta_info"] + + +@meta_register.register(torch.nn.Embedding) +def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.nn.Embedding metainfo generator + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + weight_tensor = next(filter(lambda x: x.type == OperationDataType.PARAM, args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + + # compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor]) + bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor], + [weight_tensor]) + + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # memory cost + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + # NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will + # have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume + # that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory + fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), + parameter=0, + temp=0, + buffer=0) + bwd_memory_cost = MemoryCost(activation=activation_size([weight_tensor]), parameter=0, temp=0, buffer=0) + + total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor)] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_tensor)] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py index b9b42f8c161d..80636e935336 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -104,5 +104,5 @@ def test_sofmax_meta_info(): if __name__ == '__main__': - # test_ReLU_meta_concrete_info_match() + test_ReLU_meta_concrete_info_match() test_sofmax_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py new file mode 100644 index 000000000000..2fb1306546ca --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py @@ -0,0 +1,77 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +def test_embedding_meta_info(): + meta_func = meta_register.get(torch.nn.Embedding) + + # construct meta tensors + input_tensor = torch.randint(0, 50256, (8, 1024), device="meta") + weight_tensor = torch.rand(50257, 1024, device="meta") + output_tensor = torch.rand(8, 1024, 1024, device="meta") + + # construct operation data + input_data = OperationData(name="input", type=OperationDataType.ARG, data=input_tensor) + + weight_data = OperationData(name="weight", type=OperationDataType.PARAM, data=weight_tensor) + + output_data = OperationData(name="output", type=OperationDataType.OUTPUT, data=output_tensor) + + # construct args and kwargs + args = [input_data, weight_data, output_data] + kwargs = {'inplace': False} + + # estimated results + compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) + + # actual results + input_real_tensor = torch.randint(0, 50256, (8, 1024), device="cuda") + embedding_module = torch.nn.Embedding(50257, 1024).cuda() + + # fwd + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output_real_tensor = embedding_module(input_real_tensor) + fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + # bwd + upstream_grad = torch.rand_like(output_real_tensor) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output_real_tensor, upstream_grad) + bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, + bwd_allocated, bwd_peak) + + +if __name__ == '__main__': + test_embedding_meta_info() From 9be91cbda1bf37817a0bdbc130e2461f17861a45 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Thu, 16 Feb 2023 21:32:51 +0800 Subject: [PATCH 2/3] [autoparallel] fix function name in test_activation_metainfo --- .../test_metainfo/test_activation_metainfo.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py index 80636e935336..b124817dbea5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -17,7 +17,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results -def _ReLU_module_mem_test(rank, world_size, port): +def _relu_module_mem_test(rank, world_size, port): """This function is for ReLU memory test Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL @@ -53,9 +53,9 @@ def _ReLU_module_mem_test(rank, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist @rerun_if_address_is_in_use() -def test_ReLU_meta_concrete_info_match(): +def test_relu_meta_concrete_info_match(): world_size = 4 - run_func_module = partial(_ReLU_module_mem_test, world_size=world_size, port=free_port()) + run_func_module = partial(_relu_module_mem_test, world_size=world_size, port=free_port()) mp.spawn(run_func_module, nprocs=world_size) @@ -104,5 +104,5 @@ def test_sofmax_meta_info(): if __name__ == '__main__': - test_ReLU_meta_concrete_info_match() + test_relu_meta_concrete_info_match() test_sofmax_meta_info() From dcd36b3e1328521fc2d100612bfc992b213719f5 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Thu, 16 Feb 2023 22:05:36 +0800 Subject: [PATCH 3/3] [autoparallel] undo changes in activation metainfo and related tests --- .../meta_profiler/meta_registry/activation.py | 2 +- .../test_metainfo/test_activation_metainfo.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index b7f7ece75586..c659cd9ac389 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -8,7 +8,7 @@ from ..registry import meta_register -__all__ = ["relu_meta_info", "softmax_meta_info"] +__all__ = ["relu_meta_info"] @meta_register.register(torch.nn.ReLU) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py index b124817dbea5..b9b42f8c161d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -17,7 +17,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results -def _relu_module_mem_test(rank, world_size, port): +def _ReLU_module_mem_test(rank, world_size, port): """This function is for ReLU memory test Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL @@ -53,9 +53,9 @@ def _relu_module_mem_test(rank, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist @rerun_if_address_is_in_use() -def test_relu_meta_concrete_info_match(): +def test_ReLU_meta_concrete_info_match(): world_size = 4 - run_func_module = partial(_relu_module_mem_test, world_size=world_size, port=free_port()) + run_func_module = partial(_ReLU_module_mem_test, world_size=world_size, port=free_port()) mp.spawn(run_func_module, nprocs=world_size) @@ -104,5 +104,5 @@ def test_sofmax_meta_info(): if __name__ == '__main__': - test_relu_meta_concrete_info_match() + # test_ReLU_meta_concrete_info_match() test_sofmax_meta_info()