diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index 2c64d8bd0e97..3e87864c8a37 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -12,8 +12,6 @@ from torch.nn import Module from torch.nn.modules.loss import _Loss from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.data import DataLoader from .schedule import BaseSchedule, NoPipelineSchedule @@ -21,46 +19,33 @@ class Engine: """Basic engine class for training and evaluation. It runs a specific process method :meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset. + It controls a iteration in training. - :param train_dataloader: Dataloader in training - :param test_dataloader: Dataloader in evaluation :param model: The neural network model - :param criterion: Criterion for calculating loss :param optimizer: Optimizer for updating the parameters - :param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation - :param schedule: Running schedule in :meth:`step` - :type train_dataloader: DataLoader, optional - :type test_dataloader: DataLoader, optional + :param step_schedule: Running schedule in :meth:`step` + :param gradient_accumulation: Steps of gradient accumulation + :param gradient_clipping: The norm of gradient clipping :type model: Module - :type criterion: _Loss, optional - :type optimizer: Optimizer, optional - :type lr_scheduler: _LRScheduler, optional - :type schedule: BaseSchedule, optional + :type optimizer: Optimizer + :type step_schedule: BaseSchedule, optional + :type gradient_accumulation: int, optional + :type gradient_clipping: float, optional """ + def __init__(self, - train_dataloader: Optional[DataLoader] = None, - test_dataloader: Optional[DataLoader] = None, - model: Module = None, - criterion: _Loss = None, - optimizer: Optimizer = None, - lr_scheduler: Optional[_LRScheduler] = None, - schedule: BaseSchedule = None, + model: Module, + optimizer: Optimizer, + step_schedule: BaseSchedule = None, gradient_accumulation: int = 1, - lr_scheduler_step: str = 'epoch'): - self.train_dataloader = train_dataloader - self.test_dataloader = test_dataloader - assert model is not None, "Engine requires a model" - self.model = model - self.criterion = criterion - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - self.schedule = schedule if schedule is not None \ + gradient_clipping: float = 0.0): + self.schedule = step_schedule if step_schedule is not None \ else NoPipelineSchedule() + self.schedule.initialize(model, optimizer) self.grad_accum_size = gradient_accumulation - self.grad_accum_step = 0 - self.lr_step = 0 # for epoch updating - if lr_scheduler_step != 'epoch': - self.lr_step = 1 + self.grad_accum_cur_step = 0 + self.grad_clip = gradient_clipping + self.training = True # default self._logger = get_global_dist_logger() # build gradient handler @@ -72,8 +57,8 @@ def __init__(self, f'argument gradient_handler_cfg expected type list, ' \ f'but got type {type(gpc.config.gradient_handler)}' gradient_handler_cfg = gpc.config.gradient_handler - elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, - ZeroRedundancyOptimizer_Level_3)): + elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, + ZeroRedundancyOptimizer_Level_3)): gradient_handler_cfg = [dict(type='ZeROGradientHandler')] self._logger.info( "Training with zero is detected, ZeROGradientHandler is automatically " @@ -92,106 +77,71 @@ def __init__(self, "to all-reduce the gradients after a training step.", ranks=[0]) for cfg in gradient_handler_cfg: - handler = build_gradient_handler(cfg, self.model, self.optimizer) + handler = build_gradient_handler(cfg, model, optimizer) self._gradient_handlers.append(handler) - self.schedule.initialize(self.train_dataloader, self.model, - self.criterion, self.optimizer) - self.schedule.grad_accum = self.grad_accum_size - # add for robustness - if self.optimizer is None: - self.forward_only = True - else: - self.forward_only = False - def handle_gradient(self): """Handles all-reduce operations of gradients across different parallel groups. """ for handler in self._gradient_handlers: handler.handle_gradient() - def set_dataloader(self, data: DataLoader, train: bool = True): - """Sets dataloader in training or evaluation. - - :param data: Dataloader to be set - :param train: Set training dataloader if True, otherwise evaluation dataloader - :type data: DataLoader - :type train: bool - """ - if train: - self.train_dataloader = data - else: - self.test_dataloader = data - - def get_model(self): - """Returns the neural network model in the engine. - """ - return self.model - - def get_optimizer(self): - """Returns optimizier in the engine. - """ - return self.optimizer - - def get_lr_scheduler(self): - """Returns the learning rate scheduler in the engine. - """ - return self.lr_scheduler - def train(self): """Sets the model to training mode. """ - self.forward_only = False - self.schedule.train(dataloader=self.train_dataloader, mode=True) + self.training = True def eval(self): """Sets the model to evaluation mode. """ - self.forward_only = True - self.schedule.train(dataloader=self.test_dataloader, mode=False) - - def is_train(self): - """Returns True if it is in training, otherwise False. - """ - return not self.forward_only - - def get_lr(self): - """Gets current learning rate. - """ - if self.lr_scheduler is not None: - return self.lr_scheduler.get_lr()[0] - else: - return self.optimizer.param_groups[0]['lr'] - - def step(self, return_loss=True): + self.training = False + + def step(self, + data_iter, + model: Module, + criterion: _Loss, + optimizer: Optimizer = None, + is_last_iteration: bool = False, + return_loss=True): """A running step based on the schedule. Usually, it runs a training or evaluation over a batch of dataset. + :param data_iter: Data iterator of the dataset + :param model: The neural network model + :param criterion: Loss function used to calculate + :param optimizer: Optimizer for updating the parameters + :param is_last_iteration: If True, this iteration is the last iteration in the epoch :param return_loss: loss will be returned if True - :type return_loss: bool + :type data_iter: Iterator + :type model: Module + :type criterion: _Loss + :type optimizer: Optimizer, optional + :type is_last_iteration: bool, optional + :type return_loss: bool, optional :return: (output, lablel, loss) """ - if not self.forward_only and self.grad_accum_step == 0: - self.schedule.zero_grad() + if self.training and self.grad_accum_cur_step == 0: + optimizer.zero_grad() output, label, loss = self.schedule.forward_backward_step( - forward_only=self.forward_only, return_loss=return_loss) - - if not self.forward_only: - self.grad_accum_step += 1 - if self.grad_accum_step == self.grad_accum_size: + data_iter, model, criterion, optimizer, + forward_only=not self.training, + grad_accum_size=self.grad_accum_size, + return_loss=return_loss) + + if self.training: + self.grad_accum_cur_step += 1 + if self.grad_accum_cur_step == self.grad_accum_size: # all reduce gradients self.handle_gradient() - self.schedule.step() - if self.lr_scheduler is not None and self.lr_step: - self.lr_scheduler.step() - self.grad_accum_step = 0 + self.schedule.optimizer_step(model, optimizer, self.grad_clip) + self.grad_accum_cur_step = 0 - return output, label, loss + if is_last_iteration: + while True: + try: + trash = next(data_iter) + except StopIteration: + break - def complete(self): - """Updating after a epoch. - """ - self.schedule.consume_batch() - if self.lr_scheduler is not None and self.lr_step == 0: - self.lr_scheduler.step() + return output, label, loss diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index c1454275ad63..1a1ef6e0416f 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -5,121 +5,87 @@ import torch +from colossalai.core import global_context as gpc from colossalai.logging import get_global_dist_logger from colossalai.utils import get_current_device class BaseSchedule(ABC): """A basic helper class to control the process of training or evaluation. + It mainly composes of forward_backward_step for gradient backward and + optimizer_step for parameters update. + For the convenience to enable FP16, we aggreate all codes that contain the + control of FP16 in class schedule. """ + def __init__(self): - self.initialized = False self.logger = get_global_dist_logger() self.grad_accum = 1 self.training = False - @property - @abstractmethod - def num_steps(self): - """The number of batches in training or evaluation. - """ - pass - - def initialize(self, - dataloader=None, - model=None, - criterion=None, - optimizer=None): - """Initializes the schedule and set parameters before running. + @staticmethod + def _move_tensor(element): + if torch.is_tensor(element): + if not element.is_cuda: + return element.to(get_current_device()).detach() + return element - :param dataloader: DataLoader in training or evaluation - :param model: The neural network model - :param criterion: Criterion for calculating loss - :param optimizer: Optimizer for updating the parameters - """ - self.dataloader = dataloader - assert model is not None, "Schedule requires a model" - self.model = model - assert criterion is not None, "Schedule requires a criterion" - self.criterion = criterion - assert optimizer is not None, "Schedule requires an optimizer" - self.optimizer = optimizer - self.initialized = True - - def check_initialized(self): - """Checks whether the schedule is initialized. - """ - assert self.initialized, \ - 'Schedule is not initialized. Call schedule.initialize(...) before using it.' + def _move_to_device(self, data): + if isinstance(data, (tuple, list)): + data = tuple([self._move_tensor(d) for d in data]) + elif torch.is_tensor(data): + data = data.to(get_current_device()).detach() + return data - def load_batch(self): - """Loads a batch of dataset. It returns the data and labels which are + def load_batch(self, data_iter): + """Loads a batch from data iterator. It returns the data and labels which are already in the same GPU as where the model's. :return: (data, label) - :rtype: (Tensor, Tensor) + :rtype: (Tensor, Tensor) """ - self.check_initialized() - if self.data_iter is None: + if data_iter is None: raise RuntimeError('Dataloader is not defined.') - data, label = next(self.data_iter) + data, label = next(data_iter) return self._move_to_device(data), self._move_to_device(label) - def consume_batch(self): - while True: - try: - self.load_batch() - except StopIteration: - break + def initialize(self, model, optimizer): + """Initializes the model and the optimizer before training. + This is often used in FP16 training. - def _move_to_device(self, data): - if isinstance(data, ( - tuple, - list, - )): - data = tuple([ - d.to(get_current_device()).detach() for d in data - if torch.is_tensor(d) - ]) - elif torch.is_tensor(data): - data = data.to(get_current_device()).detach() - return data - - def train(self, dataloader=None, mode=True): - """Sets the dataloader to be used and turn the model to - training or evaluation mode. - - :param dataloader: Dataloader to be used - :param mode: If True, the model will set as training mode. Otherwise, evaluation mode. - """ - self.check_initialized() - self.training = mode - if mode: - self.model.train() - else: - self.model.eval() - if dataloader is not None: - self.dataloader = dataloader - self.data_iter = iter(dataloader) - - def zero_grad(self, forward_only=False): - """Cleans gradients with the optimizer. + :param model: The neural network model + :param optimizer: Optimizer for updating the parameters """ - if not forward_only: - self.check_initialized() - self.optimizer.zero_grad() + return model, optimizer - def step(self): - """Updates the parameters and learning rate with the optimizer. + @abstractmethod + def forward_backward_step(self, + data_iter, + model, + criterion, + optimizer=None, + forward_only=False, + grad_accum_size: int = 1, + return_loss=True): + """The process function over a batch of dataset for training or evaluation. + + :param data_iter: Data iterator of the dataset + :param model: Model used in training or evaluation + :param optimizer: Optimizer used in training or evaluation + :param criterion: Loss function + :param forward_only: If True, the process won't include backward + :param grad_accum_size: Steps of gradient accumulation + :param return_loss: If False, the loss won't be returned """ - self.check_initialized() - self.optimizer.step() + pass @abstractmethod - def forward_backward_step(self, forward_only=False, return_loss=True): - """The process function over a batch of dataset for training or evaluation. + def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0): + """Updates the parameters with the optimizer. - :param forward_only: If True, the process won't include backward. - :param return_loss: If False, the loss won't be returned. + :param model: The neural network model + :param optimizer: Optimizer for updating the parameters + :param grad_clipping: The norm of gradient clipping + :type grad_clipping: float, optional """ pass diff --git a/colossalai/engine/schedule/_no_pipeline.py b/colossalai/engine/schedule/_no_pipeline.py index 899f03ab4079..7f62475c4d39 100644 --- a/colossalai/engine/schedule/_no_pipeline.py +++ b/colossalai/engine/schedule/_no_pipeline.py @@ -10,8 +10,6 @@ except: print('PyTorch amp is not supported with the current PyTorch version') -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3) from colossalai.nn.optimizer._utils import clip_grad_norm_fp32 @@ -75,35 +73,29 @@ def __init__( self.fp16 = False self.amp_type = None - @property - def num_steps(self): - length = len(self.dataloader) - if self.training: - length -= length % self.grad_accum - return length - - def initialize(self, - dataloader=None, - model=None, - criterion=None, - optimizer=None): - super().initialize(dataloader, - model, - criterion, - optimizer) - if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, - ZeroRedundancyOptimizer_Level_3)): + def initialize(self, model, optimizer): + if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, + ZeroRedundancyOptimizer_Level_3)): self.use_zero_level_2_3 = True - assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL' + assert self.amp_type != AMP_TYPE.PARALLEL, \ + 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL' if self.fp16: if self.amp_type == AMP_TYPE.TORCH: self._torch_amp_scaler = GradScaler(**self.amp_cfg) elif self.amp_type == AMP_TYPE.APEX: - self.model, self.optimizer = apex_amp.initialize( - self.model, self.optimizer, **self.amp_cfg) - - def forward_backward_step(self, forward_only=False, return_loss=True): + model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg) + + return model, optimizer + + def forward_backward_step(self, + data_iter, + model, + criterion, + optimizer=None, + forward_only=False, + grad_accum_size: int = 1, + return_loss=True): """The process function that loads loads a batch of dataset and feeds it to the model. The returned labels and loss will None if :attr:`return_loss` is False. @@ -112,22 +104,22 @@ def forward_backward_step(self, forward_only=False, return_loss=True): assert forward_only or return_loss, \ 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' - data, label = self.load_batch() + data, label = self.load_batch(data_iter) loss = None # forward if self.fp16 and self.amp_type == AMP_TYPE.TORCH: with torch_amp.autocast(): - output = self.model(*data) + output = model(*data) if not isinstance(output, (tuple, list)): output = (output,) if return_loss: - loss = self.criterion(*output, *label) + loss = criterion(*output, *label) else: if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL: data = convert_to_fp16(data) - output = self.model(*data) + output = model(*data) if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL: output = convert_to_fp32(output) @@ -135,44 +127,42 @@ def forward_backward_step(self, forward_only=False, return_loss=True): if not isinstance(output, (tuple, list)): output = (output,) if return_loss: - loss = self.criterion(*output, *label) - loss /= self.grad_accum + loss = criterion(*output, *label) + + loss /= grad_accum_size if not forward_only: # backward if self.use_zero_level_2_3: - self.optimizer.backward(loss) + optimizer.backward(loss) elif self.fp16: if self.amp_type == AMP_TYPE.APEX: - with apex_amp.scale_loss(loss, - self.optimizer) as scaled_loss: + with apex_amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() elif self.amp_type == AMP_TYPE.TORCH: self._torch_amp_scaler.scale(loss).backward() elif self.amp_type == AMP_TYPE.PARALLEL: - loss = self.optimizer.scale_loss(loss) + loss = optimizer.scale_loss(loss) loss.backward() # scale back to display the original value in logs - loss.div_(self.optimizer.grad_scaler.scale) + loss.div_(optimizer.grad_scaler.scale) else: loss.backward() if return_loss: - return output, label, loss * self.grad_accum + return output, label, loss * grad_accum_size else: return output, None, None - def step(self): + def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0): # step optimizer if self.fp16 and self.amp_type == AMP_TYPE.TORCH: - if getattr(gpc.config, 'clip_grad', 0.0) > 0.0: - self._torch_amp_scaler.unscale_(self.optimizer) - clip_grad_norm_fp32(self.model.parameters(), - gpc.config.clip_grad) - self._torch_amp_scaler.step(self.optimizer) + if grad_clipping > 0.0: + self._torch_amp_scaler.unscale_(optimizer) + clip_grad_norm_fp32(model.parameters(), grad_clipping) + self._torch_amp_scaler.step(optimizer) self._torch_amp_scaler.update() else: - if not self.fp16 and not self.use_zero_level_2_3 and getattr(gpc.config, 'clip_grad', 0.0) > 0.0: - clip_grad_norm_fp32(self.model.parameters(), - gpc.config.clip_grad) - self.optimizer.step() + if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0: + clip_grad_norm_fp32(model.parameters(), grad_clipping) + optimizer.step() diff --git a/colossalai/engine/schedule/_pipeline.py b/colossalai/engine/schedule/_pipeline.py index 4b625882de96..6defea93d57a 100644 --- a/colossalai/engine/schedule/_pipeline.py +++ b/colossalai/engine/schedule/_pipeline.py @@ -93,12 +93,11 @@ def _sync_data(self): ) # Pipeline schedule just puts data in memory - def load_batch(self): - self.check_initialized() - if self.data_iter is None: + def load_batch(self, data_iter): + if data_iter is None: raise RuntimeError('Dataloader is not defined.') self.batch_pos = 0 - data, label = next(self.data_iter) + data, label = next(data_iter) self.batch_data, self.batch_label = \ self._move_to_device(data), self._move_to_device(label) batch_size = self.batch_data.shape[0] @@ -117,24 +116,8 @@ def load_micro_batch(self): self.batch_pos += self.microbatch_size return (data,), (label,) - @property - def num_steps(self): - length = len(self.dataloader) - if self.training: - length -= length % self.grad_accum - return length - - def initialize(self, - dataloader=None, - model=None, - criterion=None, - optimizer=None): - super().initialize(dataloader, - model, - criterion, - optimizer) - if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, - ZeroRedundancyOptimizer_Level_3)): + def initialize(self, model, optimizer): + if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)): raise TypeError( "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" ) @@ -146,7 +129,8 @@ def initialize(self, 'default tensor dtype is set to torch.half for fp16 training', ranks=[0]) - def forward_step(self, input_tensor, return_tensors, return_loss=True): + def forward_step(self, model, criterion, input_tensor, return_tensors, + grad_accum_size, return_loss=True): """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_tensor is used. Returns output tensor. This is a helper function and can be ignored by users. @@ -157,13 +141,14 @@ def forward_step(self, input_tensor, return_tensors, return_loss=True): if self.amp_type == AMP_TYPE.PARALLEL: input_tensor = convert_to_fp16(input_tensor) input_tensor = squeeze(input_tensor) - output_tensor = self.model(input_tensor) + output_tensor = model(input_tensor) output_tensor = squeeze(output_tensor) if gpc.is_last_rank(ParallelMode.PIPELINE): if return_loss: input_tensor, label = self.load_micro_batch() - loss_reduced = self.criterion(output_tensor, *label) / (self.num_microbatches * self.grad_accum) + loss_reduced = criterion(output_tensor, *label) \ + / (self.num_microbatches * grad_accum_size) return_tensors.append( tuple((output_tensor, label[0], loss_reduced))) return loss_reduced @@ -174,7 +159,7 @@ def forward_step(self, input_tensor, return_tensors, return_loss=True): else: return output_tensor - def backward_step(self, input_tensor, output_tensor, output_tensor_grad): + def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad): """Backward step through the passed-in output tensor. If it is the last stage, the output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor. Returns the gradients with respect to the input tensor (None if first stage). @@ -187,7 +172,7 @@ def backward_step(self, input_tensor, output_tensor, output_tensor_grad): # Backward pass. if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL: - output_tensor = self.optimizer.scale_loss(output_tensor) + output_tensor = optimizer.scale_loss(output_tensor) torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) # Collect the grad of the input_tensor. @@ -197,7 +182,14 @@ def backward_step(self, input_tensor, output_tensor, output_tensor_grad): return input_tensor_grad - def forward_backward_step(self, forward_only=True, return_loss=True): + def forward_backward_step(self, + data_iter, + model, + criterion, + optimizer=None, + forward_only=False, + grad_accum_size: int = 1, + return_loss=True): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. @@ -207,7 +199,7 @@ def forward_backward_step(self, forward_only=True, return_loss=True): assert forward_only or return_loss, \ 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' - self.load_batch() + self.load_batch(data_iter) num_warmup_microbatches = \ (gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) @@ -233,9 +225,11 @@ def forward_backward_step(self, forward_only=True, return_loss=True): if not gpc.is_first_rank(ParallelMode.PIPELINE): ft_shape = recv_tensor_meta(ft_shape) input_tensor = recv_forward(ft_shape) - output_tensor = self.forward_step(input_tensor, - return_tensors, - return_loss=return_loss) + output_tensor = self.forward_step( + model, criterion, + input_tensor, return_tensors, + grad_accum_size, return_loss=return_loss + ) if not gpc.is_last_rank(ParallelMode.PIPELINE): bt_shape = output_tensor.shape fs_checker = send_tensor_meta(output_tensor, fs_checker) @@ -257,9 +251,11 @@ def forward_backward_step(self, forward_only=True, return_loss=True): for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) - output_tensor = self.forward_step(input_tensor, - return_tensors, - return_loss=return_loss) + output_tensor = self.forward_step( + model, criterion, + input_tensor, return_tensors, + grad_accum_size, return_loss=return_loss + ) if forward_only: send_forward(output_tensor) @@ -279,9 +275,11 @@ def forward_backward_step(self, forward_only=True, return_loss=True): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) - input_tensor_grad = self.backward_step(input_tensor, - output_tensor, - output_tensor_grad) + input_tensor_grad = self.backward_step( + optimizer, + input_tensor, output_tensor, + output_tensor_grad + ) if last_iteration: input_tensor = None @@ -298,9 +296,11 @@ def forward_backward_step(self, forward_only=True, return_loss=True): output_tensor_grad = recv_backward(bt_shape) - input_tensor_grad = self.backward_step(input_tensor, - output_tensor, - output_tensor_grad) + input_tensor_grad = self.backward_step( + optimizer, + input_tensor, output_tensor, + output_tensor_grad + ) send_backward(input_tensor_grad) @@ -309,8 +309,11 @@ def forward_backward_step(self, forward_only=True, return_loss=True): output, label, loss = tuple(map(list, zip(*return_tensors))) return (torch.cat(output, dim=0), torch.cat(label, dim=0), - sum(loss) * self.grad_accum) + sum(loss) * grad_accum_size) else: return tuple((torch.cat(return_tensors, dim=0), None, None)) else: return tuple((None, None, None)) + + def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0): + optimizer.step()