diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py index 79661a44424f..439d13dcfc11 100644 --- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -8,6 +8,7 @@ from torch import Tensor from colossalai.logging import get_dist_logger +from colossalai.utils.device import get_current_device __all__ = ["BaseGradScaler"] @@ -22,7 +23,7 @@ class BaseGradScaler(ABC): def __init__(self, initial_scale: float, verbose: bool): assert initial_scale > 0 - self._scale = torch.cuda.FloatTensor([initial_scale]) + self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float) self._verbose = verbose if self._verbose: diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py index 65133a4b3712..86ba919ee696 100644 --- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -5,6 +5,8 @@ import torch +from colossalai.utils.device import get_current_device + from .base_grad_scaler import BaseGradScaler __all__ = ["DynamicGradScaler"] @@ -37,12 +39,12 @@ def __init__( ): super().__init__(initial_scale, verbose) if min_scale: - self._min_scale = torch.cuda.FloatTensor([min_scale]) + self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float) else: self._min_scale = None if max_scale: - self._max_scale = torch.cuda.FloatTensor([max_scale]) + self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float) else: self._max_scale = None @@ -115,7 +117,7 @@ def state_dict(self): return state_dict def load_state_dict(self, state_dict): - self._scale = state_dict["scale"].cuda(torch.cuda.current_device()) + self._scale = state_dict["scale"].to(get_current_device()) self._growth_factor = state_dict["growth_factor"] self._backoff_factor = state_dict["backoff_factor"] self._hysteresis = state_dict["hysteresis"] diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py index a6b4904f2617..a6628e29c2bc 100644 --- a/colossalai/auto_parallel/offload/solver.py +++ b/colossalai/auto_parallel/offload/solver.py @@ -11,7 +11,7 @@ import torch from torch.fx.node import Node -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from .region import Region from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 09343138f5ff..89102820cd38 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -306,7 +306,7 @@ def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ["cuda"] + return ["cuda", "npu"] def configure( self, diff --git a/colossalai/initialize.py b/colossalai/initialize.py index aac57d34a2c1..25076b742c26 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -11,7 +11,7 @@ from colossalai.context import Config from colossalai.logging import get_dist_logger -from colossalai.utils import set_device, set_seed +from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed def launch( @@ -47,12 +47,15 @@ def launch( if rank == 0: warnings.warn("`config` is deprecated and will be removed soon.") + if IS_NPU_AVAILABLE and backend == "nccl": + backend = "hccl" + # init default process group init_method = f"tcp://[{host}]:{port}" dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # set cuda device - if torch.cuda.is_available(): + if torch.cuda.is_available() or IS_NPU_AVAILABLE: # if local rank is not given, calculate automatically set_device(local_rank) diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py index fe31921b961b..5f01e3ef327d 100644 --- a/colossalai/kernel/cuda_native/mha/utils.py +++ b/colossalai/kernel/cuda_native/mha/utils.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from einops import rearrange -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device class Unpad(torch.autograd.Function): diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py index 4fc5040f6983..5fd5602e790c 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py @@ -12,7 +12,7 @@ from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank from colossalai.logging import get_dist_logger -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ._base_schedule import BaseSchedule diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py index 867c3dfa819b..4cd7e47c37f1 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py @@ -9,7 +9,7 @@ from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.engine import Engine -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ._pipeline_schedule import PipelineSchedule diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py index 8304cd2e1eb7..b6ec5347f2e2 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py @@ -22,7 +22,7 @@ partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py index 3b2e032e5127..f81c5334ad77 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py @@ -18,7 +18,7 @@ partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py index fc2e35f36cbc..b451a4031c25 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py @@ -19,7 +19,7 @@ partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py index 196679994197..16e515f87da3 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py @@ -27,7 +27,7 @@ partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ._operation import ( diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py index 12965a4a6409..590ad5ff6085 100644 --- a/colossalai/legacy/nn/layer/vanilla/layers.py +++ b/colossalai/legacy/nn/layer/vanilla/layers.py @@ -10,7 +10,7 @@ from colossalai.legacy.context import seed from colossalai.legacy.registry import LAYERS from colossalai.nn import init as init -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ..utils import to_2tuple diff --git a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py index 19f77d4305af..e336717f4164 100644 --- a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py +++ b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py @@ -3,7 +3,7 @@ from time import time from typing import List -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from .stateful_tensor import StatefulTensor, TensorState from .tensor_placement_policy import TensorPlacementPolicy diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index db02dab59ca6..7022efd7d1fa 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -10,7 +10,7 @@ from colossalai.inference.pipeline.microbatch_manager import MicroBatchManager, Status from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ._utils import get_batch_size, get_micro_batch, model_forward, to_device from .base import PipelineSchedule @@ -93,9 +93,7 @@ def _prepare_inputs_for_interval_stage(self): Returns: dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` """ - model_inputs = { - 'infer_state': self.mb_manager.cur_descrption.infer_state - } + model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state} return model_inputs def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): @@ -129,8 +127,8 @@ def _recv_pre_stage(self) -> Any: def _init_infer_state_action(self) -> None: """ - This action is only for no first stage, to load batch and init infer_state. - 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state + This action is only for no first stage, to load batch and init infer_state. + 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state """ inputs_dict = self.load_micro_batch() self.mb_manager.add_descrption(inputs_dict) @@ -145,19 +143,19 @@ def _load_stage_action(self, model: Module) -> None: if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _gen_token_action(self, model: Module): """ - This action is only for first stage + This action is only for first stage 1.do the forward with hidden_states to generate new tokens 2.step to update """ hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" - interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state} logits = model_forward(model, None, interval_inputs) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() @@ -178,18 +176,18 @@ def _head_encoding_action(self, model: Module): new_token = self.action_interval_buffer.new_token assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None" inputs_dict = self._prepare_inputs_for_new_token(new_token) - interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _body_encoding_action(self, model: Module): hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When not first stage, the hidden states should not be None" - interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, None, interval_inputs) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _comm_action(self, recv_pre: bool) -> torch.Tensor: """ @@ -319,7 +317,7 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) self.mb_manager.add_descrption(inputs_dict) - interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) # In GENERATE phase else: @@ -330,18 +328,23 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t assert ( hidden_states is not None ), "When first stage in GENERATE phase, the hidden states should not be None" - interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = { + "hidden_states": hidden_states["hidden_states"], + "infer_state": self.mb_manager.cur_infer_state, + } logits = model_forward(model, None, interval_inputs) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" - new_token = self._get_token_id(logits['logits']) + assert ( + "logits" in logits + ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits["logits"]) self.mb_manager.step(new_token) # If the current micro batch is not DONE, go through blocks if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): inputs_dict = self._prepare_inputs_for_new_token(new_token) - interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) else: assert hidden_states is not None, "When not first stage, the hidden states should not be None" @@ -350,7 +353,10 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t if self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() self.mb_manager.add_descrption(inputs_dict) - interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = { + "hidden_states": hidden_states["hidden_states"], + "infer_state": self.mb_manager.cur_infer_state, + } output_dict = model_forward(model, inputs_dict, interval_inputs) # Current microbatch is not DONE, send hidden_state to next stage diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 780437155c61..cbf6dd80f3e0 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -9,7 +9,7 @@ from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 4eaf135fd5db..4a061ae43844 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -9,7 +9,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ._utils import ( detach, diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 42efe9a44308..8387bb5e365e 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -2,16 +2,19 @@ # -*- encoding: utf-8 -*- import warnings from abc import ABC, abstractmethod + import torch.nn as nn + from colossalai.lazy import LazyInitContext -from ._operation import hook_paramter_in_backward +from ._operation import hook_paramter_in_backward from .utils import SeqParallelUtils __all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"] try: from apex.contrib.layer_norm.layer_norm import FastLayerNorm + EnableFastLayerNorm = True except ImportError: EnableFastLayerNorm = False @@ -19,10 +22,27 @@ try: from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm + + class FusedLayerNormWithHook(ApexFusedLayerNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input): + output = super().forward(input) + output = hook_paramter_in_backward(output, self.weight, self.bias) + return output + + class FusedRMSNormWithHook(ApexFusedRMSNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input): + output = super().forward(input) + output = hook_paramter_in_backward(output, self.weight) + return output + except ImportError: - warnings.warn( - "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel" - ) + warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel") FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024, @@ -52,6 +72,7 @@ ] if EnableFastLayerNorm: + class FastLayerNormWithHook(FastLayerNorm): def __init__(self, hidden_size, eps=0.00001): super().__init__(hidden_size, eps) @@ -60,25 +81,7 @@ def forward(self, input): output = super().forward(input) output = hook_paramter_in_backward(output, self.weight, self.bias) return output - -class FusedLayerNormWithHook(ApexFusedLayerNorm): - def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): - super().__init__(normalized_shape, eps, elementwise_affine) - - def forward(self, input): - output = super().forward(input) - output = hook_paramter_in_backward(output, self.weight, self.bias) - return output - -class FusedRMSNormWithHook(ApexFusedRMSNorm): - def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): - super().__init__(normalized_shape, eps, elementwise_affine) - - def forward(self, input): - output = super().forward(input) - output = hook_paramter_in_backward(output, self.weight) - return output - + class BaseLayerNorm(ABC): @abstractmethod @@ -244,12 +247,13 @@ class FusedRMSNorm(BaseLayerNorm): """ This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. """ + def __init__(self) -> None: raise NotImplementedError( "FusedRMSNorm is not implemented as a physical class. " "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex." ) - + @staticmethod def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: r""" @@ -264,7 +268,7 @@ def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *arg nn.Module: FusedRMSNorm module. """ try: - from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm + pass except ImportError: raise ImportError( "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" @@ -282,7 +286,9 @@ def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *arg eps = module.eps elementwise_affine = module.elementwise_affine - rmsnorm = FusedRMSNormWithHook(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) + rmsnorm = FusedRMSNormWithHook( + normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine + ) rmsnorm.weight = module.weight diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 3ec39b949a23..0246a35e2a1b 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -7,7 +7,7 @@ is_ddp_ignored, set_seed, ) -from .cuda import empty_cache, get_current_device, set_device, set_to_cuda, synchronize +from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize from .multi_tensor_apply import multi_tensor_applier from .tensor_detector import TensorDetector from .timer import MultiTimer, Timer @@ -29,4 +29,5 @@ "set_seed", "is_ddp_ignored", "set_device", + "IS_NPU_AVAILABLE", ] diff --git a/colossalai/utils/cuda.py b/colossalai/utils/cuda.py deleted file mode 100644 index 6bfb08d1f04a..000000000000 --- a/colossalai/utils/cuda.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import Optional - -import torch -import torch.distributed as dist - - -def set_to_cuda(models): - """Send model to gpu. - - :param models: nn.module or a list of module - """ - if isinstance(models, list) and len(models) > 1: - ret = [] - for model in models: - ret.append(model.to(get_current_device())) - return ret - elif isinstance(models, list): - return models[0].to(get_current_device()) - else: - return models.to(get_current_device()) - - -def get_current_device() -> torch.device: - """ - Returns currently selected device (gpu/cpu). - If cuda available, return gpu, otherwise return cpu. - """ - if torch.cuda.is_available(): - return torch.device(f"cuda:{torch.cuda.current_device()}") - else: - return torch.device("cpu") - - -def synchronize(): - """Similar to cuda.synchronize(). - Waits for all kernels in all streams on a CUDA device to complete. - """ - if torch.cuda.is_available(): - torch.cuda.synchronize() - - -def empty_cache(): - """Similar to cuda.empty_cache() - Releases all unoccupied cached memory currently held by the caching allocator. - """ - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def set_device(index: Optional[int] = None) -> None: - if index is None: - index = dist.get_rank() % torch.cuda.device_count() - torch.cuda.set_device(index) diff --git a/colossalai/utils/device.py b/colossalai/utils/device.py new file mode 100644 index 000000000000..9b78f881b3bc --- /dev/null +++ b/colossalai/utils/device.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Optional + +import torch +import torch.distributed as dist + +IS_NPU_AVAILABLE: bool = False +try: + import torch_npu # noqa + + IS_NPU_AVAILABLE = torch.npu.is_available() +except ImportError: + pass + + +def set_to_cuda(models): + """Send model to gpu. + + :param models: nn.module or a list of module + """ + if isinstance(models, list) and len(models) > 1: + ret = [] + for model in models: + ret.append(model.to(get_current_device())) + return ret + elif isinstance(models, list): + return models[0].to(get_current_device()) + else: + return models.to(get_current_device()) + + +def get_current_device() -> torch.device: + """ + Returns currently selected device (gpu/cpu). + If cuda available, return gpu, otherwise return cpu. + """ + if torch.cuda.is_available(): + return torch.device(f"cuda:{torch.cuda.current_device()}") + elif IS_NPU_AVAILABLE: + return torch.device(f"npu:{torch.npu.current_device()}") + else: + return torch.device("cpu") + + +def synchronize(): + """Similar to cuda.synchronize(). + Waits for all kernels in all streams on a CUDA device to complete. + """ + if torch.cuda.is_available(): + torch.cuda.synchronize() + elif IS_NPU_AVAILABLE: + torch.npu.synchronize() + + +def device_count() -> int: + if torch.cuda.is_available(): + return torch.cuda.device_count() + elif IS_NPU_AVAILABLE: + return torch.npu.device_count() + else: + raise RuntimeError("No device available") + + +def empty_cache(): + """Similar to cuda.empty_cache() + Releases all unoccupied cached memory currently held by the caching allocator. + """ + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif IS_NPU_AVAILABLE: + torch.npu.empty_cache() + + +def set_device(index: Optional[int] = None) -> None: + if index is None: + index = dist.get_rank() % device_count() + if torch.cuda.is_available(): + torch.cuda.set_device(index) + elif IS_NPU_AVAILABLE: + torch.npu.set_device(index) + + +def Stream(device=None, priority=0, **kwargs): + if torch.cuda.is_available(): + return torch.cuda.Stream(device, priority, **kwargs) + elif IS_NPU_AVAILABLE: + return torch.npu.Stream(device, priority, **kwargs) + else: + raise RuntimeError("No device available") + + +def stream(stream_): + if torch.cuda.is_available(): + return torch.cuda.stream(stream_) + elif IS_NPU_AVAILABLE: + return torch.npu.stream(stream_) + else: + raise RuntimeError("No device available") + + +def set_stream(stream_): + if torch.cuda.is_available(): + return torch.cuda.set_stream(stream_) + elif IS_NPU_AVAILABLE: + return torch.npu.set_stream(stream_) + else: + raise RuntimeError("No device available") + + +def current_stream(device=None): + if torch.cuda.is_available(): + return torch.cuda.current_stream(device) + elif IS_NPU_AVAILABLE: + return torch.npu.current_stream(device) + else: + raise RuntimeError("No device available") + + +def default_stream(device=None): + if torch.cuda.is_available(): + return torch.cuda.default_stream(device) + elif IS_NPU_AVAILABLE: + return torch.npu.default_stream(device) + else: + raise RuntimeError("No device available") diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py index 2f61817f0461..8ab6b46f28b6 100644 --- a/colossalai/utils/timer.py +++ b/colossalai/utils/timer.py @@ -3,7 +3,7 @@ import time from typing import Tuple -from .cuda import synchronize +from .device import synchronize class Timer: diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index d61082beddba..c1b35ee17f91 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -12,6 +12,7 @@ from torch.distributed import ProcessGroup from torch.optim import Optimizer +import colossalai.utils.device as device_utils from colossalai.amp.naive_amp.mixed_precision_mixin import ( BF16MixedPrecisionMixin, FP16MixedPrecisionMixin, @@ -22,7 +23,7 @@ from colossalai.tensor.moe_tensor.api import is_moe_tensor # from colossalai.tensor import ColoParameter, ProcessGroup -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -182,7 +183,7 @@ def __init__( # intialize communication stream for # communication-compuation overlapping if self._overlap_communication: - self._comm_stream = torch.cuda.Stream() + self._comm_stream = device_utils.Stream() # reduction hook is only used if overlapping communication # or stage 2 is used @@ -216,7 +217,7 @@ def num_param_groups(self): return len(self._working_param_groups) def _sanity_checks(self): - assert torch.cuda.is_available(), "CUDA is required" + assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required" for param_group in self.optim.param_groups: group_params = param_group["params"] for param in group_params: @@ -339,11 +340,11 @@ def _run_reduction(self): if len(moe_grad_list) > 0: moe_flat_grads.record_stream(stream) # waiting for ops in the default stream finishing - stream.wait_stream(torch.cuda.current_stream()) + stream.wait_stream(device_utils.current_stream()) else: - stream = torch.cuda.current_stream() + stream = device_utils.current_stream() - with torch.cuda.stream(stream): + with device_utils.stream(stream): group_id = self._bucket_store.current_group_id if self.moe_extra_dp_pg is None: @@ -485,7 +486,7 @@ def backward(self, loss, retain_graph=False): # clear reduced grads if self._overlap_communication: - torch.cuda.synchronize() + device_utils.synchronize() self.zero_grad() @@ -504,7 +505,7 @@ def backward_by_grad(self, tensor, grad): # clear reduced grads if self._overlap_communication: - torch.cuda.synchronize() + device_utils.synchronize() self.zero_grad() @@ -620,22 +621,25 @@ def step(self, closure=None): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank + device = get_current_device() for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): all_splited_param = [ - torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) + torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self.moe_extra_dp_pg_size) ] - dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.moe_extra_dp_pg) + dist.all_gather( + all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg + ) else: all_splited_param = [ - torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) + torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size) ] - dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg) + dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] @@ -657,7 +661,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo norm_type = float(norm_type) if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) total_norm = total_norm_cuda.item() @@ -668,7 +672,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo total_norm_exponentiated += grad_norm_exponentiated # Sum across all model parallel GPUs. - total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float + ) torch.distributed.all_reduce( total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg ) @@ -759,6 +765,7 @@ def state_dict(self) -> Dict: Dict: the pytorch form state_dict """ zero_state = dict() + device = get_current_device() for param, state in self.optim.state.items(): zero_state[param] = copy.deepcopy(state) for k, v in state.items(): @@ -766,14 +773,14 @@ def state_dict(self) -> Dict: working_param = self._param_store.master_to_working_param[id(param)] if self.moe_extra_dp_pg is not None and is_moe_tensor(v): gather_tensor = [ - torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) ] - dist.all_gather(gather_tensor, v.cuda(), group=self.moe_extra_dp_pg) + dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg) else: gather_tensor = [ - torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) ] - dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) + dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg) param_state = ( torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) @@ -820,6 +827,7 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i ret_block = dict() ret_block_size = 0 + device = get_current_device() local_states = self.optim.state_dict()["state"] for param_idx, states in local_states.items(): current_block_size = 0 @@ -836,14 +844,14 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i if isinstance(v, torch.Tensor) and k != "step": if self.moe_extra_dp_pg is not None and is_moe_tensor(v): state_tensor = [ - torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) ] - dist.all_gather(state_tensor, v.cuda(), group=self.moe_extra_dp_pg) + dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg) else: state_tensor = [ - torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) ] - dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) + dist.all_gather(state_tensor, v.to(device), group=self.dp_pg) state_tensor = ( torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 104ca254c572..3eaaf882c9ba 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -2,11 +2,14 @@ import torch import torch.distributed as dist +from torch.optim import Adam import colossalai +import colossalai.utils.device as device_utils from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.nn.optimizer import HybridAdam + +# from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -19,16 +22,17 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: + device = device_utils.get_current_device() try: plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5) booster = Booster(plugin=plugin) model = model_fn() - optimizer = HybridAdam(model.parameters(), lr=1e-3) + optimizer = Adam(model.parameters(), lr=1e-3) criterion = lambda x: x.mean() data = data_gen_fn() data = { - k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) @@ -65,7 +69,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): continue err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) - torch.cuda.empty_cache() + device_utils.empty_cache() if err is None: passed_models.append(name) @@ -89,7 +93,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True): @rerun_if_address_is_in_use() def test_low_level_zero_plugin(early_stop: bool = True): - spawn(run_dist, 4, early_stop=early_stop) + spawn(run_dist, 2, early_stop=early_stop) if __name__ == "__main__": diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py index 9416ac86e325..9df7cf75aae5 100644 --- a/tests/test_legacy/test_utils/test_memory.py +++ b/tests/test_legacy/test_utils/test_memory.py @@ -3,7 +3,7 @@ import colossalai from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.testing import spawn -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index b8d3f45e0f34..21afff753ae6 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -9,7 +9,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd_bwd diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index bfd3ebfcb28c..35323e516071 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -9,7 +9,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index e20428b67b41..152bf289502a 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -11,7 +11,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 887e495e6187..405d7d789b01 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -9,7 +9,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd_bwd diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 3c5baea138e0..351ae5f67ff7 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -9,7 +9,7 @@ import colossalai from colossalai.testing import spawn from colossalai.testing.random import seed_all -from colossalai.utils import conditional_context +from colossalai.utils import conditional_context, get_current_device from colossalai.zero import LowLevelZeroOptimizer @@ -28,9 +28,9 @@ def forward(self, x): def exam_zero_1_2_grad_acc(): local_rank = torch.distributed.get_rank() seed_all(2009) - + device = get_current_device() # create model - zero1_model = MlpModel().cuda() + zero1_model = MlpModel().to(device) zero2_model = copy.deepcopy(zero1_model) # create optimizer zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) @@ -43,8 +43,8 @@ def exam_zero_1_2_grad_acc(): ) # create data seed_all(2021 + local_rank) - input_data1 = torch.randn(32, 128).cuda() - input_data2 = torch.randn(32, 128).cuda() + input_data1 = torch.randn(32, 128, device=device) + input_data2 = torch.randn(32, 128, device=device) def fwd_bwd_func(number, cur_data, check_flag): # zero-dp forward @@ -71,14 +71,15 @@ def fwd_bwd_func(number, cur_data, check_flag): def exam_zero_1_grad_acc(sync): local_rank = torch.distributed.get_rank() seed_all(2008) + device = get_current_device() # create models zero_model = MlpModel() torch_model = copy.deepcopy(zero_model) seed_all(2008) - zero_model = zero_model.cuda() - torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) + zero_model = zero_model.to(device) + torch_model = DDP(torch_model.to(device), bucket_cap_mb=0) # create optimizer zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) @@ -94,8 +95,8 @@ def exam_zero_1_grad_acc(sync): # create data seed_all(2022 + local_rank) - input_data1 = torch.randn(32, 128).cuda() - input_data2 = torch.randn(32, 128).cuda() + input_data1 = torch.randn(32, 128, device=device) + input_data2 = torch.randn(32, 128, device=device) def fwd_bwd_func(no_sync, cur_data, check_flag): # zero1 fwd and bwd