Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions colossalai/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def build_optimizer_wrapper(config, optimizer, model=None):
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)


def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
def build_lr_scheduler(config, optimizer):
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.

Expand All @@ -255,8 +255,7 @@ def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
config_ = config.copy()
mod_type = config_.pop('type')
# warmup epochs will overwrite warmup steps
if 'warmup_epochs' in config_:
warmup_epochs = config_.pop('warmup_epochs')
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch,
**config_)
# if 'warmup_epochs' in config_:
# warmup_epochs = config_.pop('warmup_epochs')
# config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)
47 changes: 37 additions & 10 deletions colossalai/engine/_base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def __init__(self,
criterion: _Loss = None,
optimizer: Optimizer = None,
lr_scheduler: Optional[_LRScheduler] = None,
schedule: BaseSchedule = None):
schedule: BaseSchedule = None,
gradient_accumulation: int = 1,
lr_scheduler_step: str = 'epoch'):
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
assert model is not None, "Engine requires a model"
Expand All @@ -54,6 +56,11 @@ def __init__(self,
self.lr_scheduler = lr_scheduler
self.schedule = schedule if schedule is not None \
else NoPipelineSchedule()
self.grad_accum_size = gradient_accumulation
self.grad_accum_step = 0
self.lr_step = 0 # for epoch updating
if lr_scheduler_step != 'epoch':
self.lr_step = 1
self._logger = get_global_dist_logger()

# build gradient handler
Expand Down Expand Up @@ -89,9 +96,13 @@ def __init__(self,
self._gradient_handlers.append(handler)

self.schedule.initialize(self.train_dataloader, self.model,
self.criterion, self.optimizer,
self.lr_scheduler)
self.forward_only = False
self.criterion, self.optimizer)
self.schedule.grad_accum = self.grad_accum_size
# add for robustness
if self.optimizer is None:
self.forward_only = True
else:
self.forward_only = False

def handle_gradient(self):
"""Handles all-reduce operations of gradients across different parallel groups.
Expand All @@ -116,6 +127,7 @@ def get_model(self):
"""Returns the neural network model in the engine.
"""
return self.model

def get_optimizer(self):
"""Returns optimizier in the engine.
"""
Expand Down Expand Up @@ -146,7 +158,10 @@ def is_train(self):
def get_lr(self):
"""Gets current learning rate.
"""
return self.schedule.get_lr()
if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']

def step(self, return_loss=True):
"""A running step based on the schedule. Usually, it runs a training or
Expand All @@ -156,15 +171,27 @@ def step(self, return_loss=True):
:type return_loss: bool
:return: (output, lablel, loss)
"""
self.schedule.zero_grad(forward_only=self.forward_only)
if not self.forward_only and self.grad_accum_step == 0:
self.schedule.zero_grad()

output, label, loss = self.schedule.forward_backward_step(
forward_only=self.forward_only, return_loss=return_loss)

if not self.forward_only:
# all reduce gradients
self.handle_gradient()

self.schedule.step()
self.grad_accum_step += 1
if self.grad_accum_step == self.grad_accum_size:
# all reduce gradients
self.handle_gradient()
self.schedule.step()
if self.lr_scheduler is not None and self.lr_step:
self.lr_scheduler.step()
self.grad_accum_step = 0

return output, label, loss

def complete(self):
"""Updating after a epoch.
"""
self.schedule.consume_batch()
if self.lr_scheduler is not None and self.lr_step == 0:
self.lr_scheduler.step()
26 changes: 11 additions & 15 deletions colossalai/engine/schedule/_base_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class BaseSchedule(ABC):
def __init__(self):
self.initialized = False
self.logger = get_global_dist_logger()
self.grad_accum = 1
self.training = False

@property
@abstractmethod
Expand All @@ -27,15 +29,13 @@ def initialize(self,
dataloader=None,
model=None,
criterion=None,
optimizer=None,
lr_scheduler=None):
optimizer=None):
"""Initializes the schedule and set parameters before running.

:param dataloader: DataLoader in training or evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler in the process
"""
self.dataloader = dataloader
assert model is not None, "Schedule requires a model"
Expand All @@ -44,7 +44,6 @@ def initialize(self,
self.criterion = criterion
assert optimizer is not None, "Schedule requires an optimizer"
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.initialized = True

def check_initialized(self):
Expand All @@ -66,6 +65,13 @@ def load_batch(self):
data, label = next(self.data_iter)
return self._move_to_device(data), self._move_to_device(label)

def consume_batch(self):
while True:
try:
self.load_batch()
except StopIteration:
break

def _move_to_device(self, data):
if isinstance(data, (
tuple,
Expand All @@ -87,6 +93,7 @@ def train(self, dataloader=None, mode=True):
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
"""
self.check_initialized()
self.training = mode
if mode:
self.model.train()
else:
Expand All @@ -102,22 +109,11 @@ def zero_grad(self, forward_only=False):
self.check_initialized()
self.optimizer.zero_grad()

def get_lr(self):
"""Returns the current learning rate.
"""
if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']

def step(self):
"""Updates the parameters and learning rate with the optimizer.
"""
self.check_initialized()
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()

@abstractmethod
def forward_backward_step(self, forward_only=False, return_loss=True):
Expand Down
24 changes: 11 additions & 13 deletions colossalai/engine/schedule/_no_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,20 @@ def __init__(

@property
def num_steps(self):
return len(self.dataloader)
length = len(self.dataloader)
if self.training:
length -= length % self.grad_accum
return length

def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
dataloader=None,
model=None,
criterion=None,
optimizer=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
optimizer)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
self.use_zero_level_2_3 = True
Expand Down Expand Up @@ -147,6 +148,7 @@ def forward_backward_step(self, forward_only=False, return_loss=True):
output = (output,)
if return_loss:
loss = self.criterion(*output, *label)
loss /= self.grad_accum

if not forward_only:
# backward
Expand All @@ -168,7 +170,7 @@ def forward_backward_step(self, forward_only=False, return_loss=True):
loss.backward()

if return_loss:
return output, label, loss
return output, label, loss * self.grad_accum
else:
return output, None, None

Expand All @@ -179,7 +181,3 @@ def step(self):
self._torch_amp_scaler.update()
else:
self.optimizer.step()

# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
21 changes: 11 additions & 10 deletions colossalai/engine/schedule/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,20 @@ def load_micro_batch(self):

@property
def num_steps(self):
return len(self.dataloader)
length = len(self.dataloader)
if self.training:
length -= length % self.grad_accum
return length

def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
dataloader=None,
model=None,
criterion=None,
optimizer=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
optimizer)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
raise TypeError(
Expand Down Expand Up @@ -163,7 +164,7 @@ def forward_step(self, input_tensor, return_tensors, return_loss=True):
if return_loss:
input_tensor, label = self.load_micro_batch()
loss_reduced = self.criterion(output_tensor, *
label) / self.num_microbatches
label) / (self.num_microbatches * self.grad_accum)
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced
Expand Down Expand Up @@ -309,7 +310,7 @@ def forward_backward_step(self, forward_only=True, return_loss=True):
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
sum(loss))
sum(loss) * self.grad_accum)
else:
return tuple((torch.cat(return_tensors, dim=0), None, None))
else:
Expand Down
19 changes: 9 additions & 10 deletions colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,16 +339,15 @@ def initialize(config: Union[str, dict] = None,

lr_scheduler = None
if hasattr(gpc.config, 'lr_scheduler'):
if hasattr(gpc.config, 'num_steps'):
total_steps = gpc.config.num_steps
elif hasattr(gpc.config, 'num_epochs'):
total_steps = int(gpc.config.num_epochs * len(train_dataloader))
else:
raise Exception(
'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer,
total_steps, len(train_dataloader))
# if hasattr(gpc.config, 'num_steps'):
# total_steps = gpc.config.num_steps
# elif hasattr(gpc.config, 'num_epochs'):
# total_steps = int(gpc.config.num_epochs * len(train_dataloader))
# else:
# raise Exception(
# 'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
# )
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer)
logger.info('Learning rate scheduler is created', ranks=[0])

# pipeline or no pipeline schedule
Expand Down
1 change: 1 addition & 0 deletions colossalai/trainer/_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _train_epoch(self, epoch: int = None):
if self.exceed_max_step():
# stop when max iter is reached
break
self._engine.complete()
self._timer.stop('train-epoch', keep_in_history=True)
self.call_hooks('after_train_epoch')
self._timer.reset('train-step')
Expand Down
11 changes: 6 additions & 5 deletions configs/vit/vit_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
IMG_SIZE = 32
PATCH_SIZE = 4
DIM = 512
NUM_ATTENTION_HEADS = 8
NUM_ATTENTION_HEADS = 2
SUMMA_DIM = 2
NUM_CLASSES = 10
DEPTH = 6
DEPTH = 1

train_data = dict(
dataset=dict(
Expand Down Expand Up @@ -127,14 +127,14 @@
dict(type='LogMetricByEpochHook'),
dict(type='Accuracy2DHook'),
dict(type='LossHook'),
dict(type='TensorboardHook', log_dir='./tfb_logs'),
# dict(type='TensorboardHook', log_dir='./tfb_logs'),
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
]

parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
tensor=dict(size=1, mode='2d'),
)

# for fp16 training
Expand All @@ -146,7 +146,8 @@

lr_scheduler = dict(
type='LinearWarmupLR',
warmup_epochs=5
total_steps=60,
warmup_steps=5
)

# only needed when pipeline parallel is used
Expand Down
3 changes: 2 additions & 1 deletion examples/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def run_trainer():
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
schedule=schedule,
gradient_accumulation=5,
)
logger.info("engine is built", ranks=[0])

Expand Down