Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,4 +279,4 @@ def autocast(
"""
Return autocast function
"""
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
2 changes: 1 addition & 1 deletion colossalai/kernel/jit/option.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch

from colossalai.accelerator import get_accelerator
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear

from .bias_dropout_add import bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl
Expand Down Expand Up @@ -45,6 +44,7 @@ def warmup_jit_fusion(
dtype: torch.dtype = torch.float32,
):
"""Compile JIT functions before the main training steps"""
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear

embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
Expand Down
10 changes: 8 additions & 2 deletions colossalai/pipeline/schedule/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import torch
import torch.cuda
from packaging.version import Version
from torch.nn import Module
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten


# this register are for torch under version 1.13.1, maybe removed in the future
Expand All @@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
return OrderedDict((key, value) for key, value in zip(context, values))


_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
if Version(torch.__version__) <= Version("1.13.1"):
try:
from torch.utils._pytree import register_pytree_node as _register_pytree_node
except ImportError:
from torch.utils._pytree import _register_pytree_node
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)


def tree_map_hf(fn: Any, pytree: Any):
Expand Down
11 changes: 6 additions & 5 deletions colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import torch.nn

from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float

Expand All @@ -27,6 +22,12 @@ class RuntimeMemTracer:

def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
super().__init__()
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)

self.module = module
self.dtype = dtype
self._gradstat = GradMemStats()
Expand Down
3 changes: 2 additions & 1 deletion colossalai/zero/gemini/placement_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.distributed as dist

from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk

from .chunk import Chunk, ChunkManager
Expand Down Expand Up @@ -172,6 +171,8 @@ def evict_tensors(
Returns:
int: the volume of memory that is evicted
"""
from colossalai.legacy.utils.memory import colo_device_memory_capacity

start = time()
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan)
AMP stands for automatic mixed precision training.
In Colossal-AI, we have incorporated different implementations of mixed precision training:

1. torch.cuda.amp
1. torch.amp
2. apex.amp
3. naive amp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
AMP 代表自动混合精度训练。
在 Colossal-AI 中, 我们结合了混合精度训练的不同实现:

1. torch.cuda.amp
1. torch.amp
2. apex.amp
3. naive amp

Expand Down