From 86be7440076de238bacdd4bf401513d44025c388 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Wed, 6 Jul 2022 17:34:24 +0800 Subject: [PATCH 1/4] make it faster --- tests/test_utils/test_colo_checkpoint.py | 39 +++++++++++++----------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index 6e7d4441d760..48742fc18a58 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -1,21 +1,20 @@ from abc import ABC, abstractmethod -import os, sys, shutil +import os, shutil import torch import torch.nn as nn import pytest import copy -import operator -import colossalai -from colossalai.context.parallel_mode import ParallelMode +from functools import partial + import torch.multiprocessing as mp import torch.distributed as dist + +import colossalai from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup, ColoTensor -from colossalai.core import global_context as gpc -from functools import partial +from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -46,15 +45,17 @@ def __len__(self): class DummyDataLoader(DummyDataGenerator): - batch_size = 128 - category = 16 - feature_size = 256 + + def __init__(self, batch_size, category, feature_size, length=10): + super().__init__(length) + self.batch_size = batch_size + self.category = category + self.feature_size = feature_size def generate(self): image_dict = {} - image_dict['pixel_values'] = torch.rand( - DummyDataLoader.batch_size, DummyDataLoader.feature_size, device=get_current_device()) * 2 - 1 - image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,), + image_dict['pixel_values'] = torch.rand(self.batch_size, self.feature_size, device=get_current_device()) * 2 - 1 + image_dict['label'] = torch.randint(self.category, (self.batch_size,), dtype=torch.int64, device=get_current_device()) return image_dict @@ -102,11 +103,15 @@ def remove(path): def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg): - train_dataloader = DummyDataLoader(length=16) + batch = 3 + feature = 32 + category = 16 + train_dataloader = DummyDataLoader(batch, category, feature, length=16) with ColoInitContext(device=get_current_device()): - model = MLP(256, 16, 64) - model_reload = MLP(256, 16, 64) - model_ref = MLP(256, 16, 64) + model = MLP(feature, category) + model_reload = MLP(feature, category) + model_ref = MLP(feature, category) + model = model.cuda() model_reload = model_reload.cuda() model_ref = model_ref.cuda() From fa315edcb43fd998c967895bc6078aa2d4dd6018 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Wed, 13 Jul 2022 14:14:41 +0800 Subject: [PATCH 2/4] [checkpoint] hotfix bugs in colo checkpoint --- colossalai/tensor/colo_tensor.py | 1 - .../utils/checkpoint/module_checkpoint.py | 7 +- tests/test_utils/test_colo_checkpoint.py | 222 +++++++++--------- 3 files changed, 115 insertions(+), 115 deletions(-) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 0c92ac0c7796..dd678b1a3f83 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -204,7 +204,6 @@ def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) ColoTensor: a redistributed colotensor """ if pg is not None and pg != self.get_process_group(): - print('here _redistribute') # if the pg is not equal, convert the current tensor to replicated self._redistribute(ReplicaSpec()) self.process_group = pg diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index 3f61aed2f092..119d719b2c33 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -28,7 +28,8 @@ def save_checkpoint(dire: str, if isinstance(v, ColoTensor): mapping[k] = (v.dist_spec, v.compute_spec) new_dict[k] = v.to_replicate().detach() - + else: + new_dict[k] = v if dist.get_rank() == 0: for k, v in new_dict.items(): if isinstance(v, ColoTensor): @@ -60,7 +61,7 @@ def load_checkpoint(dire, """ mapping = dict() - for k, v in model.named_parameters(): + for k, v in model.state_dict().items(): if isinstance(v, ColoTensor): mapping[k] = (v.dist_spec, v.compute_spec) v.to_replicate_() @@ -70,6 +71,6 @@ def load_checkpoint(dire, # reset tensors to original dist spec. with DistSpecManager.no_grad(): - for k, v in model.named_parameters(): + for k, v in model.state_dict().items(): if isinstance(v, ColoTensor): v.set_tensor_spec(*mapping[k]) diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index 0581d7bf02f6..2643bf2170a1 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -1,91 +1,69 @@ -from abc import ABC, abstractmethod import os, shutil import torch -import torch.nn as nn import pytest from functools import partial import torch.multiprocessing as mp import torch.distributed as dist + from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import MultiplicativeLR +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR import colossalai from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup +from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR - - -class DummyDataGenerator(ABC): - - def __init__(self, length=10): - self.length = length - - @abstractmethod - def generate(self): - pass - - def __iter__(self): - self.step = 0 - return self +from colossalai.nn.optimizer import ColoOptimizer - def __next__(self): - if self.step < self.length: - self.step += 1 - return self.generate() - else: - raise StopIteration - - def __len__(self): - return self.length +from tests.components_to_test.registry import non_distributed_component_funcs -class DummyDataLoader(DummyDataGenerator): +def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) - def __init__(self, batch_size, category, feature_size, length=10): - super().__init__(length) - self.batch_size = batch_size - self.category = category - self.feature_size = feature_size - def generate(self): - image_dict = {} - image_dict['pixel_values'] = torch.rand(self.batch_size, self.feature_size, device=get_current_device()) * 2 - 1 - image_dict['label'] = torch.randint(self.category, (self.batch_size,), - dtype=torch.int64, - device=get_current_device()) - return image_dict +def init_1d_col_linear(weight, pg): + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) -class MLP(nn.Module): +def init_1d_row_embedding(weight, pg): + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) - def __init__(self, in_features, out_features, hidden_features=None): - super().__init__() - if hidden_features is None: - hidden_features = out_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.fc2 = nn.Linear(hidden_features, out_features) - self.activation = nn.ReLU() - def forward(self, x): - x = self.fc1(x) - x = self.activation(x) - x = self.fc2(x) - return x +def init_1d_col_embedding(weight, pg): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - with DistSpecManager.no_grad(): - for n, p in model.named_parameters(): - if 'weight' in n: - p.set_process_group(pg) - p.set_tensor_spec(*spec) + for name, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + continue + if 'embed' in name and 'weight' in name: + init_1d_col_embedding(p, pg) + if 'proj1' in name and ('weight' in name or 'bias' in name): + init_1d_col_linear(p, pg) + if 'proj2' in name and 'weight' in name: + init_1d_row_linear(p, pg) + if 'classifier' in name and ('weight' in name or 'bias' in name): + init_1d_col_linear(p, pg) def check_param_equal(model, torch_model): @@ -103,56 +81,76 @@ def remove(path): raise ValueError("file {} is not a file or dir.".format(path)) -def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): - num_epoch = 5 - warmup_epoch = 2 +def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - batch = 3 - feature = 32 - category = 16 + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + # set_seed(1) with ColoInitContext(device=get_current_device()): - model = MLP(feature, category) + model = model_builder(checkpoint=True) + model_reload = model_builder(checkpoint=True) - with ColoInitContext(device=get_current_device()): - model_reload = MLP(feature, category) + if use_mp_reload: + if 'bert' == model_name: + for name, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + continue + # num_class = type_vocab_size = 2 | (8, 2) + if 'classifier' in name and 'weight' in name: + init_1d_row_linear(p, pg) + # num_class = vocab_size = 30524 | (30524, 8) + elif 'word_embeddings' in name and 'weight' in name: + init_1d_row_embedding(p, pg) + # num_class = seq_len = 512 | (512, 8) + elif 'position_embeddings' in name and 'weight' in name: + init_1d_row_embedding(p, pg) + # num_class = type_vocab_size = 2 | (2, 8) + elif 'token_type_embeddings' in name and 'weight' in name: + init_1d_col_embedding(p, pg) + elif p.process_group.tp_world_size() == 1: + with DistSpecManager.no_grad(): + p.redistribute(ReplicaSpec(), pg) + elif "simple_net" == model_name: + init_spec_func(model, pg) model = model.cuda() + model.train() + model_reload = model_reload.cuda() - if use_ddp: - model = ColoDDP(model, pg) - model_reload = ColoDDP(model_reload, pg) + model_reload.train() - init_spec_func(model, pg) - if use_mp_reload: - init_spec_func(model_reload, pg) - - optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) - optimizer_reload = torch.optim.Adam(model_reload.parameters(), - lr=0.001, - betas=(0.9, 0.999), - eps=1e-08, - weight_decay=0) - - lr_scheduler = None - if test_scheduler == 'colossalai_cosine_warmup': - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch) - lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload, - total_steps=num_epoch, - warmup_steps=warmup_epoch) - elif test_scheduler == 'torch_cosine': - lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch) - lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch) - elif test_scheduler == 'torch_lambda': - lr_lambda = lambda epoch: 0.95 - lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda) - lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda) - else: - raise TypeError(f"{test_scheduler} is invalid") + colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1) + + for i, (data, label) in enumerate(train_dataloader): + + # Zero grad + colo_optimizer.zero_grad() - save_checkpoint('./checkpoint', 0, model, optimizer, lr_scheduler) + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + # Bcast rank0 data to all processes + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + loss.backward() + colo_optimizer.step() + + if i > 2: + break + + if not os.path.isdir('./checkpoint') and rank == 0: + os.mkdir('./checkpoint') + save_checkpoint('./checkpoint', 0, model, None, None) dist.barrier() - load_checkpoint('./checkpoint', 0, model_reload, optimizer_reload, lr_scheduler_reload) + load_checkpoint('./checkpoint', 0, model_reload, None, None) # Since model is sharded, we merge them before param checking. for p in model.parameters(): @@ -163,26 +161,29 @@ def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): check_param_equal(model, model_reload) + if rank == 0: + remove('./checkpoint') + def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): - if use_ddp and world_size == 1: - return - tp_world_size = world_size // 2 if use_ddp else world_size - config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') pg = ProcessGroup(tp_degree=world_size) - run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, use_mp_reload, test_scheduler=test_scheduler, pg=pg) + for model_name in ['bert', 'simple_net']: + _run_checkpoint(model_name, + init_1d_row_for_linear_weight_spec, + use_ddp, + use_mp_reload, + test_scheduler=test_scheduler, + pg=pg) @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 2]) -@pytest.mark.parametrize('use_ddp', [True, False]) +@pytest.mark.parametrize('use_ddp', [False]) @pytest.mark.parametrize('use_mp_reload', [True, False]) -@pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) +# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) @rerun_if_address_is_in_use() -def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler): - if not os.path.isdir('./checkpoint'): - os.mkdir('./checkpoint') +def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None): run_func = partial(run_dist, world_size=world_size, port=free_port(), @@ -190,8 +191,7 @@ def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler): use_mp_reload=use_mp_reload, test_scheduler=test_scheduler) mp.spawn(run_func, nprocs=world_size) - remove('./checkpoint') if __name__ == '__main__': - test_checkpoint(2, True, False, "torch_cosine") + test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine") From 0a9c4b099bcc412bd6bc6efa47688022f17d04cc Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Thu, 14 Jul 2022 15:34:59 +0800 Subject: [PATCH 3/4] polish code --- tests/test_utils/test_colo_checkpoint.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index 2643bf2170a1..4557cfa2805c 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -25,30 +25,26 @@ def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup): spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - with DistSpecManager.no_grad(): - weight.set_process_group(pg) - weight.set_tensor_spec(*spec) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) def init_1d_col_linear(weight, pg): spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - with DistSpecManager.no_grad(): - weight.set_process_group(pg) - weight.set_tensor_spec(*spec) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) def init_1d_row_embedding(weight, pg): spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - with DistSpecManager.no_grad(): - weight.set_process_group(pg) - weight.set_tensor_spec(*spec) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) def init_1d_col_embedding(weight, pg): spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - with DistSpecManager.no_grad(): - weight.set_process_group(pg) - weight.set_tensor_spec(*spec) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): @@ -111,8 +107,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch elif 'token_type_embeddings' in name and 'weight' in name: init_1d_col_embedding(p, pg) elif p.process_group.tp_world_size() == 1: - with DistSpecManager.no_grad(): - p.redistribute(ReplicaSpec(), pg) + p.redistribute(ReplicaSpec(), pg) elif "simple_net" == model_name: init_spec_func(model, pg) From 37ba7469cbaa7dfd1280c7732c2ffb82182cf6c5 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Thu, 14 Jul 2022 15:37:34 +0800 Subject: [PATCH 4/4] [optimizer] polish ColoOptimizer --- colossalai/nn/optimizer/colo_optimizer.py | 16 ++++------------ colossalai/tensor/process_group.py | 9 +++++++-- tests/test_tensor/test_model.py | 6 +++--- tests/test_utils/test_colo_checkpoint.py | 4 ++-- 4 files changed, 16 insertions(+), 19 deletions(-) diff --git a/colossalai/nn/optimizer/colo_optimizer.py b/colossalai/nn/optimizer/colo_optimizer.py index 52c641594fa1..72ac916823ef 100644 --- a/colossalai/nn/optimizer/colo_optimizer.py +++ b/colossalai/nn/optimizer/colo_optimizer.py @@ -24,12 +24,7 @@ def __init__(self, named_params: Mapping[str, Union[Tensor, ColoTensor]], optimi **optimizer_kwargs: the key-word arguments to initialize the optimizer. """ - tensors: List[Tensor] = [] - for value in named_params.values(): - tensors.append(value) - - self.named_params = named_params - self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs) + self._optim = optimizer_class([p for n, p in named_params], *optimizer_args, **optimizer_kwargs) self.param_groups = self._optim.param_groups self.state = self._optim.state @@ -68,8 +63,7 @@ def state_dict(self) -> Dict[str, Any]: Returned state and param_groups will contain parameter keys instead of parameter indices like torch.optim.Optimizer. """ - # TODO: implement state_dict - raise NotImplementedError("ColoOptimizer state_dict not implemented yet!") + return self._optim.state_dict() def load_state_dict(self, state_dict: Mapping[str, Any]): r"""Loads the ColoOptimizer state. @@ -78,11 +72,9 @@ def load_state_dict(self, state_dict: Mapping[str, Any]): state_dict (dict): ColoOptimizer state. Should be an object returned from a call to :meth:`state_dict`. """ - # TODO: implement load_state_dict - raise NotImplementedError("ColoOptimizer load_state_dict not implemented yet!") + self._optim.load_state_dict(state_dict) def add_param_group(self, param_group: Any): r"""Add a new param group """ - # TODO: implement add_param_group - raise NotImplementedError("ColoOptimizer add_param_group not implemented yet!") + self._optim.add_param_group(param_group) diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py index 1624638c4117..f6330c2b1996 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/tensor/process_group.py @@ -48,6 +48,7 @@ def __init__(self, tp_degree: Optional[int] = None, dp_degree: Optional[int] = None) -> None: if not torch.distributed.is_initialized(): + self.is_init = False return assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" @@ -96,6 +97,7 @@ def __init__(self, self._has_cpu_groups = False PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl') PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl') + self.is_init = True def set_cpu_groups(self): if self.has_cpu_groups: @@ -110,8 +112,11 @@ def has_cpu_groups(self): return self._has_cpu_groups def __repr__(self): - return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\ - format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list) + if self.is_init: + return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\ + format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list) + else: + return "ProcessGroup not initialized" def __eq__(self, obj: 'ProcessGroup') -> bool: if not isinstance(obj, ProcessGroup): diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 34a376891f85..ee5edae2c578 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -33,7 +33,7 @@ def run_1d_hybrid_tp(model_name): if rank == 0: model_torch = model_builder(checkpoint=True) model_torch = model_torch.cuda() - optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1) + optimizer_torch = ColoOptimizer(model_torch.named_parameters(), torch.optim.SGD, lr=0.1) # Make two models have the same init params for p1, p2 in zip(model.parameters(), model_torch.parameters()): @@ -80,7 +80,7 @@ def run_1d_hybrid_tp(model_name): if rank == 0: model_torch.train() - colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1) + colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1) for i, (data, label) in enumerate(train_dataloader): @@ -170,7 +170,7 @@ def test_colo_optimizer(): with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): model = model_builder(checkpoint=True) - colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1) + colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1) for i, (data, label) in enumerate(train_dataloader): colo_optimizer.zero_grad() data = data.to(get_current_device()) diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index 2643bf2170a1..23cbdeb648e7 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -16,9 +16,9 @@ from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec -from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint from colossalai.nn.optimizer import ColoOptimizer +from colossalai.nn.parallel.data_parallel import ColoDDP from tests.components_to_test.registry import non_distributed_component_funcs @@ -122,7 +122,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch model_reload = model_reload.cuda() model_reload.train() - colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1) + colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1) for i, (data, label) in enumerate(train_dataloader):