Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def __init__(self,
def boost(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
Expand Down
8 changes: 5 additions & 3 deletions colossalai/booster/mixed_precision/fp16_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@ def __init__(self,

def configure(self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if criterion is not None:
criterion = TorchAMPModule(criterion)
return model, optimizer, criterion
7 changes: 4 additions & 3 deletions colossalai/booster/mixed_precision/mixed_precision_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Tuple
from typing import Callable, Optional, Tuple

import torch.nn as nn
from torch.optim import Optimizer
Expand All @@ -15,7 +15,8 @@ class MixedPrecision(ABC):
@abstractmethod
def configure(self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
# TODO: implement this method
pass
18 changes: 11 additions & 7 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,11 @@ def supported_devices(self) -> List[str]:
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:

if not isinstance(model, ModelWrapper):
# convert model to sync bn
Expand All @@ -293,8 +293,12 @@ def configure(
# wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose)

if not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose)

return model, optimizer, criterion, dataloader, lr_scheduler
Expand Down
18 changes: 11 additions & 7 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,21 @@ def supported_devices(self) -> List[str]:
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:

if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision)

if not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose)

return model, optimizer, criterion, dataloader, lr_scheduler
Expand Down
12 changes: 6 additions & 6 deletions colossalai/booster/plugin/plugin_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch.nn as nn
from torch.optim import Optimizer
Expand Down Expand Up @@ -38,11 +38,11 @@ def support_no_sync(self) -> bool:
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# implement this method
pass

Expand Down
13 changes: 7 additions & 6 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def supported_devices(self) -> List[str]:
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# cast model to cuda
model = model.cuda()

Expand All @@ -152,7 +152,8 @@ def configure(
# wrap the model with PyTorch DDP
model = TorchDDPModel(model, **self.ddp_kwargs)

if not isinstance(optimizer, OptimizerWrapper):
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)

return model, optimizer, criterion, dataloader, lr_scheduler
Expand Down
25 changes: 13 additions & 12 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,23 +195,24 @@ def supported_devices(self) -> List[str]:
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:

# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)

if len(optimizer.param_groups) > 1:
warnings.warn(
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
if optimizer is not None:
if len(optimizer.param_groups) > 1:
warnings.warn(
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)

if not isinstance(optimizer, FSDPOptimizerWrapper):
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
if not isinstance(optimizer, FSDPOptimizerWrapper):
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)

return fsdp_model, optimizer, criterion, dataloader, lr_scheduler

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import torch

try:
from diffusers import UNet2DModel
MODELS = [UNet2DModel]
import diffusers
MODELS = [diffusers.UNet2DModel]
HAS_REPO = True
from packaging import version
SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2")
except:
MODELS = []
HAS_REPO = False
SKIP_UNET_TEST = False

from test_autochunk_diffuser_utils import run_test

Expand All @@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]:
return meta_args, concrete_args


@pytest.mark.skipif(
SKIP_UNET_TEST,
reason="diffusers version > 0.10.2",
)
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
Expand Down