Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""

input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))

# construct forward args for flop mapping
fwd_in_args = [input_op_data.data, other_op_data.data]
fwd_in_args = [opdata.data for opdata in input_op_data]
fwd_out_args = [output_op_data.data]

# calculate cost

# calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)

# calculate memory cost
param_mem_cost = activation_size(
[arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM])
param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
fwd_mem_cost = MemoryCost(
activation=activation_size([input_op_data.data, output_op_data.data]),
activation=activation_size(output_op_data.data),
parameter=param_mem_cost,
)
bwd_mem_cost = MemoryCost(
Expand Down
46 changes: 27 additions & 19 deletions colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch.fx.node import Node

from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
Expand Down Expand Up @@ -234,15 +234,19 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)

# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)

# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)

return self.strategies_vector

Expand Down Expand Up @@ -281,14 +285,18 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)

# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)

# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)

return self.strategies_vector