Skip to content
Merged

fix #6327

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/test_zero/test_gemini/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from colossalai.accelerator import get_accelerator
from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run
from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
Expand Down Expand Up @@ -53,6 +53,7 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict):
return model


@rerun_if_address_is_in_use()
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("model_init_func", [single_chunk_init, multi_chunk_init])
Expand Down Expand Up @@ -104,6 +105,7 @@ def inference_iter():
train_iter()
inference_iter()
train_iter()
torch.cuda.empty_cache()


def run_dist(rank, world_size, port):
Expand All @@ -112,8 +114,8 @@ def run_dist(rank, world_size, port):


@pytest.mark.dist
@clear_cache_before_run()
@pytest.mark.parametrize("world_size", [1, 4])
@rerun_if_address_is_in_use()
def test_inference(world_size):
spawn(run_dist, world_size)

Expand Down