Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,11 @@ In this way, users can train their models as usual.
In our latest example, a Gemini + ZeRO DDP model is also defined to reduce overhead and improve efficiency.For the details of this part, please refer to [ZeRO](../features/zero_with_chunk.md). You can combine these two parts to understand our entire training process:

```python
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=32)
return model
Expand Down
8 changes: 4 additions & 4 deletions docs/source/en/features/zero_with_chunk.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,23 +185,23 @@ def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
Define a model which uses Gemini + ZeRO DDP:

```python
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
cai_version = colossalai.__version__
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placememt_policy))
init_device=GeminiManager.get_default_device(placement_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ for mn, module in model.named_modules():
在我们最新示例中还定义了一个Gemini + ZeRO DDP 的模型从而减小开销,提升效率。这一部分的详细内容可以参考[ZeRO](../features/zero_with_chunk.md),你可以将这两部分内容结合起来看从而理解我们整个训练流程:

```python
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=32)
return model
Expand Down
8 changes: 4 additions & 4 deletions docs/source/zh-Hans/features/zero_with_chunk.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,23 +185,23 @@ def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
定义一个使用 Gemini + ZeRO DDP 的模型:

```python
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
cai_version = colossalai.__version__
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placememt_policy))
init_device=GeminiManager.get_default_device(placement_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
Expand Down
4 changes: 2 additions & 2 deletions examples/images/dreambooth/train_dreambooth_colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,12 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:


# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"):
from colossalai.nn.parallel import GeminiDDP

model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=64)
return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,12 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:


# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"):
from colossalai.nn.parallel import GeminiDDP

model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=64)
return model
Expand Down
8 changes: 4 additions & 4 deletions examples/language/palm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,23 @@ def get_model_size(model: nn.Module):


# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
cai_version = colossalai.__version__
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placememt_policy))
init_device=GeminiManager.get_default_device(placement_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
Expand Down