Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ jobs:
detect:
name: Detect file change
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
outputs:
changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }}
anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }}
Expand All @@ -27,10 +27,10 @@ jobs:
- name: Locate base commit
id: locate-base-sha
run: |
curBranch=$(git rev-parse --abbrev-ref HEAD)
commonCommit=$(git merge-base origin/main $curBranch)
echo $commonCommit
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
curBranch=$(git rev-parse --abbrev-ref HEAD)
commonCommit=$(git merge-base origin/main $curBranch)
echo $commonCommit
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT

- name: Find the changed extension-related files
id: find-extension-change
Expand Down Expand Up @@ -63,7 +63,6 @@ jobs:
echo "$file was changed"
done


build:
name: Build and Test Colossal-AI
needs: detect
Expand Down Expand Up @@ -124,7 +123,7 @@ jobs:
- name: Execute Unit Testing
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
run: |
PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/
env:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
Expand Down
8 changes: 2 additions & 6 deletions applications/Chat/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import os
import tempfile
from contextlib import nullcontext
from functools import partial

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from coati.models.gpt import GPTActor
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config

from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn

GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)

Expand Down Expand Up @@ -90,8 +87,7 @@ def run_dist(rank, world_size, port, strategy):
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size, strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, strategy=strategy)


if __name__ == '__main__':
Expand Down
8 changes: 2 additions & 6 deletions applications/Chat/tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import os
from copy import deepcopy
from functools import partial

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic
from coati.replay_buffer import NaiveReplayBuffer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config

from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn

GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)

Expand Down Expand Up @@ -114,8 +111,7 @@ def run_dist(rank, world_size, port, strategy):
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
@rerun_if_address_is_in_use()
def test_data(world_size, strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, strategy=strategy)


if __name__ == '__main__':
Expand Down
3 changes: 2 additions & 1 deletion colossalai/cli/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from colossalai.context.random import reset_seeds
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import MultiTimer, free_port
from colossalai.testing import free_port
from colossalai.utils import MultiTimer

from .models import MLP

Expand Down
16 changes: 13 additions & 3 deletions colossalai/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group
from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use, skip_if_not_enough_gpus
from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal
from .pytest_wrapper import run_on_environment_flag
from .utils import (
clear_cache_before_run,
free_port,
parameterize,
rerun_if_address_is_in_use,
rerun_on_exception,
skip_if_not_enough_gpus,
spawn,
)

__all__ = [
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus'
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn',
'clear_cache_before_run', 'run_on_environment_flag'
]
88 changes: 81 additions & 7 deletions colossalai/testing/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import gc
import random
import re
import torch
from typing import Callable, List, Any
import socket
from functools import partial
from inspect import signature
from typing import Any, Callable, List

import torch
import torch.multiprocessing as mp
from packaging import version


Expand Down Expand Up @@ -43,7 +48,7 @@ def say_something(person, msg):
# > davis: hello
# > davis: bye
# > davis: stop

Args:
argument (str): the name of the argument to parameterize
values (List[Any]): a list of values to iterate for this argument
Expand Down Expand Up @@ -85,13 +90,13 @@ def test_method():
def test_method():
print('hey')
raise RuntimeError('Address already in use')

# rerun for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, max_try=None)
def test_method():
print('hey')
raise RuntimeError('Address already in use')

# rerun only the exception message is matched with pattern
# for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$")
Expand All @@ -101,10 +106,10 @@ def test_method():

Args:
exception_type (Exception, Optional): The type of exception to detect for rerun
pattern (str, Optional): The pattern to match the exception message.
pattern (str, Optional): The pattern to match the exception message.
If the pattern is not None and matches the exception message,
the exception will be detected for rerun
max_try (int, Optional): Maximum reruns for this function. The default value is 5.
max_try (int, Optional): Maximum reruns for this function. The default value is 5.
If max_try is None, it will rerun foreven if exception keeps occurings
"""

Expand Down Expand Up @@ -202,3 +207,72 @@ def _execute_by_gpu_num(*args, **kwargs):
return _execute_by_gpu_num

return _wrap_func


def free_port() -> int:
"""Get a free port on localhost.

Returns:
int: A free port on localhost.
"""
while True:
port = random.randint(20000, 65000)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue


def spawn(func, nprocs=1, **kwargs):
"""
This function is used to spawn processes for testing.

Usage:
# must contians arguments rank, world_size, port
def do_something(rank, world_size, port):
...

spawn(do_something, nprocs=8)

# can also pass other arguments
def do_something(rank, world_size, port, arg1, arg2):
...

spawn(do_something, nprocs=8, arg1=1, arg2=2)

Args:
func (Callable): The function to be spawned.
nprocs (int, optional): The number of processes to spawn. Defaults to 1.
"""
port = free_port()
wrapped_func = partial(func, world_size=nprocs, port=port, **kwargs)
mp.spawn(wrapped_func, nprocs=nprocs)


def clear_cache_before_run():
"""
This function is a wrapper to clear CUDA and python cache before executing the function.

Usage:
@clear_cache_before_run()
def test_something():
...
"""

def _wrap_func(f):

def _clear_cache(*args, **kwargs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_max_memory_cached()
torch.cuda.synchronize()
gc.collect()
f(*args, **kwargs)

return _clear_cache

return _wrap_func
2 changes: 0 additions & 2 deletions colossalai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
count_zeros_fp32,
disposable,
ensure_path_exists,
free_port,
is_ddp_ignored,
is_dp_rank_0,
is_model_parallel_parameter,
Expand Down Expand Up @@ -37,7 +36,6 @@

__all__ = [
'checkpoint',
'free_port',
'print_rank_0',
'sync_model_param',
'is_ddp_ignored',
Expand Down
17 changes: 0 additions & 17 deletions colossalai/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,6 @@ def ensure_path_exists(filename: str):
Path(dirpath).mkdir(parents=True, exist_ok=True)


def free_port() -> int:
"""Get a free port on localhost.

Returns:
int: A free port on localhost.
"""
while True:
port = random.randint(20000, 65000)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue


def sync_model_param(model, parallel_mode):
r"""Make sure data parameters are consistent during Data Parallel Mode.

Expand Down
1 change: 1 addition & 0 deletions docs/requirements-doc-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ packaging
tensornvme
psutil
transformers
pytest
7 changes: 3 additions & 4 deletions docs/source/en/basics/colotensor_concept.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp
```python
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port, print_rank_0
from colossalai.utils import print_rank_0
from functools import partial

import colossalai
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
from colossalai.utils import free_port
from colossalai.testing import spawn

import torch

Expand All @@ -83,8 +83,7 @@ def run_dist_tests(rank, world_size, port):
print_rank_0(f"shape {t1.shape}, {t1.data}")

def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist_tests, world_size)

if __name__ == '__main__':
test_dist_cases(4)
Expand Down
7 changes: 3 additions & 4 deletions docs/source/zh-Hans/basics/colotensor_concept.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs.
```python
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port, print_rank_0
from colossalai.utils import print_rank_0
from functools import partial

import colossalai
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
from colossalai.utils import free_port
from colossalai.testing import spawn

import torch

Expand All @@ -84,8 +84,7 @@ def run_dist_tests(rank, world_size, port):
print_rank_0(f"shape {t1.shape}, {t1.data}")

def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist_tests, world_size)

if __name__ == '__main__':
test_dist_cases(4)
Expand Down
8 changes: 2 additions & 6 deletions examples/images/vit/test_vit.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
import random
from functools import partial

import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from vit import get_training_components

Expand All @@ -15,8 +13,7 @@
from colossalai.core import global_context as gpc
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext

Expand Down Expand Up @@ -156,8 +153,7 @@ def run_dist(rank, world_size, port, use_ddp):
@pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use()
def test_vit(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, use_ddp=use_ddp)


if __name__ == '__main__':
Expand Down
Loading