diff --git a/colossalai/gemini/tensor_placement_policy.py b/colossalai/gemini/tensor_placement_policy.py index cfcfb385667c..0e575254c0b6 100644 --- a/colossalai/gemini/tensor_placement_policy.py +++ b/colossalai/gemini/tensor_placement_policy.py @@ -1,15 +1,15 @@ +import functools from abc import ABC, abstractmethod from time import time -from typing import List, Optional +from typing import List, Optional, Type + import torch -from colossalai.utils import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage -from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.gemini.memory_tracer import MemStatsCollector -from typing import Type -import functools +from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.utils import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity class TensorPlacementPolicy(ABC):