Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions tests/test_tensor/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
Expand All @@ -21,18 +21,20 @@

def init_1d_row_spec(model, pg: ProcessGroup):
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
p.set_tensor_spec(*tensor_spec)

for n, p in model.named_parameters():
p.set_process_group(pg)
if 'weight' in n and 'ln' not in n:
p.set_tensor_spec(*tensor_spec)


def init_1d_col_spec(model, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_tensor_spec(*spec)

for n, p in model.named_parameters():
p.set_process_group(pg)
if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_tensor_spec(*spec)


def check_param_equal(model, torch_model, pg: ProcessGroup):
Expand All @@ -48,6 +50,7 @@ def check_grad_equal(model, torch_model, pg: ProcessGroup):


def run_gpt(init_spec_func, use_ddp):
set_seed(13234)
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
Expand All @@ -67,14 +70,16 @@ def run_gpt(init_spec_func, use_ddp):
model = ColoDDP(model, process_group=pg)
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p)

init_spec_func(model, pg)
check_param_equal(model, torch_model, pg)
model.train()
torch_model.train()
set_seed(pg.tp_local_rank())
torch.distributed.barrier()

for i, (input_ids, attn_mask) in enumerate(train_dataloader):
logits = model(input_ids, attn_mask)
colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need to convert input to ColoTensor?
What if pass a torch tensor into the colo model?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether this line is necessary.
Since we have a problem in transformation between [dp: 4, tp: 1] and [dp: 1, tp: 4], I think we'd better add this line to ensure that this transformation would not be triggered. Otherwise, to_replicate may use PG, [dp: 4, tp: 1] and the tensor wouldn't be replicated when we want to get 4 copy from it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the dp degree is consist among the model parameters. I believe we should pass the torch tensor into the model.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test the correctness in the original way?
logits = model(input_ids, attn_mask)

Copy link
Copy Markdown
Contributor Author

@1SAA 1SAA Jul 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the dp degree is consist among the model parameters. I believe we should pass the torch tensor into the model.

I don't think so. Acctually, passed data has its own ditribution information. In TP environment, passed data is replicated among TP process group. In DP environment, passed data is just a part of batched data. It is not aprropriate to regard them having no difference with odinary tensors.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, that makes sense

logits = model(colo_input, attn_mask)
torch_logits = torch_model(input_ids, attn_mask)
assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
loss = criterion(logits, input_ids)
Expand All @@ -95,19 +100,18 @@ def run_dist(rank, world_size, port, use_ddp):
tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# run_gpt(init_1d_row_spec, use_ddp)
run_gpt(init_1d_row_spec, use_ddp)
run_gpt(init_1d_col_spec, use_ddp)


@pytest.mark.dist
@pytest.mark.skip("under development")
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True])
@pytest.mark.parametrize('use_ddp', [False])
@rerun_if_address_is_in_use()
def test_gpt(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_gpt(4, True)
test_gpt(4, False)