Skip to content
2 changes: 1 addition & 1 deletion colossalai/engine/_base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from colossalai.logging import get_dist_logger
from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.gemini.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from typing import Optional, Type
from colossalai.engine.gradient_handler import BaseGradientHandler
Expand Down
3 changes: 2 additions & 1 deletion colossalai/engine/schedule/_pipeline_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2

from ._base_schedule import BaseSchedule

Expand Down Expand Up @@ -157,6 +156,7 @@ def load_micro_batch(self):
return self._move_to_device(mciro_batch_data)

def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism
model = engine.model
if isinstance(model, NaiveAMPModel):
Expand Down Expand Up @@ -482,6 +482,7 @@ def __init__(self,
self.num_model_chunks = num_model_chunks

def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
if isinstance(engine.model, ShardedModelV2):
self.dtype = torch.half
elif isinstance(engine.model[0], NaiveAMPModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from colossalai.context.parallel_mode import ParallelMode
import torch
from colossalai.engine.ophooks import BaseOpHook
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.registry import OPHOOKS
from colossalai.logging import get_dist_logger
from colossalai.core import global_context as gpc
Expand Down
2 changes: 1 addition & 1 deletion colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook
from colossalai.gemini.ophooks import BaseOpHook

from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
from colossalai.utils.moe import sync_moe_model_param
Expand Down
2 changes: 1 addition & 1 deletion colossalai/utils/profiler/legacy/mem_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Union
from colossalai.engine import Engine
from torch.utils.tensorboard import SummaryWriter
from colossalai.engine.ophooks import MemTracerOpHook
from colossalai.gemini.ophooks import MemTracerOpHook
from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
from typing import List
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.engine.ophooks import BaseOpHook
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.engine import Engine
from colossalai.utils.profiler.extention import ProfilerExtension

Expand Down
4 changes: 2 additions & 2 deletions colossalai/zero/sharded_model/sharded_model_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.gemini.ophooks import register_ophooks_recursively
from colossalai.zero.utils import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.gemini.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, disposable
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector
Expand Down
2 changes: 1 addition & 1 deletion colossalai/zero/utils/zero_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from colossalai.utils import get_current_device

from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.engine.ophooks import BaseOpHook
from colossalai.gemini.ophooks import BaseOpHook

from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.memory_tracer import MemStatsCollector
Expand Down
6 changes: 3 additions & 3 deletions docs/colossalai/colossalai.engine.ophooks.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
colossalai.engine.ophooks
colossalai.gemini.ophooks
=========================

.. automodule:: colossalai.engine.ophooks
.. automodule:: colossalai.gemini.ophooks
:members:


.. toctree::
:maxdepth: 2

colossalai.engine.ophooks.zero_hook
colossalai.gemini.ophooks.zero_hook
4 changes: 2 additions & 2 deletions docs/colossalai/colossalai.engine.ophooks.zero_hook.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
colossalai.engine.ophooks.zero\_hook
colossalai.gemini.ophooks.zero\_hook
====================================

.. automodule:: colossalai.engine.ophooks.zero_hook
.. automodule:: colossalai.gemini.ophooks.zero_hook
:members:
2 changes: 1 addition & 1 deletion docs/colossalai/colossalai.engine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ colossalai.engine
:maxdepth: 2

colossalai.engine.gradient_handler
colossalai.engine.ophooks
colossalai.gemini.ophooks
colossalai.engine.schedule
86 changes: 0 additions & 86 deletions tests/test_engine/test_param_hook.py

This file was deleted.