From 9cabb7207173ba9c930adc441fb7a41e4f3d17ba Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 3 Jan 2023 10:26:19 +0800 Subject: [PATCH 1/7] [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline --- .../auto_parallel/meta_profiler/constants.py | 5 +- .../meta_registry/binary_elementwise_ops.py | 2 +- .../auto_parallel/meta_profiler/metainfo.py | 4 +- colossalai/auto_parallel/passes/constants.py | 8 ++ .../auto_parallel/passes/meta_info_prop.py | 75 +++++++++---------- 5 files changed, 51 insertions(+), 43 deletions(-) create mode 100644 colossalai/auto_parallel/passes/constants.py diff --git a/colossalai/auto_parallel/meta_profiler/constants.py b/colossalai/auto_parallel/meta_profiler/constants.py index 714674b7b425..35b8c13ee8ff 100644 --- a/colossalai/auto_parallel/meta_profiler/constants.py +++ b/colossalai/auto_parallel/meta_profiler/constants.py @@ -5,8 +5,11 @@ from ..tensor_shard.constants import * -# list of inplace operations +# list of inplace module INPLACE_MODULE = [nn.ReLU] +# list of inplace operations +INPLACE_OPS = [torch.flatten] + # list of operations that do not save forward activations NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub] diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index b4cc58d05c0c..15c3063b759b 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -60,7 +60,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_op_data.data, device='meta'), torch.zeros_like(other_op_data.data, device='meta')] + fwd_in = [] fwd_buffer = [] fwd_out = [torch.zeros_like(output_op_data.data, device='meta')] diff --git a/colossalai/auto_parallel/meta_profiler/metainfo.py b/colossalai/auto_parallel/meta_profiler/metainfo.py index ff76e3059fef..218187768a7b 100644 --- a/colossalai/auto_parallel/meta_profiler/metainfo.py +++ b/colossalai/auto_parallel/meta_profiler/metainfo.py @@ -12,7 +12,7 @@ ) from colossalai.tensor.sharding_spec import ShardingSpec -from .constants import INPLACE_MODULE, NO_SAVE_ACTIVATION +from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION from .registry import meta_register __all__ = ['MetaInfo'] @@ -104,6 +104,8 @@ def compute_metainfo(self): # construct kwargs if self.target in INPLACE_MODULE: kwargs = {'inplace': self.target.inplace} + elif self.target in INPLACE_OPS: + kwargs = {'inplace': True} else: kwargs = {'inplace': False} diff --git a/colossalai/auto_parallel/passes/constants.py b/colossalai/auto_parallel/passes/constants.py new file mode 100644 index 000000000000..b86088474644 --- /dev/null +++ b/colossalai/auto_parallel/passes/constants.py @@ -0,0 +1,8 @@ +import torch + +OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten] + +OUTPUT_SAVED_MOD = [ + torch.nn.ReLU, + torch.nn.Softmax, +] diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index 607f7e17ec73..bdeaeffedc66 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -8,9 +8,9 @@ from torch.fx.node import Node from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS from colossalai.fx._compatibility import compatibility from colossalai.fx.profiler import GraphInfo -from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS def _normalize_tuple(x): @@ -46,7 +46,7 @@ def _is_inplace(self, node: Node): """ Check if the node is inplace operation. """ - if node.op == 'call_method': + if node.op == 'call_module': return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD elif node.op == "call_function": return node.target in OUTPUT_SAVED_OPS @@ -102,56 +102,51 @@ def node_handler(self, node: Node) -> None: meta_info: MetaInfo # set data_ptr for input_tensor in MetaInfo class - input_tensor: List[torch.Tensor] = meta_info.fwd_in - buffer_tensor: List[torch.Tensor] = meta_info.fwd_buffer - output_tensor: List[torch.Tensor] = meta_info.fwd_out + input_tensors: List[torch.Tensor] = meta_info.fwd_in + buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer + output_tensors: List[torch.Tensor] = meta_info.fwd_out - if len(input_tensor) > 0: + if self._is_inplace(node): + # inplace operation will not create new tensor, and it only has one parent node + # TODO: Verify this observation + # set data_ptr for input_tensor, buffer_tensor and output_tensor of current node + parent_node = list(node._input_nodes.keys())[0] + parent_tensor = parent_node.meta.get("fwd_out")[0] + parent_tensor: torch.Tensor + for tensor in input_tensors: + tensor.data_ptr = parent_tensor.data_ptr + for tensor in buffer_tensors: + tensor.data_ptr = parent_tensor.data_ptr + for tensor in output_tensors: + tensor.data_ptr = parent_tensor.data_ptr + + else: for par in node._input_nodes: - if par.meta: - if len(par.meta["fwd_out"]) > 0: - # set data_ptr for the input_tensor of current node from the output_tensor of its parent node - for tensor in par.meta["fwd_out"]: - tensor: torch.Tensor - target_tensor = next( - (x for x in input_tensor if not x.data_ptr() and x.shape == tensor.shape), None) - target_tensor.data_ptr = tensor.data_ptr + # set data_ptr for the input_tensor of current node from the output_tensor of its parent node + for tensor in par.meta.get("fwd_out", []): + tensor: torch.Tensor + target_input_tensor = next( + (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None) + if target_input_tensor is not None: + target_input_tensor.data_ptr = tensor.data_ptr # set data_ptr for tensor in input_tensor that is not set - for tensor in input_tensor: + for tensor in input_tensors: if not tensor.data_ptr(): self._set_data_ptr(tensor) - # attach it to graph_info - graph_info.fwd_in = input_tensor - - if self._is_inplace(node): - # inplace operation will not create new tensor - # set data_ptr for buffer_tensor and output_tensor of current node - for tensor in input_tensor: - tensor: torch.Tensor - target_buffer_tensor = next((x for x in buffer_tensor if not x.data_ptr() and x.shape == tensor.shape), - None) - target_output_tensor = next((x for x in output_tensor if not x.data_ptr() and x.shape == tensor.shape), - None) - target_buffer_tensor.data_ptr = tensor.data_ptr - target_output_tensor.data_ptr = tensor.data_ptr - # attach them to graph_info - graph_info.fwd_tmp = buffer_tensor - graph_info.fwd_out = output_tensor - - else: # set data_ptr for buffer_tensor - for tensor in buffer_tensor: + for tensor in buffer_tensors: self._set_data_ptr(tensor) - # attach it to graph_info - graph_info.fwd_tmp = buffer_tensor # set data_ptr for output_tensor - for tensor in output_tensor: + for tensor in output_tensors: self._set_data_ptr(tensor) - # attach it to graph_info - graph_info.fwd_out = output_tensor + + # attach them to graph_info + graph_info.fwd_in = input_tensors + graph_info.fwd_tmp = buffer_tensors + graph_info.fwd_out = output_tensors # fetch other memory informations memory_cost = meta_info.memory_cost From ad5ab65a1d24dfba904b016da5a052947a0786ba Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 3 Jan 2023 10:39:34 +0800 Subject: [PATCH 2/7] [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop --- colossalai/fx/profiler/shard_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/fx/profiler/shard_utils.py b/colossalai/fx/profiler/shard_utils.py index a765e5055b28..34feefb4336a 100644 --- a/colossalai/fx/profiler/shard_utils.py +++ b/colossalai/fx/profiler/shard_utils.py @@ -100,7 +100,7 @@ def calculate_fwd_time(n: Node) -> float: fwd_time (float): the result of `fwd_time` """ # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs - return n.meta["fwd_flop"] + return n.meta["fwd_time"] def calculate_bwd_time(n: Node) -> float: @@ -111,4 +111,4 @@ def calculate_bwd_time(n: Node) -> float: bwd_time (float): the result of `bwd_time` """ # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs - return n.meta["bwd_flop"] + return n.meta["bwd_time"] From f29bbba5122f28e3e24bf847e29143ef728f1fde Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 3 Jan 2023 11:35:01 +0800 Subject: [PATCH 3/7] [autoparallel] specifycomm nodes' memory cost in construct chain --- .../auto_parallel/checkpoint/ckpt_solver_rotor.py | 11 +++++++++-- colossalai/auto_parallel/passes/meta_info_prop.py | 1 + 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py index 6ef53c9d1380..cd5b70d110dc 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -4,6 +4,7 @@ from torch import Tensor from torch.fx import Graph, Node +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions from colossalai.fx.profiler import ( activation_size, @@ -131,8 +132,14 @@ def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]: fwd_mem_peak = 0 for n in node: assert isinstance(n, Node), f'{n} is not a Node' - xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) - fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)) + if n.target == runtime_apply or n.target == runtime_comm_spec_apply: + # in this case we need to calculate memory usage directly based on the statics that hooked in node.meta + xbar += n.meta['fwd_mem_out'] + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp']) + else: + xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)) + # minimum flop count is required ftime += max(calculate_fwd_time(n), 1.0) btime += max(calculate_bwd_time(n), 1.0) diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index bdeaeffedc66..f7e07ef1ec18 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -151,6 +151,7 @@ def node_handler(self, node: Node) -> None: # fetch other memory informations memory_cost = meta_info.memory_cost graph_info.fwd_mem_tmp = memory_cost.fwd.temp + graph_info.fwd_mem_out = memory_cost.fwd.activation graph_info.bwd_mem_tmp = memory_cost.bwd.temp graph_info.bwd_mem_out = memory_cost.bwd.activation From f6c0b28dff27cdc4d4262b7e9572c4106dbc13e3 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 3 Jan 2023 16:59:38 +0800 Subject: [PATCH 4/7] [autoparallel] fix wrong runtime apply calculation --- colossalai/tensor/shape_consistency.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index daf81034f384..e3f8fbdb55b2 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -441,6 +441,8 @@ def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, if discard_input: alloc_numel -= input_numel + return alloc_numel, peak_numel + def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """analyze split memory footprint split will allocate memory for the output tensor if we don't apply shard on the first dimension of @@ -478,11 +480,13 @@ def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, p # kind of weird, and I think we could ignore it for now. pass + return alloc_numel, peak_numel + def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """ a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory """ - pass + return alloc_numel, peak_numel def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """analyze all_to_all memory footprint @@ -508,11 +512,13 @@ def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, if discard_input: alloc_numel -= input_numel + return alloc_numel, peak_numel + def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """ a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory """ - pass + return alloc_numel, peak_numel pattern_to_func_dict = { CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis], @@ -540,9 +546,9 @@ def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int # the first forward comm action will not discard input fwd_action, comm_spec = action_spec_pair if idx == 0: - fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) + fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) else: - fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel) # analyze memory footprint for backward comm actions sequence bwd_alloc_numel = 0 From 9d58c96d1d5b7fe9ce778380f95239812da87b16 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 3 Jan 2023 17:00:43 +0800 Subject: [PATCH 5/7] [autoparallel] fix wrong runtime apply calculation --- colossalai/tensor/shape_consistency.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index e3f8fbdb55b2..6118ee9a612e 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -555,7 +555,7 @@ def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int bwd_peak_numel = 0 for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): bwd_action, comm_spec = action_spec_pair - bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel) fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel) From 5a3628dca0fa4bba7b6dcaf762e01d99fefa0c00 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 3 Jan 2023 17:14:12 +0800 Subject: [PATCH 6/7] [autoparallel] fix wrong runtime apply calculation --- colossalai/tensor/shape_consistency.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 6118ee9a612e..2831b10a3c57 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -545,17 +545,18 @@ def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)): # the first forward comm action will not discard input fwd_action, comm_spec = action_spec_pair - if idx == 0: - fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) - else: - fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel, + fwd_peak_numel) if idx == 0 else fwd_action( + comm_spec, True, fwd_alloc_numel, fwd_peak_numel) # analyze memory footprint for backward comm actions sequence bwd_alloc_numel = 0 bwd_peak_numel = 0 for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): bwd_action, comm_spec = action_spec_pair - bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel, + bwd_peak_numel) if idx == 0 else bwd_action( + comm_spec, True, bwd_alloc_numel, bwd_peak_numel) fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel) From 259bee11be223ad6b88107263d92d05b0fcfc8d0 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 3 Jan 2023 20:22:51 +0800 Subject: [PATCH 7/7] [autoparallel] bypass metainfo when available and modify BCAST_FUNC_OP --- .../meta_registry/binary_elementwise_ops.py | 11 ++--- .../tensor_shard/node_handler/node_handler.py | 46 +++++++++++-------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index 15c3063b759b..281a92c0d4f1 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs """ - input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] + input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args)) # construct forward args for flop mapping - fwd_in_args = [input_op_data.data, other_op_data.data] + fwd_in_args = [opdata.data for opdata in input_op_data] fwd_out_args = [output_op_data.data] # calculate cost # calculate compute cost # NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case - fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args) + fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args) bwd_compute_cost = fwd_compute_cost * 2 compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # calculate memory cost - param_mem_cost = activation_size( - [arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM]) + param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM]) fwd_mem_cost = MemoryCost( - activation=activation_size([input_op_data.data, output_op_data.data]), + activation=activation_size(output_op_data.data), parameter=param_mem_cost, ) bwd_mem_cost = MemoryCost( diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index af3cb5810d11..78dc58c905ec 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -4,7 +4,7 @@ import torch from torch.fx.node import Node -from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo +from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -234,15 +234,19 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV """ super().register_strategy(compute_resharding_cost=compute_resharding_cost) target = self.get_target_function() - metainfo_vector = [] - for strategy in self.strategies_vector: - metainfo = MetaInfo(strategy, target) - strategy.compute_cost = metainfo.compute_cost - strategy.memory_cost = metainfo.memory_cost - metainfo_vector.append(metainfo) - - # attach metainfos to the handler - setattr(self, "metainfo_vector", metainfo_vector) + # Currently we haven't patched all the torch functions and modules, so if the target + # is not patched, we will use the default cost model to compute the cost. + # TODO: patch all torch functions and modules to make it clean + if meta_register.has(target.__class__) or meta_register.has(target): + metainfo_vector = [] + for strategy in self.strategies_vector: + metainfo = MetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + metainfo_vector.append(metainfo) + + # attach metainfos to the handler + setattr(self, "metainfo_vector", metainfo_vector) return self.strategies_vector @@ -281,14 +285,18 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV """ super().register_strategy(compute_resharding_cost=compute_resharding_cost) target = self.get_target_function() - metainfo_vector = [] - for strategy in self.strategies_vector: - metainfo = MetaInfo(strategy, target) - strategy.compute_cost = metainfo.compute_cost - strategy.memory_cost = metainfo.memory_cost - metainfo_vector.append(metainfo) - - # attach metainfos to the handler - setattr(self, "metainfo_vector", metainfo_vector) + # Currently we haven't patched all the torch functions and modules, so if the target + # is not patched, we will use the default cost model to compute the cost. + # TODO: patch all torch functions and modules to make it clean + if meta_register.has(target.__class__) or meta_register.has(target): + metainfo_vector = [] + for strategy in self.strategies_vector: + metainfo = MetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + metainfo_vector.append(metainfo) + + # attach metainfos to the handler + setattr(self, "metainfo_vector", metainfo_vector) return self.strategies_vector