From a7a254db32769c423fbfb58918c59c4c1e5d146e Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Thu, 16 Feb 2023 22:02:56 +0800 Subject: [PATCH 1/3] [autoparallel] tanh meta information --- .../meta_profiler/meta_registry/activation.py | 294 +++++++++++------- .../test_metainfo/test_activation_metainfo.py | 56 +--- 2 files changed, 191 insertions(+), 159 deletions(-) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index c659cd9ac389..457478144ba2 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -1,124 +1,194 @@ -from typing import List, Tuple +from typing import Callable, List, Tuple import torch from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping +from colossalai.fx.profiler.opcount import elementwise_flop_counter from ..registry import meta_register -__all__ = ["relu_meta_info"] - - -@meta_register.register(torch.nn.ReLU) -def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: - """torch.nn.ReLU metainfo generator - The aten graph of torch.nn.ReLU is - graph(): - %input_2 : [#users=1] = placeholder[target=placeholder](default=) - %relu_default : [#users=2] = call_function[target=torch.ops.aten.relu.default](args = (%input_2,), kwargs = {}) - %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%relu_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) - %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%relu_default,), kwargs = {}) - %threshold_backward_default : [#users=1] = call_function[target=torch.ops.aten.threshold_backward.default](args = (%zeros_like_default, %detach_default, None), kwargs = {}) - %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%threshold_backward_default,), kwargs = {}) - %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) +# __all__ = ["relu_meta_info"] +__all__ = ["elementwise_meta_info"] + +# @meta_register.register(torch.nn.ReLU) +# def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: +# """torch.nn.ReLU metainfo generator +# The aten graph of torch.nn.ReLU is +# graph(): +# %input_2 : [#users=1] = placeholder[target=placeholder](default=) +# %relu_default : [#users=2] = call_function[target=torch.ops.aten.relu.default](args = (%input_2,), kwargs = {}) +# %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%relu_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) +# %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%relu_default,), kwargs = {}) +# %threshold_backward_default : [#users=1] = call_function[target=torch.ops.aten.threshold_backward.default](args = (%zeros_like_default, %detach_default, None), kwargs = {}) +# %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%threshold_backward_default,), kwargs = {}) +# %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + +# Returns: +# Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs +# """ + +# input_tensor = args[0].data +# output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data +# is_inplace = kwargs.get("inplace", False) + +# # construct input args for forward +# fwd_in_args = [input_tensor] + +# # construct input args for backward +# bwd_in_args = [output_tensor] + +# # calculate cost +# # the fwd op with compute cost is relu.default +# # the bwd op with compute cost is threshold_backward + +# # calculate compute cost +# fwd_compute_cost = flop_mapping[torch.ops.aten.relu.default](fwd_in_args, (output_tensor,)) +# bwd_compute_cost = flop_mapping[torch.ops.aten.threshold_backward.default](bwd_in_args, (input_tensor,)) +# compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + +# # calculate memory cost +# # NOTE: the inplace ReLU don't have forward memory cost +# # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward +# fwd_memory_cost = MemoryCost( +# activation=activation_size(input_tensor) if is_inplace else activation_size([output_tensor, input_tensor]), +# parameter=0, +# temp=0, +# buffer=0) + +# bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), parameter=0, temp=0, buffer=0) + +# # total cost is the sum of forward and backward cost +# total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, +# parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + +# memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + +# # store fwd_in, fwd_buffer, fwd_out +# # NOTE: It might seems a little bit weird here, we just want to align it with the older version +# # of MetaInfoProp. In the future we might modify this part to make it clearer. +# fwd_in = [] +# fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] +# fwd_out = [torch.zeros_like(output_tensor, device='meta')] + +# return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out + +# @meta_register.register(torch.nn.Softmax) +# @meta_register.register(torch.nn.functional.softmax) +# def softmax_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: +# """torch.nn.Softmax metainfo generator +# Returns: +# Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs +# """ +# input_tensor = next( +# filter( +# lambda x: +# (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim', +# args)).data +# output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data +# softmax_dim = next(filter(lambda x: x.name == 'softmax_dim', args)).data + +# # calculate cost + +# # calculate compute cost +# fwd_compute_cost = flop_mapping[torch.ops.aten._softmax.default]([input_tensor], [output_tensor]) +# bwd_compute_cost = flop_mapping[torch.ops.aten._softmax_backward_data.default]([output_tensor], [input_tensor]) + +# compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + +# # calculate memory cost +# # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward +# fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), +# parameter=0, +# temp=0, +# buffer=0) +# bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), +# parameter=0, +# temp=activation_size(input_tensor), +# buffer=0) + +# # total cost is the sum of forward and backward cost +# total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, +# parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, +# temp=fwd_memory_cost.temp + bwd_memory_cost.temp, +# buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) + +# memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + +# # store fwd_in, fwd_buffer, fwd_out +# fwd_in = [] +# fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] +# fwd_out = [torch.zeros_like(output_tensor, device='meta')] + +# return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out + + +def elementwise_meta_info(temp_mem_scale: float = 0) -> Callable: + """This is a function to create the meta information generator for elementwise operations + + Args: + temp_mem_scale (float, optional): temp memory scaling factor. Defaults to 0. Returns: - Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + Callable: meta information generator """ - input_tensor = args[0].data - output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - is_inplace = kwargs.get("inplace", False) - - # construct input args for forward - fwd_in_args = [input_tensor] - - # construct input args for backward - bwd_in_args = [output_tensor] - - # calculate cost - # the fwd op with compute cost is relu.default - # the bwd op with compute cost is threshold_backward - - # calculate compute cost - fwd_compute_cost = flop_mapping[torch.ops.aten.relu.default](fwd_in_args, (output_tensor,)) - bwd_compute_cost = flop_mapping[torch.ops.aten.threshold_backward.default](bwd_in_args, (input_tensor,)) - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) - - # calculate memory cost - # NOTE: the inplace ReLU don't have forward memory cost - # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost( - activation=activation_size(input_tensor) if is_inplace else activation_size([output_tensor, input_tensor]), - parameter=0, - temp=0, - buffer=0) - - bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), parameter=0, temp=0, buffer=0) - - # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) - - memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) - - # store fwd_in, fwd_buffer, fwd_out - # NOTE: It might seems a little bit weird here, we just want to align it with the older version - # of MetaInfoProp. In the future we might modify this part to make it clearer. - fwd_in = [] - fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] - - return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out - - -@meta_register.register(torch.nn.Softmax) -@meta_register.register(torch.nn.functional.softmax) -def softmax_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: - """torch.nn.Softmax metainfo generator - Returns: - Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs - """ - input_tensor = next( - filter( - lambda x: - (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim', - args)).data - output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - softmax_dim = next(filter(lambda x: x.name == 'softmax_dim', args)).data - - # calculate cost - - # calculate compute cost - fwd_compute_cost = flop_mapping[torch.ops.aten._softmax.default]([input_tensor], [output_tensor]) - bwd_compute_cost = flop_mapping[torch.ops.aten._softmax_backward_data.default]([output_tensor], [input_tensor]) - - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) - - # calculate memory cost - # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), - parameter=0, - temp=0, - buffer=0) - bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), - parameter=0, - temp=activation_size(input_tensor), - buffer=0) - - # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, - temp=fwd_memory_cost.temp + bwd_memory_cost.temp, - buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) - - memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) - - # store fwd_in, fwd_buffer, fwd_out - fwd_in = [] - fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] - - return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out + def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + input_tensor = next( + filter( + lambda x: + (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim', + args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + is_inplace = 1 if kwargs.get('inplace', False) else 0 + + flop_counter = elementwise_flop_counter(1, 0) + # calculate compute cost + fwd_compute_cost = flop_counter([input_tensor], [output_tensor]) + bwd_compute_cost = flop_counter([output_tensor], [input_tensor]) + + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + # NOTE: if in_place is True, we will not create a new tensor in forward + fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace), + parameter=0, + temp=0, + buffer=0) + + # temp_mem_scale is for situation like softmax backward + bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), + parameter=0, + temp=activation_size(input_tensor) * temp_mem_scale, + buffer=0) + + # total cost is the sum of forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + temp=fwd_memory_cost.temp + bwd_memory_cost.temp, + buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [] + fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out + + return meta_func + + +# the following elementwise ops doesn't have temp memory during backward +zero_temp_mem_ops = [torch.nn.ReLU, torch.nn.functional.relu, torch.tanh] + +# the following elementwise ops have temp memory the same size as input during backward +one_temp_mem_ops = [torch.nn.Softmax, torch.nn.functional.softmax] + +# register meta information +meta_register.register(zero_temp_mem_ops)(elementwise_meta_info()) +meta_register.register(one_temp_mem_ops)(elementwise_meta_info(1)) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py index b9b42f8c161d..b10de379b2ca 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -17,51 +17,14 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results -def _ReLU_module_mem_test(rank, world_size, port): - """This function is for ReLU memory test - Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL - - Args: - Args: - rank: device rank - bias: indicate whether conv module need bias - world_size: number of devices - port: port for initializing process group - """ - disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - model = nn.Sequential(nn.ReLU()).cuda() - input = torch.rand(4, 128, 64, 64).cuda() - input.requires_grad = True - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - - # index of target node in computation graph - node_index = 1 - # total number of target node strategies - strategy_number = 1 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) - - -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_ReLU_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_ReLU_module_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) - - @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") -def test_sofmax_meta_info(): - meta_func = meta_register.get(torch.nn.functional.softmax) +@parameterize('func', [ + torch.nn.functional.softmax, + torch.nn.functional.relu, + torch.tanh, +]) +def test_activation_meta_info(func): + meta_func = meta_register.get(func) # construct meta tensors input_tensor = torch.rand(256, 1024, device="meta") output_tensor = torch.rand(256, 1024, device="meta") @@ -87,7 +50,7 @@ def test_sofmax_meta_info(): # fwd torch.cuda.reset_peak_memory_stats() mem_stamp0 = torch.cuda.memory_allocated() - output_real_tensor = torch.nn.functional.softmax(input_real_tensor, dim=softmax_dim) + output_real_tensor = func(input_real_tensor) fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 @@ -104,5 +67,4 @@ def test_sofmax_meta_info(): if __name__ == '__main__': - # test_ReLU_meta_concrete_info_match() - test_sofmax_meta_info() + test_activation_meta_info() From e37062d3321765c7417e7574b6fcd6c0a2fc8d0e Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Thu, 16 Feb 2023 22:03:43 +0800 Subject: [PATCH 2/3] [autoparallel] remove redundant code --- .../meta_profiler/meta_registry/activation.py | 112 ------------------ 1 file changed, 112 deletions(-) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index 457478144ba2..7d9ecd5cdc39 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -8,120 +8,8 @@ from ..registry import meta_register -# __all__ = ["relu_meta_info"] __all__ = ["elementwise_meta_info"] -# @meta_register.register(torch.nn.ReLU) -# def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: -# """torch.nn.ReLU metainfo generator -# The aten graph of torch.nn.ReLU is -# graph(): -# %input_2 : [#users=1] = placeholder[target=placeholder](default=) -# %relu_default : [#users=2] = call_function[target=torch.ops.aten.relu.default](args = (%input_2,), kwargs = {}) -# %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%relu_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) -# %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%relu_default,), kwargs = {}) -# %threshold_backward_default : [#users=1] = call_function[target=torch.ops.aten.threshold_backward.default](args = (%zeros_like_default, %detach_default, None), kwargs = {}) -# %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%threshold_backward_default,), kwargs = {}) -# %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) - -# Returns: -# Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs -# """ - -# input_tensor = args[0].data -# output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data -# is_inplace = kwargs.get("inplace", False) - -# # construct input args for forward -# fwd_in_args = [input_tensor] - -# # construct input args for backward -# bwd_in_args = [output_tensor] - -# # calculate cost -# # the fwd op with compute cost is relu.default -# # the bwd op with compute cost is threshold_backward - -# # calculate compute cost -# fwd_compute_cost = flop_mapping[torch.ops.aten.relu.default](fwd_in_args, (output_tensor,)) -# bwd_compute_cost = flop_mapping[torch.ops.aten.threshold_backward.default](bwd_in_args, (input_tensor,)) -# compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) - -# # calculate memory cost -# # NOTE: the inplace ReLU don't have forward memory cost -# # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward -# fwd_memory_cost = MemoryCost( -# activation=activation_size(input_tensor) if is_inplace else activation_size([output_tensor, input_tensor]), -# parameter=0, -# temp=0, -# buffer=0) - -# bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), parameter=0, temp=0, buffer=0) - -# # total cost is the sum of forward and backward cost -# total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, -# parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) - -# memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) - -# # store fwd_in, fwd_buffer, fwd_out -# # NOTE: It might seems a little bit weird here, we just want to align it with the older version -# # of MetaInfoProp. In the future we might modify this part to make it clearer. -# fwd_in = [] -# fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] -# fwd_out = [torch.zeros_like(output_tensor, device='meta')] - -# return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out - -# @meta_register.register(torch.nn.Softmax) -# @meta_register.register(torch.nn.functional.softmax) -# def softmax_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: -# """torch.nn.Softmax metainfo generator -# Returns: -# Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs -# """ -# input_tensor = next( -# filter( -# lambda x: -# (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim', -# args)).data -# output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data -# softmax_dim = next(filter(lambda x: x.name == 'softmax_dim', args)).data - -# # calculate cost - -# # calculate compute cost -# fwd_compute_cost = flop_mapping[torch.ops.aten._softmax.default]([input_tensor], [output_tensor]) -# bwd_compute_cost = flop_mapping[torch.ops.aten._softmax_backward_data.default]([output_tensor], [input_tensor]) - -# compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) - -# # calculate memory cost -# # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward -# fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), -# parameter=0, -# temp=0, -# buffer=0) -# bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), -# parameter=0, -# temp=activation_size(input_tensor), -# buffer=0) - -# # total cost is the sum of forward and backward cost -# total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, -# parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, -# temp=fwd_memory_cost.temp + bwd_memory_cost.temp, -# buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) - -# memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) - -# # store fwd_in, fwd_buffer, fwd_out -# fwd_in = [] -# fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] -# fwd_out = [torch.zeros_like(output_tensor, device='meta')] - -# return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out - def elementwise_meta_info(temp_mem_scale: float = 0) -> Callable: """This is a function to create the meta information generator for elementwise operations From f0416e719a2332ad87040dd3bceffff5aa4d333e Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Thu, 16 Feb 2023 22:26:31 +0800 Subject: [PATCH 3/3] [autoparallel] patch meta information of torch.nn.Dropout --- .../meta_profiler/meta_registry/activation.py | 31 ++++++++++--------- .../test_metainfo/test_activation_metainfo.py | 1 + 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index 7d9ecd5cdc39..faeed9f29e61 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -11,11 +11,12 @@ __all__ = ["elementwise_meta_info"] -def elementwise_meta_info(temp_mem_scale: float = 0) -> Callable: +def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0) -> Callable: """This is a function to create the meta information generator for elementwise operations Args: - temp_mem_scale (float, optional): temp memory scaling factor. Defaults to 0. + temp_mem_scale (float, optional): temp memory scaling factor for backward. Defaults to 0. + buffer_mem_scale (float, optional): buffer memory scaling factor for forward. Defaults to 0. Returns: Callable: meta information generator @@ -45,13 +46,15 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace), parameter=0, temp=0, - buffer=0) + buffer=activation_size(input_tensor) * buffer_mem_scale) # temp_mem_scale is for situation like softmax backward - bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), - parameter=0, - temp=activation_size(input_tensor) * temp_mem_scale, - buffer=0) + # the buffer will be removed during backward phase + bwd_memory_cost = MemoryCost( + activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale, + parameter=0, + temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale, + buffer=0) # total cost is the sum of forward and backward cost total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, @@ -71,12 +74,12 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor return meta_func -# the following elementwise ops doesn't have temp memory during backward -zero_temp_mem_ops = [torch.nn.ReLU, torch.nn.functional.relu, torch.tanh] +# register meta information +# (0, 0) +meta_register.register([torch.nn.ReLU, torch.nn.functional.relu, torch.tanh])(elementwise_meta_info(0, 0)) -# the following elementwise ops have temp memory the same size as input during backward -one_temp_mem_ops = [torch.nn.Softmax, torch.nn.functional.softmax] +# (1, 0) +meta_register.register([torch.nn.Softmax, torch.nn.functional.softmax])(elementwise_meta_info(1, 0)) -# register meta information -meta_register.register(zero_temp_mem_ops)(elementwise_meta_info()) -meta_register.register(one_temp_mem_ops)(elementwise_meta_info(1)) +# (0, 0.25) for dropout, the buffer is in bool type so that the buffer memory cost is 0.25 times of input tensor +meta_register.register([torch.nn.Dropout, torch.nn.functional.dropout])(elementwise_meta_info(0, 0.25)) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py index b10de379b2ca..e41ac4fa690b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -22,6 +22,7 @@ torch.nn.functional.softmax, torch.nn.functional.relu, torch.tanh, + torch.nn.functional.dropout, ]) def test_activation_meta_info(func): meta_func = meta_register.get(func)