-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[shardformer] Unit test #3928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
FrankLeeeee
merged 6 commits into
hpcaitech:feature/shardformer
from
FoolPlayer:unit_test
Jun 12, 2023
Merged
[shardformer] Unit test #3928
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
3df5bca
fix bug in slicer, add slicer unit test
FoolPlayer b4fac4e
add dropout test
FoolPlayer 27e3f7d
use pid as dropout seed
FoolPlayer cf2a65a
updata dropout test with local pattern
FoolPlayer d260db7
ad todo
FoolPlayer 06bb4ef
Merge branch 'feature/shardformer' of https://github.com/FoolPlayer/C…
FoolPlayer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| import colossalai | ||
| from colossalai.logging import disable_existing_loggers | ||
| from colossalai.shardformer.layer.dropout import Dropout1D | ||
| from colossalai.testing import rerun_if_address_is_in_use, spawn | ||
|
|
||
| CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) | ||
|
|
||
|
|
||
| def check_dropout(rank, world_size, port): | ||
| disable_existing_loggers() | ||
| colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') | ||
|
|
||
| # prepare data | ||
| input = torch.randn(5, 4).to('cuda') | ||
| dropout = Dropout1D(p=0.4).to('cuda') | ||
| output_list = [] | ||
| # compare the dropout pattern in each device | ||
| for i in range(2): | ||
| output = dropout(input) | ||
| output_list.append(output) | ||
| dist_output_list = [torch.zeros(*output.shape).to('cuda') for _ in range(world_size)] | ||
| torch.distributed.all_gather(dist_output_list, output) | ||
| for j in range(world_size): | ||
| for k in range(world_size): | ||
| if j != k: | ||
| mask = torch.eq(dist_output_list[j], 0.0) == torch.eq(dist_output_list[k], 0.0) | ||
| assert torch.all( | ||
| mask | ||
| ) == False, f"The dropout pattern in each device is not unique\n{dist_output_list[j]}\n{dist_output_list[k]}" | ||
| # compare the dropout pattern in loacl device | ||
| for i in range(len(output_list)): | ||
| for j in range(len(output_list)): | ||
| if i != j: | ||
| mask = torch.eq(output_list[i], 0.0) == torch.eq(output_list[j], 0.0) | ||
| assert torch.all( | ||
| mask | ||
| ) == False, f"The dropout pattern in one device is not unique\n{output_list[i]}\n{output_list[j]}" | ||
|
|
||
|
|
||
| @pytest.mark.dist | ||
| @rerun_if_address_is_in_use() | ||
| def test_dropout(): | ||
| spawn(check_dropout, 2) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| test_dropout() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| import colossalai | ||
| from colossalai.logging import disable_existing_loggers | ||
| from colossalai.shardformer.policies.basepolicy import Col_Layer, Layer, Row_Layer | ||
| from colossalai.shardformer.shard.shard_config import ShardConfig | ||
| from colossalai.shardformer.shard.slicer import Slicer | ||
| from colossalai.testing import rerun_if_address_is_in_use, spawn | ||
|
|
||
| CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) | ||
|
|
||
|
|
||
| def check_slicer(rank, world_size, port, in_feature, out_feature): | ||
| disable_existing_loggers() | ||
| colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') | ||
| # initialize slicer | ||
| shardconfig = ShardConfig(rank=rank, world_size=world_size) | ||
| slicer = Slicer(shardconfig) | ||
| # initialize test data | ||
| weight = torch.randn(in_feature, out_feature) | ||
| bias = torch.randn(out_feature) | ||
| policy_layer_cls_list = [Layer, Col_Layer, Row_Layer] | ||
| n_cast_list = [None, 2, 3, 4] | ||
| # weight and bias | ||
| for n_cast in n_cast_list: | ||
| sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Layer, n_cast=n_cast) | ||
| expected_sliced_weight = weight | ||
| expected_sliced_bias = bias | ||
| assert torch.equal( | ||
| sliced_weight, expected_sliced_weight | ||
| ), f"In Layer case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" | ||
| assert torch.equal( | ||
| sliced_bias, expected_sliced_bias | ||
| ), f"In Layer case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" | ||
|
|
||
| sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Col_Layer, n_cast=n_cast) | ||
| if (n_cast is None): | ||
| expected_sliced_weight = weight.chunk(world_size, dim=0)[rank] | ||
| expected_sliced_bias = bias.chunk(world_size)[rank] | ||
| else: | ||
| chunks = weight.chunk(world_size * n_cast, dim=0) | ||
| expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=0) | ||
| chunks = bias.chunk(world_size * n_cast, dim=0) | ||
| expected_sliced_bias = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)]) | ||
| assert torch.equal( | ||
| sliced_weight, expected_sliced_weight | ||
| ), f"In Col_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" | ||
| assert torch.equal( | ||
| sliced_bias, expected_sliced_bias | ||
| ), f"In Col_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_bias}\nexpected:{expected_sliced_bias}" | ||
|
|
||
| sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Row_Layer, n_cast=n_cast) | ||
| if (n_cast is None): | ||
| expected_sliced_weight = weight.chunk(world_size, dim=1)[rank] | ||
| expected_sliced_bias = bias | ||
| else: | ||
| chunks = weight.chunk(world_size * n_cast, dim=1) | ||
| expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=1) | ||
| expected_sliced_bias = bias | ||
| assert torch.equal( | ||
| sliced_weight, expected_sliced_weight | ||
| ), f"In Row_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" | ||
| assert torch.equal( | ||
| sliced_bias, expected_sliced_bias | ||
| ), f"In Row_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" | ||
|
|
||
|
|
||
| @pytest.mark.dist | ||
| @rerun_if_address_is_in_use() | ||
| def test_slicer(): | ||
| args = dict(in_feature=24, out_feature=48) | ||
| spawn(check_slicer, nprocs=2, in_feature=args['in_feature'], out_feature=args['out_feature']) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| test_slicer() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.