Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions python/tvm/relay/transform/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"nn.conv2d_transpose",
"nn.conv3d_transpose",
"nn.dense",
# "nn.batch_matmul", # Handled by a special case
"nn.batch_matmul",
]
DEFAULT_FOLLOW_LIST = [
# These ops add new data or change shape
Expand Down Expand Up @@ -162,7 +162,9 @@ def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) ->
# Some discussion here about making this better is here:
# https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994/4?u=andrewzhaoluo
if hasattr(call_node.attrs, "out_dtype"):
return ["float32", mixed_precision_type]
# TODO (AndrewZhaoLuo): evaluate consistent support for mixed_type accumulators
# return ["float32", mixed_precision_type]
return [mixed_precision_type, mixed_precision_type]

# [accumulation_dtype, output_dtype] for the operations
return [mixed_precision_type, mixed_precision_type]
Expand All @@ -184,12 +186,3 @@ def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List:
@register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST)
def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List:
return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type)


@register_mixed_precision_conversion("nn.batch_matmul")
def nn_batch_matmul(call_node: relay.Call, mixed_precision_type: str) -> List:
# TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well.
# Batched matmul has inconsistent support for mixed precision operations.
# Many schedules ignore the out_dtype attribute which leads to errors when
# input types do not match the out_dtype. Therefore, accumulate to output_dtype.
return [MIXED_PRECISION_ALWAYS, "float16", "float16"]
9 changes: 3 additions & 6 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
"""
import numpy as np
import tvm
from tvm import te
import tvm.testing
import tvm.topi.testing
from tvm import relay
from tvm import relay, te, topi
from tvm.relay import transform
from tvm.relay.testing import run_infer_type
from tvm import topi
import tvm.topi.testing
import tvm.testing


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -608,7 +605,7 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, device=dev, target=target)
out_relay = intrp.evaluate(func)(predictions_np, targets_np, weights_np)
tvm.testing.assert_allclose(out_relay.asnumpy(), out_np, rtol=1e-4, atol=1e-5)
tvm.testing.assert_allclose(out_relay.asnumpy(), out_np, rtol=1e-6, atol=1e-6)

_verify((10, 5))
_verify((10, 5, 2, 2))
Expand Down
84 changes: 39 additions & 45 deletions tests/python/relay/test_to_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def verify_mixed_precision_output_close(
result_fp32 = run_module(mod, mod_params)
fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
result_fp16 = run_module(fp16_mod, mod_params)

# Ensure the results are close
for fp32, fp16 in zip(result_fp32, result_fp16):
np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)
Expand All @@ -60,7 +61,9 @@ def test_lstm():

Has internal functions and let statements the pass must work on.
"""
units = 3
# TODO(AndrewZhaoLuo): investigate why non-even units cause failure in codegen for CUDA
# See discussion here: https://github.com/apache/tvm/issues/8294#issuecomment-866190408
units = 4
iterations = 5
mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units)

Expand Down Expand Up @@ -118,16 +121,13 @@ def test_convert_single_conv():
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)

expected_mod = tvm.IRModule.from_expr(
relay.cast(
relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float32",
),
"float16",
)
relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float16",
),
)
expected_mod = tvm.relay.transform.InferType()(expected_mod)

Expand Down Expand Up @@ -156,16 +156,13 @@ def test_convert_single_conv_fp64():
# Note we still accumulate to FP32 by default, a user would need to overwrite default
# behavior to make this make more sense.
expected_mod = tvm.IRModule.from_expr(
relay.cast(
relay.nn.conv2d(
relay.cast(data, "float64"),
relay.cast(weight, "float64"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float32",
),
"float64",
)
relay.nn.conv2d(
relay.cast(data, "float64"),
relay.cast(weight, "float64"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float64",
),
)
expected_mod = tvm.relay.transform.InferType()(expected_mod)

Expand Down Expand Up @@ -198,15 +195,12 @@ def test_convert_conv_bn():
"moving_mean": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
"moving_var": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
}
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.025, rtol=0.01)

# Creating expected module
data = relay.cast(relay.var("data", shape=data_shape), "float16")
weight = relay.cast(relay.var("weight", shape=weight_shape), "float16")
conv = relay.cast(
relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32"),
"float16",
)
conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float16")

bn_shape = [5]
gamma = relay.cast(relay.var("gamma", shape=bn_shape), "float16")
Expand Down Expand Up @@ -254,17 +248,14 @@ def test_green_gray_propagates_simple():
"data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
"weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
}
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01)

conv_expr = relay.cast(
relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float32",
),
"float16",
conv_expr = relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float16",
)
expected_mod = tvm.IRModule.from_expr(conv_expr + conv_expr)
expected_mod = tvm.relay.transform.InferType()(expected_mod)
Expand Down Expand Up @@ -316,12 +307,15 @@ def test_green_red_not_use_extraneous_cast():
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)

# Construct expected structure
conv = relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float32",
conv = relay.cast(
relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float16",
),
"float32",
)
result = relay.nn.softmax(conv)
expected_mod = tvm.IRModule.from_expr(result)
Expand Down Expand Up @@ -380,12 +374,12 @@ def test_let_statement_simple():
r2 = var2 + var2
let2 = relay.Let(
var2,
relay.cast(relay.nn.dense(r1, weight, units=20, out_dtype="float32"), "float16"),
relay.nn.dense(r1, weight, units=20, out_dtype="float16"),
r2,
)
let1 = relay.Let(
var1,
relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16"),
relay.nn.dense(data, weight, units=20, out_dtype="float16"),
let2,
)
expected_mod = tvm.IRModule.from_expr(let1)
Expand All @@ -410,7 +404,7 @@ def test_where_simple():
# Create expected module
data = relay.cast(relay.var("data", shape=[1, 20]), "float16")
weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16")
a = relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16")
a = relay.nn.dense(data, weight, units=20, out_dtype="float16")
b = relay.where(data, a, a)
expected_mod = tvm.IRModule.from_expr(b)
expected_mod = InferType()(expected_mod)
Expand Down