Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.utils.data.distributed import DistributedSampler

from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
Expand Down Expand Up @@ -513,7 +513,7 @@ def seed_worker(worker_id):
**_kwargs)

def get_checkpoint_io(self) -> CheckpointIO:
self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io

def no_sync(self, model: Module) -> Iterator[None]:
Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile

__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']
8 changes: 4 additions & 4 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'


class HypridParallelCheckpointIO(GeneralCheckpointIO):
class HybridParallelCheckpointIO(GeneralCheckpointIO):
"""
CheckpointIO for Hybrid Parallel Training.

Expand Down Expand Up @@ -136,7 +136,7 @@ def _optimizer_sharder(optimizer: OptimizerWrapper,

param_id = param_info['param2id'][id(working_param)]
original_shape = param_info['param2shape'][id(working_param)]
state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
working_param,
original_shape=original_shape,
dp_group=dp_group,
Expand Down Expand Up @@ -189,7 +189,7 @@ def save_sharded_model(self,

# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.tp_rank == 0)
Expand Down Expand Up @@ -385,7 +385,7 @@ def save_sharded_optimizer(self,

# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder(
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.dp_group,
Expand Down