Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions colossalai/engine/gradient_accumulation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from typing import Iterable, List

import torch.nn as nn
from typing import List
from colossalai.engine import BaseGradientHandler
from typing import Iterable
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler

from colossalai.engine import BaseGradientHandler

from ._gradient_accumulation import (
GradAccumDataloader,
GradAccumGradientHandler,
GradAccumLrSchedulerByStep,
GradAccumOptimizer,
)

__all__ = [
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class BaseGradientHandler(ABC):
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
before optimization.

Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler

from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce


@GRADIENT_HANDLER.register_module
class DataParallelGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in
A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@

import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from ._base_gradient_handler import BaseGradientHandler


@GRADIENT_HANDLER.register_module
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in sub parallel groups.
A all-reduce collective communication will be operated in
A all-reduce collective communication will be operated in
:func:`handle_gradient` among all sub pipeline parallel groups.
For better performance, it bucketizes the gradients of all parameters that are
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.

Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler

from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce


@GRADIENT_HANDLER.register_module
class SequenceParallelGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in
A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.

Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from colossalai.registry import GRADIENT_HANDLER

from ._base_gradient_handler import BaseGradientHandler


Expand Down
2 changes: 1 addition & 1 deletion colossalai/engine/schedule/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._base_schedule import BaseSchedule
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from ._non_pipeline_schedule import NonPipelineSchedule
from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape

__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']
2 changes: 1 addition & 1 deletion colossalai/engine/schedule/_base_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# -*- encoding: utf-8 -*-

from abc import ABC, abstractmethod
from typing import Callable, Iterable

import torch

from typing import Iterable, Callable
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device

Expand Down
9 changes: 5 additions & 4 deletions colossalai/engine/schedule/_non_pipeline_schedule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import Iterable
import inspect
from typing import Callable, Iterable

import torch
import inspect
from ._base_schedule import BaseSchedule

from colossalai.utils import conditional_context
from typing import Callable

from ._base_schedule import BaseSchedule


class NonPipelineSchedule(BaseSchedule):
Expand Down