Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,16 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded
warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
Expand Down
15 changes: 12 additions & 3 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader

from colossalai.checkpoint_io import CheckpointIO
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
Expand All @@ -32,8 +32,17 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
# TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now
warnings.warn(
'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)


class LowLevelZeroModel(ModelWrapper):
Expand Down
130 changes: 69 additions & 61 deletions tests/test_checkpoint_io/test_gemini_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,95 @@
import tempfile
import os

import pytest
import torch
import torch.distributed as dist
from utils import shared_tempdir

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero import ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo


@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['bert'])
@parameterize('use_safetensors', [True, False])
@parameterize('model_name', ['transformers_bert_for_sequence_classification'])
@parameterize('use_safetensors', [False, True])
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn()

model_ckpt_dir = tempfile.TemporaryDirectory()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, *_ = get_components_func()
with ColoInitContext(device=(get_current_device())):
bert_model = model_builder()
bert_model.config.save_pretrained(save_directory=(model_ckpt_dir.name))

config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
bert_model = ZeroDDP(bert_model, gemini_manager)
bert_model.train()

ckpt_io = GeminiCheckpointIO()
if ckpt_io.coordinator.is_master():
with shared_tempdir() as tempdir:
pretrained_path = os.path.join(tempdir, 'pretrained')
bert_model.config.save_pretrained(save_directory=pretrained_path)

# TODO(ver217): use boost api
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
bert_model = ZeroDDP(bert_model, gemini_manager)
bert_model.train()

ckpt_io = GeminiCheckpointIO()
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
ckpt_io.save_model(bert_model, (model_ckpt_dir.name),
ckpt_io.save_model(bert_model, (pretrained_path),
True,
True,
'', (model_size / 3),
use_safetensors=use_safetensors)
new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name)
check_state_dict_equal(bert_model.state_dict(only_rank_0=True, dtype=(torch.float32)),
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32),
new_bert_model.state_dict(), False)
model_ckpt_dir.cleanup()


@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['gpt2', 'bert'])
@parameterize('use_safetensors', [True, False])
def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, *_ = get_components_func()
with ColoInitContext(device=(get_current_device())):
model = model_builder()
new_model = model_builder()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)

model.train()
#new model
new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100)
new_chunk_manager = ChunkManager(new_config_dict)
new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager)
new_model = ZeroDDP(new_model, new_gemini_manager)

model_ckpt_dir = tempfile.TemporaryDirectory()
ckpt_io = GeminiCheckpointIO()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
ckpt_io.save_model(model, (model_ckpt_dir.name),
True,
True,
'epoch', (model_size / 3),
use_safetensors=use_safetensors)

if ckpt_io.coordinator.is_master():
ckpt_io.load_model(new_model, (model_ckpt_dir.name), strict=True)
model_dict = model.state_dict(only_rank_0=True)
new_model_dict = new_model.state_dict(only_rank_0=True)
check_state_dict_equal(model_dict, new_model_dict, False)
model_ckpt_dir.cleanup()
@parameterize('shard', [True, False])
@parameterize('model_name', ['transformers_gpt'])
def exam_state_dict(placement_policy, shard: bool, model_name: str):
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
plugin = GeminiPlugin(placement_policy=placement_policy)
booster = Booster(plugin=plugin)

model = model_fn()
new_model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)

data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])

booster.backward(loss, optimizer)
optimizer.step()

with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
dist.barrier()

booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
new_model.unwrap().state_dict(only_rank_0=False), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)


def run_dist(rank, world_size, port):
Expand All @@ -92,7 +100,7 @@ def run_dist(rank, world_size, port):


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4, 4])
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
35 changes: 21 additions & 14 deletions tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import tempfile

import pytest
import torch
import torch.distributed as dist
from torchvision.models import resnet18
from utils import shared_tempdir

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroCheckpointIO
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import (
check_state_dict_equal,
Expand All @@ -20,7 +18,8 @@

@clear_cache_before_run()
@parameterize('stage', [2])
def check_low_level_zero_checkpointIO(stage: int):
@parameterize('shard', [True, False])
def check_low_level_zero_checkpointIO(stage: int, shard: bool):
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32)
booster = Booster(plugin=plugin)
model = resnet18()
Expand All @@ -34,17 +33,25 @@ def check_low_level_zero_checkpointIO(stage: int):
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_model(model, model_ckpt_path, shard=shard)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
dist.barrier()

optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
ckpt_io = LowLevelZeroCheckpointIO()
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
new_model = resnet18()
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)

new_model = resnet18()
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
_, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
if ckpt_io.coordinator.is_master():
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)


def run_dist(rank, world_size, port):
Expand Down
11 changes: 2 additions & 9 deletions tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import tempfile

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
from torchvision.models import resnet18
from utils import shared_tempdir

import colossalai
from colossalai.booster import Booster
Expand Down Expand Up @@ -35,11 +34,7 @@ def check_torch_ddp_checkpointIO(shard: bool):
optimizer.step()
scheduler.step()

with tempfile.TemporaryDirectory() as tempdir:
obj = [tempdir]
dist.broadcast_object_list(obj, src=0)
tempdir = obj[0] # use the same directory on all ranks

with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
Expand All @@ -66,8 +61,6 @@ def check_torch_ddp_checkpointIO(shard: bool):
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)

dist.barrier()


def run_dist(rank, world_size, port):
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost')
Expand Down
21 changes: 21 additions & 0 deletions tests/test_checkpoint_io/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import tempfile
from contextlib import contextmanager, nullcontext
from typing import Iterator

import torch.distributed as dist


@contextmanager
def shared_tempdir() -> Iterator[str]:
"""
A temporary directory that is shared across all processes.
"""
ctx_fn = tempfile.TemporaryDirectory if dist.get_rank() == 0 else nullcontext
with ctx_fn() as tempdir:
try:
obj = [tempdir]
dist.broadcast_object_list(obj, src=0)
tempdir = obj[0] # use the same directory on all ranks
yield tempdir
finally:
dist.barrier()