Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/zero/legacy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from colossalai.logging import get_dist_logger

from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
from .shard_utils import BucketTensorShardStrategy, TensorShardStrategy
from .sharded_model import ShardedModelV2
from .sharded_optim import ShardedOptimizerV2

Expand Down Expand Up @@ -40,5 +41,5 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model

__all__ = [
'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context',
'no_shard_zero_decrator'
'no_shard_zero_decrator', 'TensorShardStrategy', 'BucketTensorShardStrategy'
]
7 changes: 6 additions & 1 deletion examples/language/roberta/configs/colossalai_ddp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import FusedAdam

try:
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
from colossalai.zero.legacy import TensorShardStrategy

clip_grad_norm = 1.0
9 changes: 7 additions & 2 deletions examples/language/roberta/configs/colossalai_zero.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import FusedAdam

try:
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
from colossalai.zero.legacy import TensorShardStrategy

# fp16 = dict(
# mode=AMP_TYPE.TORCH,
# )
Expand Down Expand Up @@ -29,4 +34,4 @@
weight_decay=1e-2,
)

# 64433
# 64433
6 changes: 5 additions & 1 deletion examples/tutorial/opt/opt/colossalai_zero.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from colossalai.zero.shard_utils import TensorShardStrategy
try:
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
from colossalai.zero.legacy import TensorShardStrategy

zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
tensor_placement_policy="auto",
Expand Down
1 change: 1 addition & 0 deletions examples/tutorial/opt/opt/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ datasets >= 1.8.0
sentencepiece != 0.1.92
protobuf
accelerate == 0.13.2
transformers
6 changes: 5 additions & 1 deletion examples/tutorial/opt/opt/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,11 @@ def main():
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
try:
from colossalai.nn.parallel import GeminiDDP
except ImportError:
# this works for unreleased main branch, and this may be released on 0.2.9
from colossalai.zero import GeminiDDP
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
Expand Down
21 changes: 21 additions & 0 deletions examples/tutorial/opt/opt/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash

set -xue

pip install -r requirements.txt

BS=8
MEMCAP=0
GPUNUM=2
MODLE="facebook/opt-125m"

torchrun \
--nproc_per_node ${GPUNUM} \
--master_port 19198 \
run_clm.py \
-s \
--output_dir $PWD \
--mem_cap ${MEMCAP} \
--model_name_or_path ${MODLE} \
--per_device_train_batch_size ${BS} \
--num_train_epochs 1
3 changes: 3 additions & 0 deletions examples/tutorial/opt/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

cd opt && bash test_ci.sh