From 7cd753c9b6fcd79882e1551d4739a521401916d8 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 15 Nov 2023 18:33:03 +0800 Subject: [PATCH 1/3] [npu] refactor device utils --- colossalai/utils/device.py | 198 ++++++++++++++++++++++++++----------- 1 file changed, 139 insertions(+), 59 deletions(-) diff --git a/colossalai/utils/device.py b/colossalai/utils/device.py index 9b78f881b3bc..e1bd20d59dac 100644 --- a/colossalai/utils/device.py +++ b/colossalai/utils/device.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Optional +from typing import Any, Dict, List, Optional, Tuple import torch import torch.distributed as dist @@ -44,84 +44,164 @@ def get_current_device() -> torch.device: return torch.device("cpu") -def synchronize(): - """Similar to cuda.synchronize(). - Waits for all kernels in all streams on a CUDA device to complete. - """ +def _dispatch_device_func(fn_name: str, *args, **kwargs): if torch.cuda.is_available(): - torch.cuda.synchronize() + return getattr(torch.cuda, fn_name)(*args, **kwargs) elif IS_NPU_AVAILABLE: - torch.npu.synchronize() + return getattr(torch.npu, fn_name)(*args, **kwargs) + else: + raise RuntimeError("No device available") + + +# device semantics + + +def can_device_access_peer(device, peer_device) -> bool: + return _dispatch_device_func("can_device_access_peer", device, peer_device) + + +def current_device() -> int: + return _dispatch_device_func("current_device") + + +def current_stream(device=None): + return _dispatch_device_func("current_stream", device) + + +def default_stream(device=None): + return _dispatch_device_func("default_stream", device) def device_count() -> int: - if torch.cuda.is_available(): - return torch.cuda.device_count() - elif IS_NPU_AVAILABLE: - return torch.npu.device_count() - else: - raise RuntimeError("No device available") + return _dispatch_device_func("device_count") -def empty_cache(): - """Similar to cuda.empty_cache() - Releases all unoccupied cached memory currently held by the caching allocator. - """ - if torch.cuda.is_available(): - torch.cuda.empty_cache() - elif IS_NPU_AVAILABLE: - torch.npu.empty_cache() +def get_device_capability(device=None) -> Tuple[int, int]: + return _dispatch_device_func("get_device_capability", device) + + +def get_device_name(device=None) -> str: + return _dispatch_device_func("get_device_name", device) + + +def get_device_properties(device): + return _dispatch_device_func("get_device_properties", device) def set_device(index: Optional[int] = None) -> None: if index is None: index = dist.get_rank() % device_count() - if torch.cuda.is_available(): - torch.cuda.set_device(index) - elif IS_NPU_AVAILABLE: - torch.npu.set_device(index) + _dispatch_device_func("set_device", index) -def Stream(device=None, priority=0, **kwargs): - if torch.cuda.is_available(): - return torch.cuda.Stream(device, priority, **kwargs) - elif IS_NPU_AVAILABLE: - return torch.npu.Stream(device, priority, **kwargs) - else: - raise RuntimeError("No device available") +def set_stream(stream_): + return _dispatch_device_func("set_stream", stream_) def stream(stream_): - if torch.cuda.is_available(): - return torch.cuda.stream(stream_) - elif IS_NPU_AVAILABLE: - return torch.npu.stream(stream_) - else: - raise RuntimeError("No device available") + return _dispatch_device_func("stream", stream_) -def set_stream(stream_): - if torch.cuda.is_available(): - return torch.cuda.set_stream(stream_) - elif IS_NPU_AVAILABLE: - return torch.npu.set_stream(stream_) - else: - raise RuntimeError("No device available") +def synchronize(): + return _dispatch_device_func("synchronize") -def current_stream(device=None): - if torch.cuda.is_available(): - return torch.cuda.current_stream(device) - elif IS_NPU_AVAILABLE: - return torch.npu.current_stream(device) - else: - raise RuntimeError("No device available") +def utilization(device=None) -> int: + return _dispatch_device_func("utilization", device) -def default_stream(device=None): - if torch.cuda.is_available(): - return torch.cuda.default_stream(device) - elif IS_NPU_AVAILABLE: - return torch.npu.default_stream(device) - else: - raise RuntimeError("No device available") +# random number generator + + +def get_rng_state(device="cuda") -> torch.Tensor: + return _dispatch_device_func("get_rng_state", device) + + +def get_rng_state_all() -> List[torch.Tensor]: + return _dispatch_device_func("get_rng_state_all") + + +def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None: + return _dispatch_device_func("set_rng_state", new_state, device) + + +def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None: + return _dispatch_device_func("set_rng_state_all", new_states) + + +def manual_seed(seed: int) -> None: + return _dispatch_device_func("manual_seed", seed) + + +def manual_seed_all(seed: int) -> None: + return _dispatch_device_func("manual_seed_all", seed) + + +def seed() -> None: + return _dispatch_device_func("seed") + + +def seed_all() -> None: + return _dispatch_device_func("seed_all") + + +def initial_seed() -> int: + return _dispatch_device_func("initial_seed") + + +# streams and events + + +def Stream(device=None, priority=0, **kwargs): + return _dispatch_device_func("Stream", device, priority, **kwargs) + + +def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + return _dispatch_device_func("Event", enable_timing, blocking, interprocess) + + +# memory management + + +def empty_cache() -> None: + return _dispatch_device_func("empty_cache") + + +def memory_stats(device=None) -> Dict[str, Any]: + return _dispatch_device_func("memory_stats", device) + + +def memory_summary(device=None, abbreviated=False) -> str: + return _dispatch_device_func("memory_summary", device, abbreviated) + + +def memory_snapshot(): + return _dispatch_device_func("memory_snapshot") + + +def memory_allocated(device=None) -> int: + return _dispatch_device_func("memory_allocated", device) + + +def max_memory_allocated(device=None) -> int: + return _dispatch_device_func("max_memory_allocated", device) + + +def reset_max_memory_allocated(device=None) -> None: + return _dispatch_device_func("reset_max_memory_allocated", device) + + +def memory_reserved(device=None) -> int: + return _dispatch_device_func("memory_reserved", device) + + +def max_memory_reserved(device=None) -> int: + return _dispatch_device_func("max_memory_reserved", device) + + +def set_per_process_memory_fraction(fraction: float, device=None) -> None: + return _dispatch_device_func("set_per_process_memory_fraction", fraction, device) + + +def reset_peak_memory_stats(device=None) -> None: + return _dispatch_device_func("reset_peak_memory_stats", device) From 4b222ac238dc12dcf53259e33fb84288514dc8db Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 15 Nov 2023 18:35:11 +0800 Subject: [PATCH 2/3] [gemini] support npu --- colossalai/booster/plugin/gemini_plugin.py | 10 +++- colossalai/zero/gemini/chunk/chunk.py | 15 +++-- colossalai/zero/gemini/chunk/manager.py | 5 +- colossalai/zero/gemini/gemini_ddp.py | 66 ++++++++++++++-------- colossalai/zero/gemini/gemini_mgr.py | 6 +- colossalai/zero/gemini/gemini_optimizer.py | 53 +++++++++-------- 6 files changed, 90 insertions(+), 65 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 9c7dc6836c1e..ceace61f00cd 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -24,6 +24,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.utils import get_current_device +from colossalai.utils.device import IS_NPU_AVAILABLE from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -37,6 +38,7 @@ DP_AXIS = 0 TP_AXIS = 1 + def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: # 1. A mapping from integer param_id to param32 shape. @@ -53,6 +55,8 @@ def get_param_info(optim: Optimizer): start_index += len(group["params"]) return param_info + + class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() @@ -359,6 +363,8 @@ def __init__( ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" + if IS_NPU_AVAILABLE: + assert placement_policy == "static", "NPU only supports static placement policy" self.gemini_config = dict( chunk_config_dict=chunk_config_dict, chunk_init_device=(chunk_init_device or get_current_device()), @@ -432,7 +438,7 @@ def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ["cuda"] + return ["cuda", "npu"] def configure( self, @@ -480,4 +486,4 @@ def get_checkpoint_io(self) -> CheckpointIO: return GeminiCheckpointIO() def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 4ea6cc662025..d32b2349a95c 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -7,6 +7,7 @@ from torch.distributed import ProcessGroup from colossalai.utils import get_current_device +from colossalai.utils.device import IS_NPU_AVAILABLE class TensorState(Enum): @@ -169,7 +170,7 @@ def memory_usage(self) -> Dict[str, int]: if self.chunk_temp is not None: # this chunk is not closed - if self.chunk_temp.device.type == "cuda": + if self.chunk_temp.device.type == "cuda" or self.chunk_temp.device.type == "npu": cuda_memory += self.chunk_mem else: cpu_memory += self.chunk_mem @@ -188,10 +189,8 @@ def device_type(self) -> str: if self.chunk_temp is not None: return self.chunk_temp.device.type else: - if self.is_gathered: - return "cuda" - elif self.cuda_shard is not None: - return "cuda" + if self.is_gathered or self.cuda_shard is not None: + return "npu" if IS_NPU_AVAILABLE else "cuda" else: return "cpu" @@ -326,12 +325,12 @@ def shard_move(self, device: torch.device, force_copy: bool = False): # when the current chunk is not synchronized with the optimizer # just use another way for the movement if not self.optim_sync_flag: - assert device.type == "cuda", "each chunk should first be moved to CUDA" + assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA" self.__paired_shard_move() self.optim_sync_flag = True return - if device.type == "cuda": + if device.type == "cuda" or device.type == "npu": assert device == get_current_device(), "can't move chunk to another device" if self.cuda_shard: @@ -475,7 +474,7 @@ def optim_update(self) -> None: assert friend_chunk.is_gathered is True self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk) self.optim_sync_flag = True - elif friend_chunk.device_type == "cuda" and self.device_type == "cuda": + elif friend_chunk.device_type in ("cuda", "npu") and self.device_type in ("cuda", "npu"): self.cuda_shard.copy_(friend_chunk.cuda_shard) self.optim_sync_flag = True self.cpu_vis_flag = False diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index d3c512fe978d..64a9166ade1d 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -204,7 +204,10 @@ def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. """ assert tensor not in self.tensor_chunk_map - self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() + device_type = tensor.device.type + if device_type == "npu": + device_type = "cuda" + self.total_mem[device_type] += tensor.numel() * tensor.element_size() def __repr__(self) -> str: msg = [ diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index ade0a4909902..31ffde84bb1c 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -10,32 +10,30 @@ from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group -from colossalai.checkpoint_io.utils import StateDictSharder +from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored -from colossalai.checkpoint_io.utils import gather_distributed_param - -from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager -from .gemini_hook import GeminiZeROHook -from .gemini_mgr import GeminiManager -from .memory_tracer import MemStats, OrderedParamGenerator -from .utils import get_temp_total_chunk_on_cuda - from colossalai.tensor.d_tensor import ( distribute_tensor, distribute_tensor_with_customization, - init_tensor_as_customization_distributed, get_device_mesh, + get_global_shape, get_sharding_spec, + init_as_dtensor, + init_tensor_as_customization_distributed, is_customized_distributed_tensor, is_distributed_tensor, - get_global_shape, - init_as_dtensor ) +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored + +from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager +from .gemini_hook import GeminiZeROHook +from .gemini_mgr import GeminiManager +from .memory_tracer import MemStats, OrderedParamGenerator +from .utils import get_temp_total_chunk_on_cuda try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys @@ -160,7 +158,7 @@ def __init__( self._init_chunks( param_order=param_order, strict_ddp_mode=strict_ddp_mode, - cpu_offload=self.gemini_manager.policy_name != "cuda", + cpu_offload=not (self.gemini_manager.policy_name == "static" and offload_param_frac == 0), pin_memory=pin_memory, ) super().__init__(module) @@ -447,12 +445,13 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: global_shape = get_global_shape(tensor) device_mesh = get_device_mesh(tensor) shard_spec = get_sharding_spec(tensor) - record_tensor = init_as_dtensor(record_tensor, - device_mesh=device_mesh, - sharding_spec=shard_spec, - global_shape = global_shape) + record_tensor = init_as_dtensor( + record_tensor, device_mesh=device_mesh, sharding_spec=shard_spec, global_shape=global_shape + ) elif is_customized_distributed_tensor(tensor): - init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn) + init_tensor_as_customization_distributed( + record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn + ) record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() assert tensor not in chunk_to_save_data @@ -628,7 +627,15 @@ def _load_from_state_dict( local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) local_state = {k: v for k, v in local_name_params if v is not None} - def load(param_name, dest_tensor, copy_func, source_device_mesh=None, source_sharding_spec=None, shard_fn=None, gather_fn=None): + def load( + param_name, + dest_tensor, + copy_func, + source_device_mesh=None, + source_sharding_spec=None, + shard_fn=None, + gather_fn=None, + ): state_key = prefix + param_name if state_key in state_dict: input_param = state_dict[state_key] @@ -636,7 +643,9 @@ def load(param_name, dest_tensor, copy_func, source_device_mesh=None, source_sha if source_device_mesh is not None and source_sharding_spec is not None: input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) elif shard_fn is not None and gather_fn is not None: - input_param = distribute_tensor_with_customization(input_param, shard_fn=shard_fn, gather_fn=gather_fn) + input_param = distribute_tensor_with_customization( + input_param, shard_fn=shard_fn, gather_fn=gather_fn + ) # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: @@ -681,7 +690,6 @@ def load_parameter(chunk_slice, data): temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) for tensor, tensor_info in chunk.tensors_info.items(): - source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None if is_distributed_tensor(tensor): # shard the input param @@ -693,7 +701,15 @@ def load_parameter(chunk_slice, data): parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor] parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] - load(parameter_name, tensor, partial(load_parameter, parameter_slice), source_device_mesh, source_sharding_spec, shard_fn, gather_fn) + load( + parameter_name, + tensor, + partial(load_parameter, parameter_slice), + source_device_mesh, + source_sharding_spec, + shard_fn, + gather_fn, + ) if chunk.is_gathered: chunk.cuda_global_chunk.copy_(temp_chunk) @@ -791,7 +807,7 @@ def _cast_buffers(self): for buffer in self.module.buffers(): if isinstance(buffer, LazyTensor): buffer.materialize() - buffer.data = buffer.cuda() + buffer.data = buffer.to(get_current_device()) if torch.is_floating_point(buffer): buffer.data = buffer.to(self.mixed_precision) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index f7ff3f6cdd86..150932e3d8d9 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -17,9 +17,7 @@ class GeminiManager: https://arxiv.org/abs/2108.05818 Args: - placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'. - If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used. - If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used. + placement_policy (str): Which device to place *held* tensors. It can be 'static' and 'auto'. If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. Note that 'auto' policy can only work well when no other processes use CUDA during your training. chunk_manager (ChunkManager): A ``ChunkManager`` instance. @@ -121,7 +119,7 @@ def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, start = time() cuda_demand = 0 for chunk in chunks: - if chunk.device_type == "cuda": + if chunk.device_type == "cuda" or chunk.device_type == "npu": if chunk.is_gathered: pass else: diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index e20d846f1071..50d4f51d3390 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -7,31 +7,29 @@ import torch import torch.distributed as dist from packaging.version import Version +from torch.distributed import ProcessGroup from torch.nn import Parameter from torch.optim import Optimizer -from torch.distributed import ProcessGroup from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin -from colossalai.checkpoint_io.utils import StateDictSharder +from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam -from colossalai.utils import disposable, get_current_device, is_ddp_ignored - -from .chunk import Chunk, ChunkManager -from .gemini_ddp import GeminiDDP -from colossalai.checkpoint_io.utils import gather_distributed_param from colossalai.tensor.d_tensor import ( distribute_tensor, distribute_tensor_with_customization, - init_tensor_as_customization_distributed, get_device_mesh, get_sharding_spec, + init_as_dtensor, + init_tensor_as_customization_distributed, is_customized_distributed_tensor, is_distributed_tensor, - get_global_shape, - init_as_dtensor ) +from colossalai.utils import disposable, get_current_device, is_ddp_ignored + +from .chunk import Chunk, ChunkManager +from .gemini_ddp import GeminiDDP __all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"] @@ -312,7 +310,7 @@ def _maybe_move_fp32_params(self): chunk16 = self.param_to_chunk16[fake_param] chunk32 = chunk16.paired_chunk - if chunk32.device_type == "cuda": + if chunk32.device_type == "cuda" or chunk32.device_type == "npu": continue if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: @@ -326,7 +324,7 @@ def _maybe_move_fp32_params(self): for fake_param in group["params"]: chunk16 = self.param_to_chunk16[fake_param] chunk32 = chunk16.paired_chunk - if chunk32.device_type == "cuda": + if chunk32.device_type == "cuda" or chunk32.device_type == "npu": state = self.optim.state[fake_param] for k, v in state.items(): if isinstance(v, torch.Tensor): @@ -479,15 +477,19 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() if is_dtensor: state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) - state_tensor = init_as_dtensor(state_tensor, - device_mesh=device_mesh, - sharding_spec=shard_spec, - global_shape = global_shape) + state_tensor = init_as_dtensor( + state_tensor, + device_mesh=device_mesh, + sharding_spec=shard_spec, + global_shape=global_shape, + ) elif is_customized_distributed: state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) - init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) + init_tensor_as_customization_distributed( + state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn + ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() - + collected_states[state_name] = state_tensor.reshape(global_shape) return collected_states @@ -533,13 +535,14 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: collected_states[state_name] = torch.reshape(state_tensor, param.shape) if is_dtensor: state_tensor = state_tensor.to(param.device) - state_tensor = init_as_dtensor(state_tensor, - sharding_spec=shard_spec, - device_mesh=device_mesh, - global_shape=global_shape) + state_tensor = init_as_dtensor( + state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape + ) elif is_customized_distributed: state_tensor = state_tensor.to(param.device) - init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) + init_tensor_as_customization_distributed( + state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn + ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() return collected_states @@ -548,7 +551,7 @@ def pack_optimizer_states_to_tensor( self, param_id: int, state_names: list, - device: torch.device = torch.device("cuda"), + device: torch.device = get_current_device(), dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ @@ -705,7 +708,7 @@ def cast(param, state_range, value, key=None): ret_val = torch.zeros( state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False ) - + if is_dtensor: value = torch.reshape(value, global_shape) value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh) From 217c9dce136235717d92804282476da5530856d7 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 15 Nov 2023 18:35:52 +0800 Subject: [PATCH 3/3] [example] llama2+gemini support npu --- examples/language/llama2/benchmark.py | 5 +++-- examples/language/llama2/performance_evaluator.py | 8 +++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index ce13ebbf617d..ec8a1b9b8ff2 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -13,6 +13,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM import colossalai +import colossalai.utils.device as device_utils from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin from colossalai.cluster import DistCoordinator @@ -190,7 +191,7 @@ def empty_init(): torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) torch.set_default_dtype(torch.float) - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master( f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" ) @@ -216,7 +217,7 @@ def empty_init(): performance_evaluator.on_step_end(**batch) performance_evaluator.on_fit_end() - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py index a57c1e0e9ae3..ca0a7d140462 100644 --- a/examples/language/llama2/performance_evaluator.py +++ b/examples/language/llama2/performance_evaluator.py @@ -5,7 +5,9 @@ import torch.distributed as dist from torch import Tensor +import colossalai.utils.device as device_utils from colossalai.cluster import DistCoordinator +from colossalai.utils.device import get_current_device def divide(x: float, y: float) -> float: @@ -20,7 +22,7 @@ def divide(x: float, y: float) -> float: def all_reduce_mean(x: float, world_size: int) -> float: if world_size == 1: return x - tensor = torch.tensor([x], device=torch.cuda.current_device()) + tensor = torch.tensor([x], device=get_current_device()) dist.all_reduce(tensor) tensor = tensor / world_size return tensor.item() @@ -77,13 +79,13 @@ def on_step_start(self, step: int) -> None: self.disable = self.ignore_steps > 0 and step < self.ignore_steps if self.disable: return - torch.cuda.synchronize() + device_utils.synchronize() self.timer.start() def on_step_end(self, input_ids: Tensor, **kwargs) -> None: if self.disable: return - torch.cuda.synchronize() + device_utils.synchronize() self.timer.end() batch_size, seq_len = input_ids.shape