Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 61 additions & 111 deletions colossalai/engine/_base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,55 +12,40 @@
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

from .schedule import BaseSchedule, NoPipelineSchedule


class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
It controls a iteration in training.

:param train_dataloader: Dataloader in training
:param test_dataloader: Dataloader in evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
:param schedule: Running schedule in :meth:`step`
:type train_dataloader: DataLoader, optional
:type test_dataloader: DataLoader, optional
:param step_schedule: Running schedule in :meth:`step`
:param gradient_accumulation: Steps of gradient accumulation
:param gradient_clipping: The norm of gradient clipping
:type model: Module
:type criterion: _Loss, optional
:type optimizer: Optimizer, optional
:type lr_scheduler: _LRScheduler, optional
:type schedule: BaseSchedule, optional
:type optimizer: Optimizer
:type step_schedule: BaseSchedule, optional
:type gradient_accumulation: int, optional
:type gradient_clipping: float, optional
"""

def __init__(self,
train_dataloader: Optional[DataLoader] = None,
test_dataloader: Optional[DataLoader] = None,
model: Module = None,
criterion: _Loss = None,
optimizer: Optimizer = None,
lr_scheduler: Optional[_LRScheduler] = None,
schedule: BaseSchedule = None,
model: Module,
optimizer: Optimizer,
step_schedule: BaseSchedule = None,
gradient_accumulation: int = 1,
lr_scheduler_step: str = 'epoch'):
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
assert model is not None, "Engine requires a model"
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.schedule = schedule if schedule is not None \
gradient_clipping: float = 0.0):
self.schedule = step_schedule if step_schedule is not None \
else NoPipelineSchedule()
self.schedule.initialize(model, optimizer)
self.grad_accum_size = gradient_accumulation
self.grad_accum_step = 0
self.lr_step = 0 # for epoch updating
if lr_scheduler_step != 'epoch':
self.lr_step = 1
self.grad_accum_cur_step = 0
self.grad_clip = gradient_clipping
self.training = True # default
self._logger = get_global_dist_logger()

# build gradient handler
Expand All @@ -72,8 +57,8 @@ def __init__(self,
f'argument gradient_handler_cfg expected type list, ' \
f'but got type {type(gpc.config.gradient_handler)}'
gradient_handler_cfg = gpc.config.gradient_handler
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
self._logger.info(
"Training with zero is detected, ZeROGradientHandler is automatically "
Expand All @@ -92,106 +77,71 @@ def __init__(self,
"to all-reduce the gradients after a training step.",
ranks=[0])
for cfg in gradient_handler_cfg:
handler = build_gradient_handler(cfg, self.model, self.optimizer)
handler = build_gradient_handler(cfg, model, optimizer)
self._gradient_handlers.append(handler)

self.schedule.initialize(self.train_dataloader, self.model,
self.criterion, self.optimizer)
self.schedule.grad_accum = self.grad_accum_size
# add for robustness
if self.optimizer is None:
self.forward_only = True
else:
self.forward_only = False

def handle_gradient(self):
"""Handles all-reduce operations of gradients across different parallel groups.
"""
for handler in self._gradient_handlers:
handler.handle_gradient()

def set_dataloader(self, data: DataLoader, train: bool = True):
"""Sets dataloader in training or evaluation.

:param data: Dataloader to be set
:param train: Set training dataloader if True, otherwise evaluation dataloader
:type data: DataLoader
:type train: bool
"""
if train:
self.train_dataloader = data
else:
self.test_dataloader = data

def get_model(self):
"""Returns the neural network model in the engine.
"""
return self.model

def get_optimizer(self):
"""Returns optimizier in the engine.
"""
return self.optimizer

def get_lr_scheduler(self):
"""Returns the learning rate scheduler in the engine.
"""
return self.lr_scheduler

def train(self):
"""Sets the model to training mode.
"""
self.forward_only = False
self.schedule.train(dataloader=self.train_dataloader, mode=True)
self.training = True

def eval(self):
"""Sets the model to evaluation mode.
"""
self.forward_only = True
self.schedule.train(dataloader=self.test_dataloader, mode=False)

def is_train(self):
"""Returns True if it is in training, otherwise False.
"""
return not self.forward_only

def get_lr(self):
"""Gets current learning rate.
"""
if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']

def step(self, return_loss=True):
self.training = False

def step(self,
data_iter,
model: Module,
criterion: _Loss,
optimizer: Optimizer = None,
is_last_iteration: bool = False,
return_loss=True):
"""A running step based on the schedule. Usually, it runs a training or
evaluation over a batch of dataset.

:param data_iter: Data iterator of the dataset
:param model: The neural network model
:param criterion: Loss function used to calculate
:param optimizer: Optimizer for updating the parameters
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
:param return_loss: loss will be returned if True
:type return_loss: bool
:type data_iter: Iterator
:type model: Module
:type criterion: _Loss
:type optimizer: Optimizer, optional
:type is_last_iteration: bool, optional
:type return_loss: bool, optional
:return: (output, lablel, loss)
"""
if not self.forward_only and self.grad_accum_step == 0:
self.schedule.zero_grad()
if self.training and self.grad_accum_cur_step == 0:
optimizer.zero_grad()

output, label, loss = self.schedule.forward_backward_step(
forward_only=self.forward_only, return_loss=return_loss)

if not self.forward_only:
self.grad_accum_step += 1
if self.grad_accum_step == self.grad_accum_size:
data_iter, model, criterion, optimizer,
forward_only=not self.training,
grad_accum_size=self.grad_accum_size,
return_loss=return_loss)

if self.training:
self.grad_accum_cur_step += 1
if self.grad_accum_cur_step == self.grad_accum_size:
# all reduce gradients
self.handle_gradient()
self.schedule.step()
if self.lr_scheduler is not None and self.lr_step:
self.lr_scheduler.step()
self.grad_accum_step = 0
self.schedule.optimizer_step(model, optimizer, self.grad_clip)
self.grad_accum_cur_step = 0

return output, label, loss
if is_last_iteration:
while True:
try:
trash = next(data_iter)
except StopIteration:
break

def complete(self):
"""Updating after a epoch.
"""
self.schedule.consume_batch()
if self.lr_scheduler is not None and self.lr_step == 0:
self.lr_scheduler.step()
return output, label, loss
Loading