Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions colossalai/nn/_ops/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
pg = weight.get_process_group()
# embedding_1Drow splits the weight(lookup table) to the shape, [num_embeddings/P, embedding_dim]
# get the index of current segment and mask other segments with 0

# get complete input tensor through all-gather
input_tensor = input_tensor.redistribute(ReplicaSpec())

# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
Expand All @@ -54,12 +54,11 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
vocab_end_index = vocab_start_index + num_embeddings_per_partition

# Build the mask.
input_mask = (input_tensor < vocab_start_index) | \
(input_tensor >= vocab_end_index)
# Mask the input.
# build the mask.
input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
# mask the input.
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
masked_input = input_tensor.clone() - vocab_start_index
masked_input = input_tensor - vocab_start_index
masked_input[input_mask] = 0

partial_output = F.embedding(masked_input,
Expand Down
16 changes: 16 additions & 0 deletions tests/test_tensor/_utils/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern


def set_seed(seed):
Expand Down Expand Up @@ -57,3 +58,18 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
else:
raise NotImplementedError


def split_param_single_dim_tp1d(dim, param, pg):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
if param.process_group.tp_world_size() == 1:
param.set_process_group(pg)
param.set_tensor_spec(*spec)


def split_param_row_tp1d(param, pg):
split_param_single_dim_tp1d(0, param, pg)


def split_param_col_tp1d(param, pg):
split_param_single_dim_tp1d(-1, param, pg)
29 changes: 9 additions & 20 deletions tests/test_tensor/test_addmm_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor import ShardSpec
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from colossalai.tensor import ColoTensorSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from functools import partial
from _utils import tensor_shard_equal, tensor_equal
from _utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d


class Conv1D(nn.Module):
Expand All @@ -36,28 +35,18 @@ def forward(self, x):
return x


def init_1d_row(weight, bias, pg: ProcessGroup):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)


def init_1d_col(weight, bias, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
bias.set_tensor_spec(*spec)


def run_with_spec(spec_init_func):
def run_with_spec(spec_init_func, split_bias):
model = Conv1D(4, 16).cuda()
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)

weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))

spec_init_func(weight, bias, pg)
spec_init_func(weight, pg)
if split_bias:
spec_init_func(bias, pg)

x = torch.rand(2, 16).cuda()
out = model(x)
colo_out = torch.addmm(bias, x, weight)
Expand All @@ -72,8 +61,8 @@ def run_with_spec(spec_init_func):

def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_row)
run_with_spec(init_1d_col)
run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=False)
run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=True)


@pytest.mark.dist
Expand Down
16 changes: 5 additions & 11 deletions tests/test_tensor/test_embedding_bag_tp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import torch
from colossalai.tensor import ShardSpec, ColoParameter
from torch.nn import functional as F
from functools import partial

Expand All @@ -9,21 +7,17 @@
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal


def init_1d_col(weight, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d


def run_with_spec(spec_init_func):
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
model = torch.nn.EmbeddingBag(10, 4).cuda()
weight = ColoParameter(model.weight.clone(), True, ColoTensorSpec(pg))

spec_init_func(weight, pg)

inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda()
offsets = torch.tensor([0, 4]).cuda()
out = model(inputs, offsets=offsets)
Expand All @@ -38,7 +32,7 @@ def run_with_spec(spec_init_func):
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_col)
run_with_spec(split_param_col_tp1d)


@pytest.mark.dist
Expand Down
24 changes: 6 additions & 18 deletions tests/test_tensor/test_embedding_tp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import torch
from colossalai.tensor import ColoTensor, ShardSpec
from torch.nn import functional as F
from functools import partial

Expand All @@ -9,26 +7,16 @@
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal


def init_1d_row(weight, pg: ProcessGroup):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)


def init_1d_col(weight, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d


def run_with_spec(spec_init_func, pg: ProcessGroup):
model = torch.nn.Embedding(12, 32).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))

spec_init_func(weight, pg)

x = torch.tensor((0, 3, 6, 9)).cuda()
out = model(x)
colo_out = F.embedding(x, weight)
Expand All @@ -44,8 +32,8 @@ def run_dist(rank, world_size, port):
# config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
run_with_spec(init_1d_row, pg)
run_with_spec(init_1d_col, pg)
run_with_spec(split_param_row_tp1d, pg)
run_with_spec(split_param_col_tp1d, pg)


@pytest.mark.dist
Expand Down
32 changes: 10 additions & 22 deletions tests/test_tensor/test_linear_tp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import torch
from colossalai.tensor import ColoTensor, ShardSpec

from functools import partial

import colossalai
Expand All @@ -10,29 +7,20 @@
import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal


def init_1d_row(weight, bias, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d


def init_1d_col(weight, bias, pg: ProcessGroup):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
bias.set_tensor_spec(*spec)


def run_with_spec(spec_init_func):
def run_with_spec(spec_init_func, split_bias):
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
model = torch.nn.Linear(4, 8).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))
spec_init_func(weight, bias, pg)

spec_init_func(weight, pg)
if split_bias:
spec_init_func(bias, pg)

x = torch.rand(2, 4).cuda()
out = model(x)
colo_out = F.linear(x, weight, bias)
Expand All @@ -48,8 +36,8 @@ def run_with_spec(spec_init_func):
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_row)
run_with_spec(init_1d_col)
run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=False)
run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=True)


@pytest.mark.dist
Expand Down
Loading