Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions colossalai/_analyzer/_subclasses/_meta_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,7 @@ def meta_index_Tensor(self, indices):
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return new((num_weights, grad_output.size(-1)),
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)

# ============================== Dropout ===========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
Expand Down
58 changes: 53 additions & 5 deletions colossalai/_analyzer/fx/passes/shape_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def _normalize_tuple(x):


def _current_device(module):
return next(module.parameters()).device
try:
return next(module.parameters()).device
except StopIteration:
return torch.device('cpu')


@compatibility(is_backward_compatible=False)
Expand Down Expand Up @@ -120,15 +123,18 @@ def _convert_meta(t: torch.Tensor):
return t.to('meta')

if isinstance(elem, MetaTensor):
if getattr(self, '_is_param', False):
return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor)

elif isinstance(elem, torch.Tensor):
if isinstance(elem, torch.nn.Parameter):
return torch.nn.Parameter(_convert_meta(elem))
return _convert_meta(elem)

else:
return elem

# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)
Expand All @@ -149,7 +155,11 @@ def _convert_meta(t: torch.Tensor):
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
tuple(v for v in kwargs.values() if is_pure_tensor(v))

n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) # align with SPMD
# align with SPMD
if isinstance(r, (tuple, list)):
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r))
else:
n._meta_data = unwrap_fn(r)

n_info.global_ctx = self.global_hook.ctx
n_info.curr_ctx = self.global_hook.ctx.copy()
Expand All @@ -175,10 +185,48 @@ def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[st
Return
Any: The value returned by the function invocation
"""
convert_to_param = False
if target in (torch.transpose, torch.reshape) and isinstance(args[0], torch.nn.parameter.Parameter):
convert_to_param = True
if target in self._custom_dispatch_func:
return self._custom_dispatch_func[target](*args, **kwargs)
res = self._custom_dispatch_func[target](*args, **kwargs)
else:
res = super().call_function(target, args, kwargs)
if convert_to_param:
return torch.nn.Parameter(res)
else:
return res

def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Return
Any: The value returned by the method invocation
"""
# args[0] is the `self` object for this method call
self_obj, *args_tail = args

target_method = getattr(self_obj.__class__, target)

convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
args[0], torch.nn.parameter.Parameter):
convert_to_parameter = True
# Execute the method and return the result
assert isinstance(target, str)
res = getattr(self_obj, target)(*args_tail, **kwargs)
if convert_to_parameter:
return torch.nn.Parameter(res)
else:
return super().call_function(target, args, kwargs)
return res

def propagate(self, *args, device=None):
"""
Expand Down
114 changes: 36 additions & 78 deletions colossalai/_analyzer/fx/tracer/bias_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,111 +21,69 @@ def linear_impl(input, weight, bias=None):


@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
def conv1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
return F.conv1d(input, weight, **kwargs)
else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))


@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
def conv2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
return F.conv2d(input, weight, **kwargs)
else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))


@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
def conv3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
return F.conv3d(input, weight, **kwargs)
else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))


@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input,
weight,
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1)):
def conv_transpose1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose1d(input, weight, **kwargs)
else:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))


@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input,
weight,
bias=None,
stride=_pair(1),
padding=_pair(0),
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
def conv_transpose2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose2d(input, weight, **kwargs)
else:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))


@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input,
weight,
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1)):
def conv_transpose3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose3d(input, weight, **kwargs)
else:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))


@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
Expand Down
24 changes: 19 additions & 5 deletions colossalai/auto_parallel/meta_profiler/metainfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,28 @@ def target(self, target: Callable) -> None:
if self._strategy is not None and self._target is not None:
self.compute_metainfo()

def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec):
"""
Compute sharded opdata based on the given data and sharding spec.
"""
return OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)

if isinstance(sharding_spec, ShardingSpec):
op_data = OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
assert len(data) == len(sharding_spec), f"Length of data and sharding spec should be the same."
sharded_data = []
for d, s in zip(data, sharding_spec):
sharded_data.append(torch.zeros(s.get_sharded_shape_per_device(), device="meta"))
op_data = OperationData(name=operation_data.name, data=sharded_data, type=operation_data.type)
else:
raise ValueError(f"Sharding spec should be ShardingSpec or list, but got {type(sharding_spec)}.")

return op_data

def compute_metainfo(self):
"""
Expand Down
11 changes: 6 additions & 5 deletions colossalai/auto_parallel/passes/runtime_preparation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,13 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# This stream is created for overlaping the communication and computation.
reduction_stream = torch.cuda.Stream()

def _add_hook_for_grad_communication(node, param):
def _add_hook_for_grad_communication(node, param, name=None):

comm_actions = node.best_strategy.communication_actions

def _filter_param_to_hook(node, op_data, comm_action):
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK:
def _filter_param_to_hook(node, op_data, comm_action, name):

if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
return True
if node.op == 'get_attr' and isinstance(
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
Expand All @@ -402,7 +403,7 @@ def _filter_param_to_hook(node, op_data, comm_action):
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if _filter_param_to_hook(node, operation_data, comm_action):
if _filter_param_to_hook(node, operation_data, comm_action, name=name):

def wrapper(param, comm_spec, stream, overlap):

Expand Down Expand Up @@ -442,7 +443,7 @@ def _shard_param(param, target_sharding_spec):
param = _shard_param(param, target_sharding_spec)

setattr(target_module, name, param)
_add_hook_for_grad_communication(node, param)
_add_hook_for_grad_communication(node, param, name)

sharded_buffer_dict = {}
# apply the sharding spec of buffers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
generator = BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)
# addbmm will shrink the first batch dim
generator.squeeze_batch_dim = True
generators.append(generator)
return generators

def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -776,10 +776,6 @@ def validate(self) -> bool:
bias_op_data = self.op_data['bias']
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2

if self.op_data['output'].data.dim() == 2:
# addbmm will shrink the first batch dim
self.squeeze_batch_dim = True

def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
self.op_data['output'].data.shape)
Expand Down
2 changes: 1 addition & 1 deletion colossalai/fx/_meta_regist_12.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def meta_local_scalar_dense(self: torch.Tensor):
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)
return torch.empty_like(condition + self + other, dtype=result_type)


@register_meta(aten.index.Tensor)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from faulthandler import disable
from functools import partial
from xml.dom import WrongDocumentErr

import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from typing_extensions import Self

from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use
Expand Down Expand Up @@ -96,7 +94,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
meta_arg_names=meta_arg_names,
node_type='bias_module')

tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %m1 : torch.Tensor [#users=1] = placeholder[target=m1]
Expand All @@ -109,6 +107,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
# return add
graph = tracer.trace(model, meta_args=meta_args_for_tracer)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args_for_tracer.values())
# [input_1, m1, m2, addmm, output]
node_list = list(graph.nodes)
linear_node = node_list[4]
Expand Down
Loading