Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions colossalai/booster/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,52 @@

__all__ = ['Accelerator']

_supported_devices = [
'cpu',
'cuda',

# To be supported
# 'xpu',
# 'npu',
# 'tpu',
]


class Accelerator:
"""
Accelerator is an abstraction for the hardware device that is used to run the model.

Args:
device (str): The device to be used. Currently only support 'cpu' and 'gpu'.
"""

def __init__(self, device: torch.device):
def __init__(self, device: str):
self.device = device

def setup_model(self, model: nn.Module) -> nn.Module:
# TODO: implement this method
pass
assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"

def bind(self):
"""
Set the default device for the current process.
"""
if self.device == 'cpu':
pass
elif self.device == 'cuda':
# TODO(FrankLeeeee): use global environment to check if it is a dist job
# if is_distributed:
# local_rank = EnvTable().get_local_rank()
# torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
torch.cuda.set_device(torch.device('cuda'))
pass
else:
raise ValueError(f"Device {self.device} is not supported yet")

def configure_model(self, model: nn.Module) -> nn.Module:
"""
Move the model to the device.

Args:
model (nn.Module): The model to be moved.
"""
model = model.to(torch.device(self.device))
return model
15 changes: 14 additions & 1 deletion colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader

from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin

Expand Down Expand Up @@ -51,9 +52,16 @@ class Booster:
"""

def __init__(self,
device: Union[str, torch.device] = 'cuda',
device: str = 'cuda',
mixed_precision: Union[MixedPrecision, str] = None,
plugin: Optional[Plugin] = None) -> None:
# TODO(FrankLeeeee): add plugin control logic
# if self.plugin is not None and self.plugin.control_accelerator:
# ...
# create acclerator
self.acceleartor = Accelerator(device)
self.acceleartor.set_default_device()

# validate and set precision
if isinstance(MixedPrecision, str):
# the user will take the default arguments for amp training
Expand All @@ -78,6 +86,11 @@ def boost(self, model: nn.Module, optimizer: Optimizer, criterion: Callable, lr_
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
dataloader (DataLoader): The dataloader to be boosted.
"""
# TODO(FrankLeeeee): add plugin control logic
# if self.plugin is not None and self.plugin.control_accelerator:
# ...
model = self.acceleartor.configure_model(model)

# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(lsg): Add plugin control logic
# e.g.
Expand Down
13 changes: 13 additions & 0 deletions tests/test_booster/test_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
import torch.nn as nn
from torchvision.models import resnet18

from colossalai.booster.accelerator import Accelerator


@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_accelerator(device):
acceleartor = Accelerator(device)
model = nn.Linear(8, 8)
model = acceleartor.configure_model(model)
assert next(model.parameters()).device.type == device
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_torchrec_dlrm_models():
data = data_gen_fn()

# dlrm_interactionarch is not supported
# TODO(FrankLeeeee): support this model
if name == 'dlrm_interactionarch':
continue

Expand Down