-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[infer] Infer/llama demo #4503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
tiandiao123
merged 6 commits into
hpcaitech:feature/colossal-inference
from
CjhHa1:infer/llama_demo
Aug 24, 2023
Merged
[infer] Infer/llama demo #4503
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| import copy | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from torch import Tensor | ||
| from torch import distributed as dist | ||
| from torch.distributed import ProcessGroup | ||
| from torch.nn import Module | ||
| from torch.optim import Adam, Optimizer | ||
|
|
||
| from colossalai.booster import Booster | ||
| from colossalai.booster.plugin import HybridParallelPlugin | ||
| from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule | ||
| from colossalai.shardformer import ShardConfig, ShardFormer | ||
| from colossalai.shardformer._utils import getattr_ | ||
| from colossalai.shardformer.policies.auto_policy import Policy | ||
| from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor | ||
|
|
||
|
|
||
| def build_model( | ||
| model_fn, | ||
| enable_fused_normalization=False, | ||
| enable_tensor_parallelism=False, | ||
| enable_flash_attention=False, | ||
| enable_jit_fused=False, | ||
| ): | ||
| # create new model | ||
| org_model = model_fn() | ||
|
|
||
| # shard model | ||
| shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, | ||
| enable_tensor_parallelism=enable_tensor_parallelism, | ||
| enable_flash_attention=enable_flash_attention, | ||
| enable_jit_fused=enable_jit_fused, | ||
| inference_only=True) | ||
| model_copy = copy.deepcopy(org_model) | ||
| shard_former = ShardFormer(shard_config=shard_config) | ||
| sharded_model, shared_params = shard_former.optimize(model_copy) | ||
| return org_model.cuda(), sharded_model.cuda() | ||
|
|
||
|
|
||
| def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn): | ||
| # prepare input | ||
| data = data_gen_fn() | ||
| data = {k: v.cuda() for k, v in data.items()} | ||
| # run forward | ||
| org_output = original_model(**data) | ||
| org_output = output_transform_fn(org_output) | ||
|
|
||
| shard_output = sharded_model(**data) | ||
| shard_output = output_transform_fn(shard_output) | ||
|
|
||
| return org_output, shard_output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
| from torch import distributed as dist | ||
|
|
||
| import colossalai | ||
| from colossalai.logging import disable_existing_loggers | ||
| from colossalai.shardformer.layer.utils import Randomizer | ||
| from colossalai.tensor.d_tensor.api import clear_layout_converter | ||
| from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn | ||
| from tests.kit.model_zoo import model_zoo | ||
| from tests.test_infer._utils import build_model, run_infer | ||
|
|
||
| os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' | ||
|
|
||
|
|
||
| def check_infer(model_fn, data_gen_fn, output_transform_fn, test_config): | ||
| org_model, sharded_model = build_model(model_fn, **test_config) | ||
|
|
||
| org_output, infer_output = run_infer(org_model, sharded_model, data_gen_fn, output_transform_fn) | ||
|
|
||
| print('original output', org_output[0]) | ||
| print('infer output', infer_output[0]) | ||
|
|
||
|
|
||
| @parameterize('test_config', [{ | ||
| 'enable_flash_attention': False, | ||
| }]) | ||
| def run_llama_test(test_config): | ||
|
|
||
| sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') | ||
|
|
||
| for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): | ||
| if name != "transformers_llama": | ||
| continue | ||
| check_infer(model_fn, data_gen_fn, output_transform_fn, test_config) | ||
| torch.cuda.empty_cache() | ||
|
|
||
|
|
||
| def check_llama(rank, world_size, port): | ||
| disable_existing_loggers() | ||
| colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') | ||
| run_llama_test() | ||
|
|
||
|
|
||
| @pytest.mark.dist | ||
| @rerun_if_address_is_in_use() | ||
| @clear_cache_before_run() | ||
| def test_llama(): | ||
| spawn(check_llama, 1) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_llama() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.