Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,17 @@ class Booster:
```

Args:
device (str or torch.device): The device to run the training. Default: 'cuda'.
device (str or torch.device): The device to run the training. Default: None.
If plugin is not used or plugin doesn't control the device,
this argument will be set as training device ('cuda' will be used if argument is None).
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
plugin (Plugin): The plugin to run the training. Default: None.
"""

def __init__(self,
device: str = 'cuda',
device: Optional[str] = None,
mixed_precision: Union[MixedPrecision, str] = None,
plugin: Optional[Plugin] = None) -> None:
if plugin is not None:
Expand All @@ -68,13 +70,16 @@ def __init__(self,
# set accelerator
if self.plugin and self.plugin.control_device():
self.accelerator = None
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
if device is not None:
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
else:
device = device or 'cuda'
self.accelerator = Accelerator(device)

# set precision
if self.plugin and self.plugin.control_precision():
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
if mixed_precision is not None:
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
Expand Down Expand Up @@ -146,7 +151,7 @@ def execute_pipeline(self,
data_iter: Iterator,
model: nn.Module,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optimizer,
optimizer: Optional[Optimizer] = None,
return_loss: bool = True,
return_outputs: bool = False) -> dict:
# run pipeline forward backward pass
Expand Down
6 changes: 3 additions & 3 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,15 +443,15 @@ def execute_pipeline(self,
data_iter: Iterator,
model: HybridParallelModule,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
HybridParallelZeroOptimizer],
optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
HybridParallelZeroOptimizer]] = None,
return_loss: bool = True,
return_outputs: bool = False) -> dict:
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
# return loss or outputs if needed
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx:
outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss,
outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss,
return_outputs)
model.sync_shared_params()
if isinstance(optimizer, HybridParallelZeroOptimizer):
Expand Down
4 changes: 2 additions & 2 deletions colossalai/booster/plugin/pp_plugin_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Any, Callable, Iterator
from typing import Any, Callable, Iterator, Optional

import torch

Expand All @@ -15,7 +15,7 @@ def execute_pipeline(self,
data_iter: Iterator,
model: ModelWrapper,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: OptimizerWrapper,
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = True,
return_outputs: bool = False) -> dict:
pass
6 changes: 3 additions & 3 deletions colossalai/pipeline/schedule/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Iterable
from typing import Any, Callable, Iterable, Optional

from torch import Tensor
from torch.nn import Module
Expand All @@ -14,18 +14,18 @@ def __init__(self, stage_manager: PipelineStageManager) -> None:

def forward_backward_step(self,
model: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[[Any, Any], Tensor],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Forward and backward step for pipeline training.

Args:
model (Module): Model to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.

Expand Down
6 changes: 4 additions & 2 deletions colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,25 +237,27 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],

def forward_backward_step(self,
model_chunk: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.

Args:
model_chunk (List[Module]): Model Chunk to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.

Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."

self.load_batch(data_iter)
num_model_chunks = len(model_chunk)
Expand Down
6 changes: 4 additions & 2 deletions colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,25 +210,27 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],

def forward_backward_step(self,
model: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.

Args:
model (Module): Model to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.

Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."

self.load_batch(data_iter)

Expand Down
10 changes: 2 additions & 8 deletions examples/language/bert/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def move_to_cuda(batch):
@torch.no_grad()
def evaluate_model(
model: nn.Module,
optimizer,
criterion,
test_dataloader: Union[DataLoader, List[DataLoader]],
num_labels: int,
Expand All @@ -71,12 +70,7 @@ def evaluate_subset(dataloader: DataLoader):
current_rank = dist.get_rank()
#TODO pass dataloader to execute_pipeline directly
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
criterion,
optimizer,
return_loss=True,
return_outputs=True)
outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)

if booster.plugin.stage_manager.is_last_stage():
val_loss = outputs["loss"]
Expand Down Expand Up @@ -304,7 +298,7 @@ def _criterion(outputs, inputs):
for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)

results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task,
results = evaluate_model(model, _criterion, test_dataloader, data_builder.num_labels, args.task,
data_builder.eval_splits, booster, coordinator)

if coordinator.is_master():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pipeline/test_schedule/test_interleaved.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def examine_pp(num_micro_batches):
torch_loss.backward()

pp_ret = schedule.forward_backward_step(sharded_model,
pp_optimizer,
iter(input_list),
criterion,
pp_optimizer,
return_loss=True,
return_outputs=True)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_pipeline/test_schedule/test_oneF_oneB.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def examine_pp():
torch_loss.backward()

pp_ret = schedule.forward_backward_step(sharded_model,
pp_optimizer,
iter(input_list),
criterion,
pp_optimizer,
return_loss=True,
return_outputs=True)

Expand Down