Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class LlamaPipelineForwards:
under pipeline setting.
'''

@staticmethod
def llama_model_forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -169,6 +170,7 @@ def custom_forward(*inputs):
# always return dict for imediate stage
return {'hidden_states': hidden_states}

@staticmethod
def llama_for_causal_lm_forward(
self: LlamaForCausalLM,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -276,6 +278,7 @@ def llama_for_causal_lm_forward(
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}

@staticmethod
def llama_for_sequence_classification_forward(
self: LlamaForSequenceClassification,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -388,6 +391,84 @@ def llama_for_sequence_classification_forward(
return {'hidden_states': hidden_states}


class LlamaInferenceForwards:
"""
This class holds forwards for llama inference.
"""

@staticmethod
def llama_model_forward(
Comment thread
tiandiao123 marked this conversation as resolved.
self: LlamaModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[
torch.LongTensor] = None, # TODO: this can also be removed if we got sin,cos cached in inferinfo
past_key_values: Optional[List[
torch.FloatTensor]] = None, #TODO: maybe removed after memory cache manager is done.
inputs_embeds: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
inferinfo=None,
):
# only keep the basic items
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

seq_length_with_past = seq_length
past_key_values_length = 0

if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)

hidden_states = inputs_embeds

for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
)

hidden_states = layer_outputs[0]

hidden_states = self.norm(hidden_states)

if not return_dict:
return hidden_states
return BaseModelOutputWithPast(last_hidden_state=hidden_states,)


def get_llama_flash_attention_forward():

from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
Expand Down
14 changes: 12 additions & 2 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
from dataclasses import dataclass
from typing import Optional

import torch.nn as nn

Expand Down Expand Up @@ -130,6 +131,12 @@ class PolicyLocation:
PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"),
}

_INFER_POLICY_LIST = {
# LlaMa
"transformers.models.llama.modeling_llama.LlamaModel":
PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy")
}


def import_policy(policy_location: PolicyLocation) -> Policy:
"""
Expand All @@ -151,7 +158,7 @@ def _fullname(obj):
return module + '.' + klass.__qualname__


def get_autopolicy(model: nn.Module) -> Policy:
def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
r"""
Return the auto policy for the model

Expand All @@ -162,7 +169,10 @@ def get_autopolicy(model: nn.Module) -> Policy:
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
policy_location = _POLICY_LIST.get(full_name, None)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None)

if policy_location is None:
raise NotImplementedError(
Expand Down
20 changes: 19 additions & 1 deletion colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from ..modeling.llama import LlamaInferenceForwards, LlamaPipelineForwards, get_llama_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
Expand Down Expand Up @@ -263,3 +263,21 @@ def get_held_layers(self) -> List[Module]:
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
return []


class LlamaModelInferPolicy(LlamaPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
policy = super().module_policy()
# configure default shard config for inference
self.shard_config._infer()

infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
Comment thread
tiandiao123 marked this conversation as resolved.

return policy
7 changes: 7 additions & 0 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ShardConfig:
enable_all_optimization: bool = False
enable_flash_attention: bool = False
enable_jit_fused: bool = False
inference_only: bool = False

# pipeline_parallel_size: int
# data_parallel_size: int
Expand Down Expand Up @@ -57,3 +58,9 @@ def _turn_on_all_optimization(self):
self.enable_fused_normalization = True
self.enable_flash_attention = True
self.enable_jit_fused = True

def _infer(self):
"""
Set default params for inference.
"""
self.pipeline_stage_manager = None
3 changes: 2 additions & 1 deletion colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class ModelSharder(object):

def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model
self.policy = get_autopolicy(self.model) if policy is None else policy
self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy
print(self.policy)
self.shard_config = shard_config

def shard(self) -> List[Dict[int, Tensor]]:
Expand Down
53 changes: 53 additions & 0 deletions tests/test_infer/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import copy

import torch
import torch.distributed as dist
from torch import Tensor
from torch import distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Adam, Optimizer

from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor


def build_model(
model_fn,
enable_fused_normalization=False,
enable_tensor_parallelism=False,
enable_flash_attention=False,
enable_jit_fused=False,
):
# create new model
org_model = model_fn()

# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused,
inference_only=True)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda()


def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn):
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}
# run forward
org_output = original_model(**data)
org_output = output_transform_fn(org_output)

shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)

return org_output, shard_output
55 changes: 55 additions & 0 deletions tests/test_infer/test_llama_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os

import pytest
import torch
from torch import distributed as dist

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_infer._utils import build_model, run_infer

os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'


def check_infer(model_fn, data_gen_fn, output_transform_fn, test_config):
org_model, sharded_model = build_model(model_fn, **test_config)

org_output, infer_output = run_infer(org_model, sharded_model, data_gen_fn, output_transform_fn)

print('original output', org_output[0])
print('infer output', infer_output[0])


@parameterize('test_config', [{
'enable_flash_attention': False,
}])
def run_llama_test(test_config):

sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name != "transformers_llama":
continue
check_infer(model_fn, data_gen_fn, output_transform_fn, test_config)
torch.cuda.empty_cache()


def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test()


@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, 1)


if __name__ == "__main__":
test_llama()