Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import Tensor

from colossalai.logging import get_dist_logger
from colossalai.utils.device import get_current_device

__all__ = ["BaseGradScaler"]

Expand All @@ -22,7 +23,7 @@ class BaseGradScaler(ABC):

def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
self._scale = torch.cuda.FloatTensor([initial_scale])
self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float)
self._verbose = verbose

if self._verbose:
Expand Down
8 changes: 5 additions & 3 deletions colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import torch

from colossalai.utils.device import get_current_device

from .base_grad_scaler import BaseGradScaler

__all__ = ["DynamicGradScaler"]
Expand Down Expand Up @@ -37,12 +39,12 @@ def __init__(
):
super().__init__(initial_scale, verbose)
if min_scale:
self._min_scale = torch.cuda.FloatTensor([min_scale])
self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float)
else:
self._min_scale = None

if max_scale:
self._max_scale = torch.cuda.FloatTensor([max_scale])
self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float)
else:
self._max_scale = None

Expand Down Expand Up @@ -115,7 +117,7 @@ def state_dict(self):
return state_dict

def load_state_dict(self, state_dict):
self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
self._scale = state_dict["scale"].to(get_current_device())
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"]
2 changes: 1 addition & 1 deletion colossalai/auto_parallel/offload/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from torch.fx.node import Node

from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def control_device(self) -> bool:
return True

def supported_devices(self) -> List[str]:
return ["cuda"]
return ["cuda", "npu"]

def configure(
self,
Expand Down
7 changes: 5 additions & 2 deletions colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from colossalai.context import Config
from colossalai.logging import get_dist_logger
from colossalai.utils import set_device, set_seed
from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed


def launch(
Expand Down Expand Up @@ -47,12 +47,15 @@ def launch(
if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.")

if IS_NPU_AVAILABLE and backend == "nccl":
backend = "hccl"

# init default process group
init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)

# set cuda device
if torch.cuda.is_available():
if torch.cuda.is_available() or IS_NPU_AVAILABLE:
# if local rank is not given, calculate automatically
set_device(local_rank)

Expand Down
2 changes: 1 addition & 1 deletion colossalai/kernel/cuda_native/mha/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F
from einops import rearrange

from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device


class Unpad(torch.autograd.Function):
Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/engine/schedule/_pipeline_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from ._base_schedule import BaseSchedule

Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.engine import Engine
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from ._pipeline_schedule import PipelineSchedule

Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/nn/layer/parallel_1d/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule
Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/nn/layer/parallel_2d/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/nn/layer/parallel_2p5d/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/nn/layer/parallel_3d/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (
Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/nn/layer/vanilla/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from colossalai.legacy.context import seed
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from ..utils import to_2tuple

Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from time import time
from typing import List

from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy
Expand Down
46 changes: 26 additions & 20 deletions colossalai/pipeline/schedule/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from colossalai.inference.pipeline.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule
Expand Down Expand Up @@ -93,9 +93,7 @@ def _prepare_inputs_for_interval_stage(self):
Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
"""
model_inputs = {
'infer_state': self.mb_manager.cur_descrption.infer_state
}
model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
return model_inputs

def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
Expand Down Expand Up @@ -129,8 +127,8 @@ def _recv_pre_stage(self) -> Any:

def _init_infer_state_action(self) -> None:
"""
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
Expand All @@ -145,19 +143,19 @@ def _load_stage_action(self, model: Module) -> None:
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)

self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]

def _gen_token_action(self, model: Module):
"""
This action is only for first stage
This action is only for first stage
1.do the forward with hidden_states to generate new tokens 2.step to update
"""
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
Expand All @@ -178,18 +176,18 @@ def _head_encoding_action(self, model: Module):
new_token = self.action_interval_buffer.new_token
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)

self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]

def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, None, interval_inputs)

self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]

def _comm_action(self, recv_pre: bool) -> torch.Tensor:
"""
Expand Down Expand Up @@ -319,7 +317,7 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# In GENERATE phase
else:
Expand All @@ -330,18 +328,23 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t
assert (
hidden_states is not None
), "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits['logits'])
assert (
"logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(new_token)
# If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
else:
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
Expand All @@ -350,7 +353,10 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t
if self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
output_dict = model_forward(model, inputs_dict, interval_inputs)

# Current microbatch is not DONE, send hidden_state to next stage
Expand Down
2 changes: 1 addition & 1 deletion colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
Expand Down
2 changes: 1 addition & 1 deletion colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from ._utils import (
detach,
Expand Down
Loading