Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relay/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# transformation passes
from .transform import *
from .recast import recast
from . import fake_quantization_to_integer
from . import fake_quantization_to_integer, mixed_precision
9 changes: 4 additions & 5 deletions python/tvm/relay/transform/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Default behavior for ops in mixed_precision pass. Import this file to use."""
from typing import List

from tvm import relay
from tvm.relay.op import register_mixed_precision_conversion

# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
Expand Down Expand Up @@ -141,7 +140,7 @@ def decorator(func):
return decorator


def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]:
def get_generic_out_dtypes(call_node: "relay.Call", mixed_precision_type: str) -> List[str]:
"""A function which returns output dtypes in a way which works for most ops.

Parameters
Expand Down Expand Up @@ -174,15 +173,15 @@ def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) ->
# Take in CallNodes and a DType and returns a conversion type,
# an accumulation dtype, and an output_dtype.
@register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST)
def generic_always_op(call_node: relay.Call, mixed_precision_type: str) -> List:
def generic_always_op(call_node: "relay.Call", mixed_precision_type: str) -> List:
return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type)


@register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST)
def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List:
def generic_follow_op(call_node: "relay.Call", mixed_precision_type: str) -> List:
return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type)


@register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST)
def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List:
def generic_never_op(call_node: "relay.Call", mixed_precision_type: str) -> List:
return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type)