Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions colossalai/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal
from .comparison import (
assert_close,
assert_close_loose,
assert_equal,
assert_equal_in_group,
assert_not_equal,
check_state_dict_equal,
)
from .pytest_wrapper import run_on_environment_flag
from .utils import (
clear_cache_before_run,
Expand All @@ -13,5 +20,5 @@
__all__ = [
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn',
'clear_cache_before_run', 'run_on_environment_flag'
'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal'
]
24 changes: 24 additions & 0 deletions colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import OrderedDict

import torch
import torch.distributed as dist
from torch import Tensor
Expand Down Expand Up @@ -28,3 +30,25 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
a = tensor_list[i]
b = tensor_list[i + 1]
assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}'


def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
for k, v in d1.items():
if isinstance(v, dict):
check_state_dict_equal(v, d2[k])
elif isinstance(v, list):
for i in range(len(v)):
if isinstance(v[i], torch.Tensor):
if not ignore_device:
v[i] = v[i].to("cpu")
d2[k][i] = d2[k][i].to("cpu")
assert torch.equal(v[i], d2[k][i])
else:
assert v[i] == d2[k][i]
elif isinstance(v, torch.Tensor):
if not ignore_device:
v = v.to("cpu")
d2[k] = d2[k].to("cpu")
assert torch.equal(v, d2[k])
else:
assert v == d2[k]
98 changes: 98 additions & 0 deletions tests/test_checkpoint_io/test_gemini_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import tempfile

import pytest
import torch

import colossalai
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs


@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['bert'])
@parameterize('use_safetensors', [True, False])
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
from transformers import BertForSequenceClassification

model_ckpt_dir = tempfile.TemporaryDirectory()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, *_ = get_components_func()
with ColoInitContext(device=(get_current_device())):
bert_model = model_builder()
bert_model.config.save_pretrained(save_directory=(model_ckpt_dir.name))

config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
bert_model = ZeroDDP(bert_model, gemini_manager)
bert_model.train()

ckpt_io = GeminiCheckpointIO()
if ckpt_io.coordinator.is_master():
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
ckpt_io.save_model(bert_model, (model_ckpt_dir.name),
True,
True,
'', (model_size / 3),
use_safetensors=use_safetensors)
new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name)
check_state_dict_equal(bert_model.state_dict(only_rank_0=True, dtype=(torch.float32)),
new_bert_model.state_dict(), False)
model_ckpt_dir.cleanup()


@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['gpt2', 'bert'])
@parameterize('use_safetensors', [True, False])
def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, *_ = get_components_func()
with ColoInitContext(device=(get_current_device())):
model = model_builder()
new_model = model_builder()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)

model.train()
#new model
new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100)
new_chunk_manager = ChunkManager(new_config_dict)
new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager)
new_model = ZeroDDP(new_model, new_gemini_manager)

model_ckpt_dir = tempfile.TemporaryDirectory()
ckpt_io = GeminiCheckpointIO()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
ckpt_io.save_model(model, (model_ckpt_dir.name),
True,
True,
'epoch', (model_size / 3),
use_safetensors=use_safetensors)

if ckpt_io.coordinator.is_master():
ckpt_io.load_model(new_model, (model_ckpt_dir.name), strict=True)
model_dict = model.state_dict(only_rank_0=True)
new_model_dict = new_model.state_dict(only_rank_0=True)
check_state_dict_equal(model_dict, new_model_dict, False)
model_ckpt_dir.cleanup()


def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
exam_state_dict_with_origin()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4, 4])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
133 changes: 10 additions & 123 deletions tests/test_checkpoint_io/test_general_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import tempfile

import pytest
import torch
from torch.optim import Adam
from torchvision.models import resnet18

from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
from colossalai.testing import clear_cache_before_run, parameterize

import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, parameterize

# ========
# Note:
Expand Down Expand Up @@ -61,10 +54,10 @@ def test_unsharded_checkpoint(use_safetensors: bool):
ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)


# check for model and optimizer state dict recursively
recursive_check(model.state_dict(), new_model.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
check_state_dict_equal(model.state_dict(), new_model.state_dict())
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())


@pytest.mark.parametrize('use_safetensors', [True, False])
def test_sharded_checkpoint(use_safetensors: bool):
Expand All @@ -87,7 +80,7 @@ def test_sharded_checkpoint(use_safetensors: bool):
else:
suffix = ".bin"
WEIGHTS_INDEX_NAME = "model.bin.index.json"

model_ckpt_dir = tempfile.TemporaryDirectory()
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()

Expand All @@ -96,7 +89,7 @@ def test_sharded_checkpoint(use_safetensors: bool):

ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False)

# create new model
new_model = resnet18()
new_optimizer = Adam(new_model.parameters(), lr=0.001)
Expand All @@ -105,111 +98,5 @@ def test_sharded_checkpoint(use_safetensors: bool):
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)

# check for model and optimizer state dict recursively
recursive_check(model.state_dict(), new_model.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())

@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['bert'])
@parameterize('use_safetensors', [True, False])
def hf_load_colossalai_checkpoint(placement_policy, model_name, use_safetensors: bool):
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertForSequenceClassification

model_ckpt_dir = tempfile.TemporaryDirectory()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, *_ = get_components_func()

with ColoInitContext(device=get_current_device()):
bert_model = model_builder()
bert_model.config.save_pretrained(save_directory=model_ckpt_dir.name)
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
bert_model = ZeroDDP(bert_model, gemini_manager)
bert_model.train()

ckpt_io = GeminiCheckpointIO()
if ckpt_io.coordinator.is_master():
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", (model_size / 3), use_safetensors=use_safetensors)
new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name)
recursive_check(bert_model.state_dict(only_rank_0=True, dtype=torch.float32), new_bert_model.state_dict())

model_ckpt_dir.cleanup()



@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['gpt2', 'bert'])
@parameterize('use_safetensors', [True, False])
def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, *_ = get_components_func()

with ColoInitContext(device=get_current_device()):
model = model_builder()
new_model = model_builder()

config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
model.train()

new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100)
new_chunk_manager = ChunkManager(new_config_dict)
new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager)
new_model = ZeroDDP(new_model, new_gemini_manager)

model_ckpt_dir = tempfile.TemporaryDirectory()

ckpt_io = GeminiCheckpointIO()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "epoch", (model_size / 3), use_safetensors=use_safetensors)

# load model
if ckpt_io.coordinator.is_master():
ckpt_io.load_model(new_model, model_ckpt_dir.name, strict=True)
model_dict = model.state_dict(only_rank_0=True)
new_model_dict = new_model.state_dict(only_rank_0=True)
recursive_check(model_dict, new_model_dict)

model_ckpt_dir.cleanup()


def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
hf_load_colossalai_checkpoint()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4, 4])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)


# do recursive check for the optimizer state dict
# if the value is a dict, compare its values
# if the value is a list, comapre all elements one-by-one
# if the value is a torch.Tensor, use torch.equal
# otherwise use assertEqual
def recursive_check(d1, d2):
for k, v in d1.items():
if isinstance(v, dict):
recursive_check(v, d2[k])
elif isinstance(v, list):
for i in range(len(v)):
if isinstance(v[i], torch.Tensor):
v[i] = v[i].to("cpu")
d2[k][i] = d2[k][i].to("cpu")
assert torch.equal(v[i], d2[k][i])
else:
assert v[i] == d2[k][i]
elif isinstance(v, torch.Tensor):
v = v.to("cpu")
d2[k] = d2[k].to("cpu")
assert torch.equal(v, d2[k])
else:
assert v == d2[k]
check_state_dict_equal(model.state_dict(), new_model.state_dict())
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
57 changes: 57 additions & 0 deletions tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import tempfile

import pytest
import torch
from torchvision.models import resnet18

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroCheckpointIO
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import (
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)


@clear_cache_before_run()
@parameterize('stage', [2])
def check_low_level_zero_checkpointIO(stage: int):
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32)
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = HybridAdam((model.parameters()), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
output = model(x)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.step()

optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
ckpt_io = LowLevelZeroCheckpointIO()
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)

if ckpt_io.coordinator.is_master():
new_model = resnet18()
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
_, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)


def run_dist(rank, world_size, port):
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost')
check_low_level_zero_checkpointIO()


@rerun_if_address_is_in_use()
def test_low_level_zero_checkpointIO():
spawn(run_dist, 2)
Loading