Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,25 @@
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from colossalai.core import global_context as gpc
from colossalai.context.moe_context import MOE_CONTEXT

from colossalai.logging import get_dist_logger

from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from colossalai.engine import Engine
from colossalai.gemini.ophooks import BaseOpHook

from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
from colossalai.utils.moe import sync_moe_model_param

from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.engine.gradient_accumulation import accumulate_gradient

from colossalai.engine.schedule import (
InterleavedPipelineSchedule,
NonPipelineSchedule,
PipelineSchedule,
get_tensor_shape,
)
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer

from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
from colossalai.utils.moe import sync_moe_model_param
from colossalai.zero import convert_to_zero_v2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2

Expand Down Expand Up @@ -301,9 +300,9 @@ def initialize(model: nn.Module,
model = model().to(get_current_device())

# optimizer maybe a optimizer_cls
logger.warning("Initializing an non ZeRO model with optimizer class")
if isinstance(optimizer, Callable):
optimizer = optimizer(model.parameters())
logger.warning("Initializing an non ZeRO model with optimizer class")

if not use_zero:
if is_using_sequence():
Expand Down