Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/nn/parallel/layers/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def init_colo_module(module: torch.nn.Module,
compute_pattern = compute_spec.compute_pattern
if is_colo_module(module):
# for each param
# set DistSpec and ComputeSpec
# set its process_group, dist_spec and compute_spec
colo_module = get_colo_module(module)
colo_module.register(compute_pattern, pg)
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
Expand All @@ -101,6 +101,7 @@ def init_colo_module(module: torch.nn.Module,
continue
param = module.get_parameter(param_name)
if isinstance(param, ColoParameter):
param.set_process_group(pg)
param.set_dist_spec(dist_spec)
param.compute_spec = compute_spec
for mod in param.shared_param_modules:
Expand Down
24 changes: 13 additions & 11 deletions colossalai/tensor/colo_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
Tensor._base.__get__,
Tensor.grad.__get__,
Tensor._grad.__get__,
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
}


Expand Down Expand Up @@ -121,11 +121,13 @@ def set_process_group(self, pg: ProcessGroup):
RuntimeError:
"""
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
if self.process_group.tp_world_size() != 1:
raise RuntimeError("can not set_process_group on a ColoTensor whose process_group has tp world group")

if self.dist_spec.placement.value != 'r':
raise RuntimeError("can not set_process_group on a ColoTensor whose dist spec is not REPLICATE")
# if the new pg is the same as the old pg, just returns
if self.process_group == pg:
return
assert self.process_group.tp_world_size() == 1, \
"Can not set_process_group on a ColoTensor whose process_group has tp world group"
assert self.dist_spec.placement.value == 'r', \
"Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE"

self.process_group = pg

Expand Down Expand Up @@ -290,17 +292,17 @@ def size_global(self, args: Optional[int] = None):

def is_replicate(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.process_group.tp_world_size() == 1)
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.process_group.tp_world_size() == 1)

def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1

def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0

def is_sharded(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD
24 changes: 14 additions & 10 deletions tests/test_tensor/test_module_spec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from copy import copy
from copy import deepcopy
import pytest
from functools import partial

import torch
import torch.multiprocessing as mp

from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, ShardSpec, ReplicaSpec
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ShardSpec, ColoTensorSpec
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed

Expand Down Expand Up @@ -112,21 +112,25 @@ def run_linear_with_spec(mode):
with ColoInitContext(device=get_current_device()):
model = torch.nn.Linear(4, 8)

model_handy = copy(model)
model_handy = deepcopy(model)
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
compute_spec = ComputeSpec(ComputePattern.TP1D)
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)

x = torch.rand(2, 4).cuda()
colo_x = ColoTensor.from_torch_tensor(x, ColoTensorSpec(pg))

out = model(x)
colo_out = model_handy(x)
colo_out = model_handy(colo_x)
assert tensor_equal(out, colo_out)

grad = torch.rand_like(out)
out.backward(grad)
colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad, pg.tp_local_rank(), pg.tp_world_size())

assert tensor_shard_equal(model_handy.weight.grad, model.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model_handy.bias.grad, model.bias.grad, pg.tp_local_rank(), pg.tp_world_size())


def run_check_shared_param():
Expand Down Expand Up @@ -196,7 +200,7 @@ def run_dist_check(rank, world_size, port):

@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.skip("under development lazy init ColoParameter in Context")
@pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use()
def test_module_linear_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
Expand All @@ -205,7 +209,7 @@ def test_module_linear_1d(world_size):

@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.skip("under development lazy init ColoParameter in Context")
@pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use()
def test_module_model(world_size):
run_func = partial(run_dist_model, world_size=world_size, port=free_port())
Expand All @@ -214,12 +218,12 @@ def test_module_model(world_size):

@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.skip("under development lazy init ColoParameter in Context")
@pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use()
def test_module_check(world_size):
run_func = partial(run_dist_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_module_check(2)
test_module_linear_1d(4)