-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass #8069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
masahi
merged 59 commits into
apache:main
from
AndrewZhaoLuo:andrewluo-add-fp16-conversion-pass
Jun 21, 2021
Merged
Changes from all commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
425471d
Initial skeleton for fp16 pass.
AndrewZhaoLuo 2bd5311
Working python version of fp16 pass.
AndrewZhaoLuo 9fda090
Rewrite python passes in C++
AndrewZhaoLuo 4903a31
Extend support to things besides CallNodes. E.g. tuples and lets
AndrewZhaoLuo 41ac568
Rewrite how and when casting is done by checking types directly.
AndrewZhaoLuo bde1c58
linting and formatting
AndrewZhaoLuo 2101e6e
add AST header
AndrewZhaoLuo 8e82c40
remove todo
AndrewZhaoLuo 399121b
lint errors2
AndrewZhaoLuo c8f7428
remove i386 incompatible features
AndrewZhaoLuo 42b0c04
Trigger CI again
AndrewZhaoLuo 65b8d6c
set seed
AndrewZhaoLuo 8860b1c
lint
AndrewZhaoLuo b3b8776
address animesh's initial comments
AndrewZhaoLuo 479124b
mutate attributes only if they were originally floats
AndrewZhaoLuo 22ae9e7
initial comments from matthew
AndrewZhaoLuo d956848
add comment on hashing strat
AndrewZhaoLuo cb39e0f
add missing ;
AndrewZhaoLuo a00fd8b
edge case when mutating attrs
AndrewZhaoLuo e25c40c
Cody's easy to address comments
AndrewZhaoLuo 70436f5
add test to show green-red casting works
AndrewZhaoLuo 2c78317
remove np.random seed from each test
AndrewZhaoLuo 44b9782
remove as many references to fp16 types in favor of generic mixed types
AndrewZhaoLuo 4911d4f
rename RED, GREEN, GRAY to MIXED_PRECISION_ALLOW, etc.
AndrewZhaoLuo 47c2cf8
skeleton for supporting arbitrary mixed types
AndrewZhaoLuo 239dbfb
cool tests
AndrewZhaoLuo 33e286f
Using MixedModeMutator
AndrewZhaoLuo 418f873
rename things ToMixedPrecision
AndrewZhaoLuo 7d62fe1
rename passes to amp.cc
AndrewZhaoLuo b4ebd06
rename tests to match transform
AndrewZhaoLuo 8968cda
clean up typos
AndrewZhaoLuo 180b556
rename even better to_mixed_precision
AndrewZhaoLuo 528ef7b
don't insert into cache when dtypes equal
AndrewZhaoLuo 5ca1462
new python interface for registering ops
AndrewZhaoLuo 9e77cff
cleaner registering ops
AndrewZhaoLuo e691e4f
add fp64 structural test
AndrewZhaoLuo 37200fd
clean up and comments
AndrewZhaoLuo 4c93545
make copy of attributes
AndrewZhaoLuo 6aa727d
asf header
AndrewZhaoLuo 173801b
pylint
AndrewZhaoLuo f4da2df
remove TODO which is solved
AndrewZhaoLuo 7698920
Apply nits from code review (comaniac)
AndrewZhaoLuo 177f9c4
change cast_node_cache --> cast_node_cache_
AndrewZhaoLuo 8ddabda
add check for returned vals
AndrewZhaoLuo 78b5b31
better error msg
AndrewZhaoLuo 54d7c3d
docstring for pass in python
AndrewZhaoLuo 3331224
fix default behavior to be proper
AndrewZhaoLuo c781bf2
better error reporting via single flag
AndrewZhaoLuo b513fee
priority to 0
AndrewZhaoLuo 4fea978
address more nits
AndrewZhaoLuo 25d8a1d
fix story telling slightly
AndrewZhaoLuo a063994
restart
AndrewZhaoLuo 22841f1
correct docstring
AndrewZhaoLuo 7a933a5
change class fields to have _ at end
AndrewZhaoLuo a1dbb68
add class docstring
AndrewZhaoLuo 97fbd89
add comment on accumulation dtype hack
AndrewZhaoLuo 64408ee
ADT warnings
AndrewZhaoLuo 98e9cea
add todo
AndrewZhaoLuo 2634182
fix linter
AndrewZhaoLuo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,195 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| # pylint: disable=line-too-long,unused-argument | ||
| """Default behavior for ops in mixed_precision pass. Import this file to use.""" | ||
| from typing import List | ||
|
|
||
| from tvm import relay | ||
| from tvm.relay.op import register_mixed_precision_conversion | ||
|
|
||
| # MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory | ||
| # savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to | ||
| # justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to | ||
| # numerical reasons. | ||
| MIXED_PRECISION_ALWAYS = 0 | ||
| MIXED_PRECISION_FOLLOW = 1 | ||
| MIXED_PRECISION_NEVER = 2 | ||
|
|
||
| # Default lists inspired from TF's classifications: | ||
| # github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h | ||
| # They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice. | ||
| DEFAULT_ALWAYS_LIST = [ | ||
| "nn.conv1d", | ||
| "nn.conv2d", | ||
| "nn.conv3d", | ||
| "nn.conv1d_transpose", | ||
| "nn.conv2d_transpose", | ||
| "nn.conv3d_transpose", | ||
| "nn.dense", | ||
| # "nn.batch_matmul", # Handled by a special case | ||
| ] | ||
| DEFAULT_FOLLOW_LIST = [ | ||
| # These ops add new data or change shape | ||
| "nn.pad", | ||
| "nn.batch_flatten", | ||
| "concatenate", | ||
| "zeros", | ||
| "split", | ||
| "squeeze", | ||
| "transpose", | ||
| "expand_dims", | ||
| "reshape", | ||
| "dyn.reshape", | ||
| "broadcast_to_like", | ||
| "dyn.broadcast_to", | ||
| "strided_slice", | ||
| "dyn.strided_slice", | ||
| "take", | ||
| "argwhere", | ||
| "where", | ||
| "tile", | ||
| "dyn.tile", | ||
| "scatter", | ||
| "full", | ||
| "dyn.full", | ||
| # Comparison | ||
| "less", | ||
| "greater", | ||
| "less_equal", | ||
| "greater_equal", | ||
| # By definition copy and cast will depend on inputs for output. | ||
| "copy", | ||
| "cast", | ||
| "cast_like", | ||
| # Simple arithmetic | ||
| "add", | ||
| "subtract", | ||
| "multiply", | ||
| "divide", | ||
| "nn.bias_add", | ||
| "nn.batch_norm", | ||
| "sum", | ||
| "mean", | ||
| "sqrt", | ||
| "shape_of", | ||
| # Simple activations | ||
| "max", | ||
| "min", | ||
| "maximum", | ||
| "minimum", | ||
| "nn.relu", | ||
| "nn.leaky_relu", | ||
| "nn.prelu", | ||
| "nn.dropout", | ||
| # Complicated activations which saturate in a narrow range | ||
| "sigmoid", | ||
| "tanh", | ||
| # Pooling operations | ||
| "nn.max_pool1d", | ||
| "nn.max_pool2d", | ||
| "nn.max_pool3d", | ||
| "nn.avg_pool1d", | ||
| "nn.avg_pool2d", | ||
| "nn.avg_pool3d", | ||
| # "nn.global_max_pool1d", # does not exist yet | ||
| "nn.global_max_pool2d", | ||
| # "nn.global_max_pool3d", # does not exist yet | ||
| # "nn.global_avg_pool1d", # does not exist yet | ||
| "nn.global_avg_pool2d", | ||
| # "nn.global_avg_pool3d", # does not exist yet | ||
| "nn.adaptive_max_pool1d", | ||
| "nn.adaptive_max_pool2d", | ||
| "nn.adaptive_max_pool3d", | ||
| "nn.adaptive_avg_pool1d", | ||
| "nn.adaptive_avg_pool2d", | ||
| "nn.adaptive_avg_pool3d", | ||
| ] | ||
| DEFAULT_NEVER_LIST = [ | ||
| # In general if |f(x)| >> |x| for expected inputs then put the op here. | ||
| "exp", | ||
| "power", | ||
| "nn.cross_entropy", | ||
| "nn.cross_entropy_with_logits", | ||
| "nn.softmax", | ||
| "nn.l2_normalize", | ||
| # Error function doesn't seem to be able to be lowered into fp16 version in llvm. | ||
| # Move to follow list when it does. | ||
| "erf", | ||
| ] | ||
|
|
||
|
|
||
| # Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType | ||
| def register_func_to_op_list(list_ops: List): | ||
| def decorator(func): | ||
| for op_name in list_ops: | ||
| register_mixed_precision_conversion(op_name, func=func) | ||
|
|
||
| return decorator | ||
|
|
||
|
|
||
| def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]: | ||
| """A function which returns output dtypes in a way which works for most ops. | ||
|
|
||
| Parameters | ||
| --------- | ||
| call_node: relay.Call | ||
| The call node containing the op. | ||
| mixed_precision_type: str | ||
| The target type to run the operation in. | ||
| Returns | ||
| ------- | ||
| output_dtypes : [str, str] | ||
| A list of two strings. The first represents the datatype used for accumulation | ||
| in the operation. The second represents the actual output datatype. | ||
| """ | ||
| # Assume support accumulation dtypes <---> has out_dtype attr. | ||
| # This is because there is no better way right now to tell which ops support accumulating | ||
| # at different data types. | ||
| # Some discussion here about making this better is here: | ||
| # https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994/4?u=andrewzhaoluo | ||
| if hasattr(call_node.attrs, "out_dtype"): | ||
| return ["float32", mixed_precision_type] | ||
|
|
||
| # [accumulation_dtype, output_dtype] for the operations | ||
| return [mixed_precision_type, mixed_precision_type] | ||
|
|
||
|
|
||
| # Functions for FTVMMixedPrecisionConversionType which | ||
| # Take in CallNodes and a DType and returns a conversion type, | ||
| # an accumulation dtype, and an output_dtype. | ||
| @register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST) | ||
| def generic_always_op(call_node: relay.Call, mixed_precision_type: str) -> List: | ||
| return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type) | ||
|
|
||
|
|
||
| @register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST) | ||
| def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List: | ||
| return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type) | ||
|
|
||
|
|
||
| @register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST) | ||
| def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List: | ||
| return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type) | ||
|
|
||
|
|
||
| @register_mixed_precision_conversion("nn.batch_matmul") | ||
| def nn_batch_matmul(call_node: relay.Call, mixed_precision_type: str) -> List: | ||
| # TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. | ||
| # Batched matmul has inconsistent support for mixed precision operations. | ||
| # Many schedules ignore the out_dtype attribute which leads to errors when | ||
| # input types do not match the out_dtype. Therefore, accumulate to output_dtype. | ||
| return [MIXED_PRECISION_ALWAYS, "float16", "float16"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.