From 29a61d2a00094d0b6be15ba782c5a02a53abd1bd Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Sun, 19 Feb 2023 16:02:16 +0800 Subject: [PATCH 1/2] [autoparallel] non spmd meta information generator --- .../meta_profiler/meta_registry/non_spmd.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py new file mode 100644 index 000000000000..4634d3ccdcfd --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py @@ -0,0 +1,29 @@ +import operator +from typing import List, Tuple + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping + +from ..registry import meta_register + +__all__ = ["non_spmd_meta_info"] + + +@meta_register.register(torch.Size) +@meta_register.register(torch.Tensor.size) +@meta_register.register(torch.finfo) +@meta_register.register(operator.le) +def non_spmd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """Non-SPMD node meta information generator + Those nodes will not be handled by SPMD solver, so we just return all zero meta information for it + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0) + memory_cost = TrainCycleItem(fwd=MemoryCost(), bwd=MemoryCost(), total=MemoryCost()) + fwd_in, fwd_buffer, fwd_out = [], [], [] + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out From 88453536fe06a276ef67371833e8f970e5e86a4d Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Mon, 20 Feb 2023 19:28:51 +0800 Subject: [PATCH 2/2] [autoparallel] patch meta information for non spmd nodes --- colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py index 359590c1fc04..9a1b15a4e463 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py @@ -3,5 +3,6 @@ from .conv import * from .embedding import * from .linear import * +from .non_spmd import * from .norm import * from .pooling import *