Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions colossalai/nn/_ops/addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res

mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]), mat2.get_process_group())

# Output:P
partial_output = torch.mm(mat1, mat2)
# Reduce(Output)
output = reduce_input(partial_output, mat1.get_process_group())
output = reduce_input(partial_output, mat2.get_process_group())
# input
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor + alpha * output
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(ReplicaSpec()))
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(input_tensor.get_process_group()))
return output


Expand Down
10 changes: 5 additions & 5 deletions colossalai/nn/_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec


def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]))
input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]), pg)

# Output:P
partial_output = F.linear(input_tensor, weight)
Expand All @@ -27,7 +27,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
return output


def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
def colo_linear_1dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
Expand All @@ -48,7 +48,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option

def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
assert mode in ('row', 'col')
funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol}
funcs = {'row': colo_linear_1drow, 'col': colo_linear_1dcol}
return funcs[mode](input_tensor, weight, bias)


Expand Down
12 changes: 7 additions & 5 deletions colossalai/tensor/colo_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,14 @@ def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None)
ColoTensor: a redistributed colotensor
"""
if pg is not None and pg != self.get_process_group():
print('here _redistribute')
# if the pg is not equal, convert the current tensor to replicated
self._redistribute(ReplicaSpec())
self.process_group = pg
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
handled = self.redistribute(ReplicaSpec())
else:
handled = self
pg = self.process_group

ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))

def to_replicate_(self):
"""to_replicate_
Expand Down
54 changes: 11 additions & 43 deletions tests/test_tensor/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,13 @@
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup, ReplicaSpec
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.nn.optimizer import ColoOptimizer

from tests.components_to_test.registry import non_distributed_component_funcs
from _utils import split_param_row_tp1d, split_param_col_tp1d


def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)


def init_1d_col_linear(weight, pg):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)


def init_1d_row_embedding(weight, pg):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)


def init_1d_col_embedding(weight, pg):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)


def run_1d_hybrid_tp(model_name):
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable(model_name)
Expand Down Expand Up @@ -79,33 +50,30 @@ def run_1d_hybrid_tp(model_name):

# num_class = type_vocab_size = 2 | (8, 2)
if 'classifier' in name and 'weight' in name:
init_1d_row_linear(p, pg)
split_param_col_tp1d(p, pg)
# num_class = vocab_size = 30524 | (30524, 8)
elif 'word_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p, pg)
split_param_row_tp1d(p, pg)
# num_class = seq_len = 512 | (512, 8)
elif 'position_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p, pg)
split_param_row_tp1d(p, pg)
# num_class = type_vocab_size = 2 | (2, 8)
elif 'token_type_embeddings' in name and 'weight' in name:
init_1d_col_embedding(p, pg)
elif p.process_group.tp_world_size() == 1:
with DistSpecManager.no_grad():
p.redistribute(ReplicaSpec(), pg)
split_param_col_tp1d(p, pg)

elif "simple_net" == model_name:
# A naive way to set spec for all weights in Linear
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
if 'embed' in name and 'weight' in name:
init_1d_col_embedding(p, pg)
split_param_col_tp1d(p, pg)
if 'proj1' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p, pg)
split_param_row_tp1d(p, pg)
if 'proj2' in name and 'weight' in name:
init_1d_row_linear(p, pg)
split_param_col_tp1d(p, pg)
if 'classifier' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p, pg)
split_param_row_tp1d(p, pg)

model = model.cuda()
model.train()
Expand Down Expand Up @@ -327,9 +295,9 @@ def _run_pretrain_load():

def run_model_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for name in ['bert']:
for name in ['bert', 'simple_net']:
run_1d_row_tp(name)
for name in ['bert']:
for name in ['bert', 'simple_net']:
run_1d_hybrid_tp(name)


Expand Down