Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion colossalai/auto_parallel/meta_profiler/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

from ..tensor_shard.constants import *

# list of inplace operations
# list of inplace module
INPLACE_MODULE = [nn.ReLU]

# list of inplace operations
INPLACE_OPS = [torch.flatten]

# list of operations that do not save forward activations
NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub]
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)

# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_op_data.data, device='meta'), torch.zeros_like(other_op_data.data, device='meta')]
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]

Expand Down
4 changes: 3 additions & 1 deletion colossalai/auto_parallel/meta_profiler/metainfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from colossalai.tensor.sharding_spec import ShardingSpec

from .constants import INPLACE_MODULE, NO_SAVE_ACTIVATION
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register

__all__ = ['MetaInfo']
Expand Down Expand Up @@ -104,6 +104,8 @@ def compute_metainfo(self):
# construct kwargs
if self.target in INPLACE_MODULE:
kwargs = {'inplace': self.target.inplace}
elif self.target in INPLACE_OPS:
kwargs = {'inplace': True}
else:
kwargs = {'inplace': False}

Expand Down
8 changes: 8 additions & 0 deletions colossalai/auto_parallel/passes/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torch

OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten]

OUTPUT_SAVED_MOD = [
torch.nn.ReLU,
torch.nn.Softmax,
]
75 changes: 35 additions & 40 deletions colossalai/auto_parallel/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from torch.fx.node import Node

from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS


def _normalize_tuple(x):
Expand Down Expand Up @@ -46,7 +46,7 @@ def _is_inplace(self, node: Node):
"""
Check if the node is inplace operation.
"""
if node.op == 'call_method':
if node.op == 'call_module':
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS
Expand Down Expand Up @@ -102,56 +102,51 @@ def node_handler(self, node: Node) -> None:
meta_info: MetaInfo

# set data_ptr for input_tensor in MetaInfo class
input_tensor: List[torch.Tensor] = meta_info.fwd_in
buffer_tensor: List[torch.Tensor] = meta_info.fwd_buffer
output_tensor: List[torch.Tensor] = meta_info.fwd_out
input_tensors: List[torch.Tensor] = meta_info.fwd_in
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
output_tensors: List[torch.Tensor] = meta_info.fwd_out

if len(input_tensor) > 0:
if self._is_inplace(node):
# inplace operation will not create new tensor, and it only has one parent node
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Operations like flatten will change the shape of the tensor, so in the in-place operation I modify the mechanism to set the data_ptr.

# TODO: Verify this observation
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node
parent_node = list(node._input_nodes.keys())[0]
parent_tensor = parent_node.meta.get("fwd_out")[0]
parent_tensor: torch.Tensor
for tensor in input_tensors:
tensor.data_ptr = parent_tensor.data_ptr
for tensor in buffer_tensors:
tensor.data_ptr = parent_tensor.data_ptr
for tensor in output_tensors:
tensor.data_ptr = parent_tensor.data_ptr

else:
for par in node._input_nodes:
if par.meta:
if len(par.meta["fwd_out"]) > 0:
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
for tensor in par.meta["fwd_out"]:
tensor: torch.Tensor
target_tensor = next(
(x for x in input_tensor if not x.data_ptr() and x.shape == tensor.shape), None)
target_tensor.data_ptr = tensor.data_ptr
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
for tensor in par.meta.get("fwd_out", []):
tensor: torch.Tensor
target_input_tensor = next(
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
if target_input_tensor is not None:
target_input_tensor.data_ptr = tensor.data_ptr

# set data_ptr for tensor in input_tensor that is not set
for tensor in input_tensor:
for tensor in input_tensors:
if not tensor.data_ptr():
self._set_data_ptr(tensor)

# attach it to graph_info
graph_info.fwd_in = input_tensor

if self._is_inplace(node):
# inplace operation will not create new tensor
# set data_ptr for buffer_tensor and output_tensor of current node
for tensor in input_tensor:
tensor: torch.Tensor
target_buffer_tensor = next((x for x in buffer_tensor if not x.data_ptr() and x.shape == tensor.shape),
None)
target_output_tensor = next((x for x in output_tensor if not x.data_ptr() and x.shape == tensor.shape),
None)
target_buffer_tensor.data_ptr = tensor.data_ptr
target_output_tensor.data_ptr = tensor.data_ptr
# attach them to graph_info
graph_info.fwd_tmp = buffer_tensor
graph_info.fwd_out = output_tensor

else:
# set data_ptr for buffer_tensor
for tensor in buffer_tensor:
for tensor in buffer_tensors:
self._set_data_ptr(tensor)
# attach it to graph_info
graph_info.fwd_tmp = buffer_tensor

# set data_ptr for output_tensor
for tensor in output_tensor:
for tensor in output_tensors:
self._set_data_ptr(tensor)
# attach it to graph_info
graph_info.fwd_out = output_tensor

# attach them to graph_info
graph_info.fwd_in = input_tensors
graph_info.fwd_tmp = buffer_tensors
graph_info.fwd_out = output_tensors

# fetch other memory informations
memory_cost = meta_info.memory_cost
Expand Down