Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions colossalai/legacy/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self):

# logging
self._verbose = False
self._logger = get_dist_logger()
self._logger = None

@property
def config(self):
Expand All @@ -68,6 +68,12 @@ def verbose(self):
def verbose(self, verbose_: bool):
self._verbose = verbose_

@property
def logger(self):
if self._logger is None:
self._logger = get_dist_logger()
return self._logger

def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file.

Expand Down Expand Up @@ -527,7 +533,7 @@ def set_device(self, device_ordinal: int = None):

torch.cuda.set_device(device_ordinal)
if self._verbose:
self._logger.info(f"process rank {global_rank} is bound to device {device_ordinal}")
self.logger.info(f"process rank {global_rank} is bound to device {device_ordinal}")

def set_seed(self, seed: int):
"""Sets seeds for all random libraries.
Expand Down Expand Up @@ -563,19 +569,19 @@ def set_seed(self, seed: int):
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])

if self._verbose:
self._logger.info(
self.logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}."
)
else:
if self._verbose:
self._logger.info(
self.logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, pytorch: {seed}",
ranks=[0],
)
self._logger.info(
self.logger.info(
"WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states",
ranks=[0],
)
Expand Down
5 changes: 4 additions & 1 deletion colossalai/legacy/tensor/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get(self, rank_list: List[int], backend: str = "nccl"):
return self.dict[processgroup_key]


PYTORCHPGDICT_ = PyTorchProcessGroupDict()
PYTORCHPGDICT_ = None


class ProcessGroup:
Expand Down Expand Up @@ -59,6 +59,9 @@ def __init__(
if not torch.distributed.is_initialized():
self.is_init = False
return
global PYTORCHPGDICT_
if PYTORCHPGDICT_ is None:
PYTORCHPGDICT_ = PyTorchProcessGroupDict()

assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"

Expand Down
33 changes: 11 additions & 22 deletions colossalai/shardformer/modeling/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,35 +100,24 @@ def pp_forward(
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
hidden_states = embedding_output
else:
assert (
hidden_states is not None
), f"Current stage is {stage_manager.stage}, hidden_states should not be None"

# Go through encoder
encoder_outputs = _encoder_forward(
encoder=self.encoder,
start_idx=stage_index[0],
end_idx=stage_index[1],
hidden_states=hidden_states,
head_mask=head_mask,
return_dict=return_dict,
stage_manager=stage_manager,
)
if not stage_manager.is_last_stage():
hidden_states = _encoder_forward(
encoder=self.encoder,
start_idx=stage_index[0],
end_idx=stage_index[1],
hidden_states=embedding_output,
head_mask=head_mask,
return_dict=return_dict,
stage_manager=stage_manager,
)
return {"hidden_states": hidden_states}
else:
encoder_outputs = _encoder_forward(
encoder=self.encoder,
start_idx=stage_index[0],
end_idx=stage_index[1],
hidden_states=hidden_states,
head_mask=head_mask,
return_dict=return_dict,
stage_manager=stage_manager,
)
return {"hidden_states": encoder_outputs}

# Go through rest layers
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
Expand Down
23 changes: 7 additions & 16 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Adam, Optimizer
from torch.testing import assert_close

from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
Expand Down Expand Up @@ -160,7 +161,7 @@ def _criterion(outputs, inputs):
input_shape = data["input_ids"].shape
for k, v in data.items():
if v.shape == input_shape:
data[k] = v.repeat((1, ) * (v.dim() - 1) + (times,))
data[k] = v.repeat((1,) * (v.dim() - 1) + (times,))

sharded_model.train()
if booster.plugin.stage_manager is not None:
Expand Down Expand Up @@ -207,15 +208,11 @@ def check_output_hidden_state(
else:
sharded_hidden_state = sharded_output.last_hidden_state

assert torch.allclose(
org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol
), f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)


def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
assert torch.allclose(
org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol
), f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)


def check_weight(
Expand All @@ -242,9 +239,7 @@ def check_weight(
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")

assert torch.allclose(
org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol
), f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
assert_close(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol)


def get_grad_tensors_for_check(
Expand Down Expand Up @@ -310,9 +305,7 @@ def check_grad(
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")

assert torch.allclose(
org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
assert_close(org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol)


def unwrap_model(
Expand All @@ -337,6 +330,4 @@ def check_all_grad_tensors(check_tensors):
shard_grad = check_info["shard_grad"]
rtol = check_info["rtol"]
atol = check_info["atol"]
assert torch.allclose(
org_grad, shard_grad, atol=atol, rtol=rtol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
assert_close(org_grad, shard_grad, atol=atol, rtol=rtol)
13 changes: 2 additions & 11 deletions tests/test_shardformer/test_model/test_shard_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
atol, rtol = 2e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
Expand All @@ -62,7 +62,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
atol, rtol = 2e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3

Expand Down Expand Up @@ -154,15 +154,6 @@ def run_vit_test(test_config):
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
],
)
def run_vit_3d_test(test_config):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_zero/test_gemini/test_optim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch
import torch.distributed as dist
from packaging.version import Version
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close

Expand Down Expand Up @@ -161,6 +162,9 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.
rtol, atol = 1.5e-6, 2e-5
if mixed_precision is torch.bfloat16:
rtol, atol = 2e-3, 2e-3
elif Version(torch.__version__) >= Version("2.0.0"):
rtol, atol = 4e-5, 3e-5

for i, (input_ids, label) in enumerate(train_dataloader):
if i > 2:
break
Expand Down