Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats

Expand All @@ -37,6 +38,7 @@
DP_AXIS = 0
TP_AXIS = 1


def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.
Expand All @@ -53,6 +55,8 @@ def get_param_info(optim: Optimizer):
start_index += len(group["params"])

return param_info


class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -359,6 +363,8 @@ def __init__(
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if IS_NPU_AVAILABLE:
assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
Expand Down Expand Up @@ -432,7 +438,7 @@ def control_device(self) -> bool:
return True

def supported_devices(self) -> List[str]:
return ["cuda"]
return ["cuda", "npu"]

def configure(
self,
Expand Down Expand Up @@ -480,4 +486,4 @@ def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()

def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError
raise NotImplementedError
198 changes: 139 additions & 59 deletions colossalai/utils/device.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import Optional
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -44,84 +44,164 @@ def get_current_device() -> torch.device:
return torch.device("cpu")


def synchronize():
"""Similar to cuda.synchronize().
Waits for all kernels in all streams on a CUDA device to complete.
"""
def _dispatch_device_func(fn_name: str, *args, **kwargs):
if torch.cuda.is_available():
torch.cuda.synchronize()
return getattr(torch.cuda, fn_name)(*args, **kwargs)
elif IS_NPU_AVAILABLE:
torch.npu.synchronize()
return getattr(torch.npu, fn_name)(*args, **kwargs)
else:
raise RuntimeError("No device available")


# device semantics


def can_device_access_peer(device, peer_device) -> bool:
return _dispatch_device_func("can_device_access_peer", device, peer_device)


def current_device() -> int:
return _dispatch_device_func("current_device")


def current_stream(device=None):
return _dispatch_device_func("current_stream", device)


def default_stream(device=None):
return _dispatch_device_func("default_stream", device)


def device_count() -> int:
if torch.cuda.is_available():
return torch.cuda.device_count()
elif IS_NPU_AVAILABLE:
return torch.npu.device_count()
else:
raise RuntimeError("No device available")
return _dispatch_device_func("device_count")


def empty_cache():
"""Similar to cuda.empty_cache()
Releases all unoccupied cached memory currently held by the caching allocator.
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif IS_NPU_AVAILABLE:
torch.npu.empty_cache()
def get_device_capability(device=None) -> Tuple[int, int]:
return _dispatch_device_func("get_device_capability", device)


def get_device_name(device=None) -> str:
return _dispatch_device_func("get_device_name", device)


def get_device_properties(device):
return _dispatch_device_func("get_device_properties", device)


def set_device(index: Optional[int] = None) -> None:
if index is None:
index = dist.get_rank() % device_count()
if torch.cuda.is_available():
torch.cuda.set_device(index)
elif IS_NPU_AVAILABLE:
torch.npu.set_device(index)
_dispatch_device_func("set_device", index)


def Stream(device=None, priority=0, **kwargs):
if torch.cuda.is_available():
return torch.cuda.Stream(device, priority, **kwargs)
elif IS_NPU_AVAILABLE:
return torch.npu.Stream(device, priority, **kwargs)
else:
raise RuntimeError("No device available")
def set_stream(stream_):
return _dispatch_device_func("set_stream", stream_)


def stream(stream_):
if torch.cuda.is_available():
return torch.cuda.stream(stream_)
elif IS_NPU_AVAILABLE:
return torch.npu.stream(stream_)
else:
raise RuntimeError("No device available")
return _dispatch_device_func("stream", stream_)


def set_stream(stream_):
if torch.cuda.is_available():
return torch.cuda.set_stream(stream_)
elif IS_NPU_AVAILABLE:
return torch.npu.set_stream(stream_)
else:
raise RuntimeError("No device available")
def synchronize():
return _dispatch_device_func("synchronize")


def current_stream(device=None):
if torch.cuda.is_available():
return torch.cuda.current_stream(device)
elif IS_NPU_AVAILABLE:
return torch.npu.current_stream(device)
else:
raise RuntimeError("No device available")
def utilization(device=None) -> int:
return _dispatch_device_func("utilization", device)


def default_stream(device=None):
if torch.cuda.is_available():
return torch.cuda.default_stream(device)
elif IS_NPU_AVAILABLE:
return torch.npu.default_stream(device)
else:
raise RuntimeError("No device available")
# random number generator


def get_rng_state(device="cuda") -> torch.Tensor:
return _dispatch_device_func("get_rng_state", device)


def get_rng_state_all() -> List[torch.Tensor]:
return _dispatch_device_func("get_rng_state_all")


def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None:
return _dispatch_device_func("set_rng_state", new_state, device)


def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None:
return _dispatch_device_func("set_rng_state_all", new_states)


def manual_seed(seed: int) -> None:
return _dispatch_device_func("manual_seed", seed)


def manual_seed_all(seed: int) -> None:
return _dispatch_device_func("manual_seed_all", seed)


def seed() -> None:
return _dispatch_device_func("seed")


def seed_all() -> None:
return _dispatch_device_func("seed_all")


def initial_seed() -> int:
return _dispatch_device_func("initial_seed")


# streams and events


def Stream(device=None, priority=0, **kwargs):
return _dispatch_device_func("Stream", device, priority, **kwargs)


def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
return _dispatch_device_func("Event", enable_timing, blocking, interprocess)


# memory management


def empty_cache() -> None:
return _dispatch_device_func("empty_cache")


def memory_stats(device=None) -> Dict[str, Any]:
return _dispatch_device_func("memory_stats", device)


def memory_summary(device=None, abbreviated=False) -> str:
return _dispatch_device_func("memory_summary", device, abbreviated)


def memory_snapshot():
return _dispatch_device_func("memory_snapshot")


def memory_allocated(device=None) -> int:
return _dispatch_device_func("memory_allocated", device)


def max_memory_allocated(device=None) -> int:
return _dispatch_device_func("max_memory_allocated", device)


def reset_max_memory_allocated(device=None) -> None:
return _dispatch_device_func("reset_max_memory_allocated", device)


def memory_reserved(device=None) -> int:
return _dispatch_device_func("memory_reserved", device)


def max_memory_reserved(device=None) -> int:
return _dispatch_device_func("max_memory_reserved", device)


def set_per_process_memory_fraction(fraction: float, device=None) -> None:
return _dispatch_device_func("set_per_process_memory_fraction", fraction, device)


def reset_peak_memory_stats(device=None) -> None:
return _dispatch_device_func("reset_peak_memory_stats", device)
15 changes: 7 additions & 8 deletions colossalai/zero/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.distributed import ProcessGroup

from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE


class TensorState(Enum):
Expand Down Expand Up @@ -169,7 +170,7 @@ def memory_usage(self) -> Dict[str, int]:

if self.chunk_temp is not None:
# this chunk is not closed
if self.chunk_temp.device.type == "cuda":
if self.chunk_temp.device.type == "cuda" or self.chunk_temp.device.type == "npu":
cuda_memory += self.chunk_mem
else:
cpu_memory += self.chunk_mem
Expand All @@ -188,10 +189,8 @@ def device_type(self) -> str:
if self.chunk_temp is not None:
return self.chunk_temp.device.type
else:
if self.is_gathered:
return "cuda"
elif self.cuda_shard is not None:
return "cuda"
if self.is_gathered or self.cuda_shard is not None:
return "npu" if IS_NPU_AVAILABLE else "cuda"
else:
return "cpu"

Expand Down Expand Up @@ -326,12 +325,12 @@ def shard_move(self, device: torch.device, force_copy: bool = False):
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if not self.optim_sync_flag:
assert device.type == "cuda", "each chunk should first be moved to CUDA"
assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA"
self.__paired_shard_move()
self.optim_sync_flag = True
return

if device.type == "cuda":
if device.type == "cuda" or device.type == "npu":
assert device == get_current_device(), "can't move chunk to another device"

if self.cuda_shard:
Expand Down Expand Up @@ -475,7 +474,7 @@ def optim_update(self) -> None:
assert friend_chunk.is_gathered is True
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
self.optim_sync_flag = True
elif friend_chunk.device_type == "cuda" and self.device_type == "cuda":
elif friend_chunk.device_type in ("cuda", "npu") and self.device_type in ("cuda", "npu"):
self.cuda_shard.copy_(friend_chunk.cuda_shard)
self.optim_sync_flag = True
self.cpu_vis_flag = False
Expand Down
5 changes: 4 additions & 1 deletion colossalai/zero/gemini/chunk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
device_type = tensor.device.type
if device_type == "npu":
device_type = "cuda"
self.total_mem[device_type] += tensor.numel() * tensor.element_size()

def __repr__(self) -> str:
msg = [
Expand Down
Loading