From 3d08ee1513da44051d258a292334064683d13579 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Wed, 14 Sep 2022 16:59:24 +0800 Subject: [PATCH 1/3] refine AMP for bfloat16 --- src/relay/transforms/to_mixed_precision.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index e1d3a264c222..18161b3c2508 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -161,7 +161,9 @@ class MixedPrecisionPass : public MixedModeMutator { */ DataType cur_type = (attrs->out_dtype); ObjectPtr new_attrs = make_object(*attrs); - if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype; + if (cur_type.is_float() || cur_type.is_bfloat16() || cur_type.is_void()) { + new_attrs->out_dtype = accumulation_dtype; + } return Attrs(new_attrs); } @@ -175,7 +177,9 @@ class MixedPrecisionPass : public MixedModeMutator { */ DataType cur_type = (attrs->dtype); ObjectPtr new_attrs = make_object(*attrs); - if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype; + if (cur_type.is_float() || cur_type.is_bfloat16() || cur_type.is_void()) { + new_attrs->dtype = accumulation_dtype; + } return Attrs(new_attrs); } @@ -217,7 +221,7 @@ class MixedPrecisionPass : public MixedModeMutator { /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */ // If this is not a floating point type, do not cast. E.g. it might be an integer - if (!expr_dtype.is_float()) { + if (!(expr_dtype.is_float() || expr_dtype.is_bfloat16())) { return expr; } @@ -299,7 +303,7 @@ class MixedPrecisionPass : public MixedModeMutator { original_dtype_.push_back((root_->checked_type_).as()->dtype); } } - if (!mixed_precision_type_.is_float() && !mixed_precision_type_.is_bfloat16()) { + if (!(mixed_precision_type_.is_float() || mixed_precision_type_.is_bfloat16())) { LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got " << mixed_precision_type_; } From d9005be47e98eabbb164bb7036508cda9b99e29c Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Wed, 14 Sep 2022 16:59:50 +0800 Subject: [PATCH 2/3] refine AMP tests to cover bfloat16 --- tests/python/relay/test_to_mixed_precision.py | 178 ++++++++++-------- 1 file changed, 104 insertions(+), 74 deletions(-) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 026b458bde12..51d040c311f4 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -24,6 +24,12 @@ from tvm.relay.testing import lstm from tvm.relay.transform import InferType, ToMixedPrecision, mixed_precision +target_precision = tvm.testing.parameter( + pytest.param("float16"), + pytest.param("bfloat16"), + ids=["float16", "bfloat16"], +) + def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: dev = tvm.device("llvm", 0) @@ -48,28 +54,29 @@ def verify_mixed_precision_output_close( result_fp32 = run_module(mod, mod_params) if not keep_orig_output_dtype: - fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) - result_fp16 = run_module(fp16_mod, mod_params) + amp_mod = ToMixedPrecision(mixed_precision_dtype)(mod) + result_amp = run_module(amp_mod, mod_params) else: with tvm.transform.PassContext( config={"relay.ToMixedPrecision.keep_orig_output_dtype": True} ): - fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) - result_fp16 = run_module(fp16_mod, mod_params) + amp_mod = ToMixedPrecision(mixed_precision_dtype)(mod) + result_amp = run_module(amp_mod, mod_params) # Ensure the results are close - for fp32, fp16 in zip(result_fp32, result_fp16): - np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol) + if mixed_precision_dtype != "bfloat16": + for fp32, amp in zip(result_fp32, result_amp): + np.testing.assert_allclose(fp32, amp, rtol=rtol, atol=atol) if keep_orig_output_dtype: assert ( - np.array(result_fp16).dtype == np.array(result_fp32).dtype + np.array(result_amp).dtype == np.array(result_fp32).dtype ), "output type and original type mismatch" - return fp16_mod + return amp_mod -def test_lstm(): +def test_lstm(target_precision): """A small stress test on a single unrolled lstm unit. Has internal functions and let statements the pass must work on. @@ -87,7 +94,9 @@ def test_lstm(): -10, 10, (1, units) ).astype("float32") - verify_mixed_precision_output_close(mod, mod_params, rtol=0.01, atol=0.01) + verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, rtol=0.01, atol=0.01 + ) def test_lstm_float64(): @@ -114,7 +123,7 @@ def test_lstm_float64(): ) -def test_convert_single_conv(): +def test_convert_single_conv(target_precision): """Conv is a green listed operation meaning it will always use fp16 workload. By default it accumulates to fp32 and outputs fp16. @@ -131,26 +140,31 @@ def test_convert_single_conv(): "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), } - fp16_mod = verify_mixed_precision_output_close( - mod, mod_params, atol=0.01, rtol=1e-3, keep_orig_output_dtype=True + amp_mod = verify_mixed_precision_output_close( + mod, + mod_params, + mixed_precision_dtype=target_precision, + atol=0.01, + rtol=1e-3, + keep_orig_output_dtype=True, ) expected_mod = tvm.IRModule.from_expr( relay.cast( relay.nn.conv2d( - relay.cast(data, "float16"), - relay.cast(weight, "float16"), + relay.cast(data, target_precision), + relay.cast(weight, target_precision), strides=(1, 1), padding=(1, 1), - out_dtype="float16", + out_dtype=target_precision, ), "float32", ) ) expected_mod = tvm.relay.transform.InferType()(expected_mod) - assert not tvm.ir.structural_equal(fp16_mod, mod) - assert tvm.ir.structural_equal(fp16_mod, expected_mod) + assert not tvm.ir.structural_equal(amp_mod, mod) + assert tvm.ir.structural_equal(amp_mod, expected_mod) def test_convert_single_conv_fp64(): @@ -167,7 +181,7 @@ def test_convert_single_conv_fp64(): "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), } - fp16_mod = verify_mixed_precision_output_close( + amp_mod = verify_mixed_precision_output_close( mod, mod_params, mixed_precision_dtype="float64", atol=0.01, rtol=1e-3 ) @@ -184,11 +198,11 @@ def test_convert_single_conv_fp64(): ) expected_mod = tvm.relay.transform.InferType()(expected_mod) - assert not tvm.ir.structural_equal(fp16_mod, mod) - assert tvm.ir.structural_equal(fp16_mod, expected_mod) + assert not tvm.ir.structural_equal(amp_mod, mod) + assert tvm.ir.structural_equal(amp_mod, expected_mod) -def test_convert_conv_bn(): +def test_convert_conv_bn(target_precision): """Conv is green and batch norm is gray. As Conv should output fp16 batch_norm should be green.""" data_shape = (1, 3, 32, 32) weight_shape = (5, 3, 3, 3) @@ -213,49 +227,51 @@ def test_convert_conv_bn(): "moving_mean": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), "moving_var": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), } - fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.025, rtol=0.01) + amp_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.025, rtol=0.01 + ) # Creating expected module - data = relay.cast(relay.var("data", shape=data_shape), "float16") - weight = relay.cast(relay.var("weight", shape=weight_shape), "float16") - conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float16") + data = relay.cast(relay.var("data", shape=data_shape), target_precision) + weight = relay.cast(relay.var("weight", shape=weight_shape), target_precision) + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype=target_precision) bn_shape = [5] - gamma = relay.cast(relay.var("gamma", shape=bn_shape), "float16") - beta = relay.cast(relay.var("beta", shape=bn_shape), "float16") - moving_mean = relay.cast(relay.var("moving_mean", shape=bn_shape), "float16") - moving_var = relay.cast(relay.var("moving_var", shape=bn_shape), "float16") + gamma = relay.cast(relay.var("gamma", shape=bn_shape), target_precision) + beta = relay.cast(relay.var("beta", shape=bn_shape), target_precision) + moving_mean = relay.cast(relay.var("moving_mean", shape=bn_shape), target_precision) + moving_var = relay.cast(relay.var("moving_var", shape=bn_shape), target_precision) bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var) expected_mod = tvm.IRModule.from_expr(bn[0]) expected_mod = tvm.relay.transform.InferType()(expected_mod) - assert not tvm.ir.structural_equal(fp16_mod, mod) - assert tvm.ir.structural_equal(fp16_mod, expected_mod) + assert not tvm.ir.structural_equal(amp_mod, mod) + assert tvm.ir.structural_equal(amp_mod, expected_mod) -def test_do_not_convert_softmax(): +def test_do_not_convert_softmax(target_precision): """Softmax is a red listed operation and therefore should never be fp16.""" shape = [1, 2, 3] a = relay.var("a", shape=shape) b = relay.nn.softmax(a) mod = tvm.IRModule.from_expr(b) mod = tvm.relay.transform.InferType()(mod) - out_mod = ToMixedPrecision("float16")(mod) + out_mod = ToMixedPrecision(target_precision)(mod) orig_mod = tvm.relay.transform.InferType()(mod) assert tvm.ir.structural_equal(orig_mod, out_mod) -def test_do_not_convert_arange(): +def test_do_not_convert_arange(target_precision): """Arange is a red listed operation and therefore should never be fp16.""" dtype = "float32" arange = relay.arange(relay.const(1, dtype), relay.const(128, dtype)) mod = tvm.IRModule.from_expr(arange) - out_mod = ToMixedPrecision("float16")(mod) + out_mod = ToMixedPrecision(target_precision)(mod) orig_mod = tvm.relay.transform.InferType()(mod) assert tvm.ir.structural_equal(orig_mod, out_mod) -def test_do_not_convert_summation(): +def test_do_not_convert_summation(target_precision): """Ops that could involve a large summation are not allowed in fp16.""" shape = [1, 3, 16, 16] a = relay.var("a", shape=shape) @@ -267,12 +283,12 @@ def test_do_not_convert_summation(): ] for op in ops: mod = tvm.IRModule.from_expr(op(a)) - out_mod = ToMixedPrecision("float16")(mod) + out_mod = ToMixedPrecision(target_precision)(mod) orig_mod = tvm.relay.transform.InferType()(mod) assert tvm.ir.structural_equal(orig_mod, out_mod) -def test_green_gray_propagates_simple(): +def test_green_gray_propagates_simple(target_precision): """Conv is a green listed operation, while addition is gray. As Conv outputs fp16 the add should be done in fp16. @@ -290,23 +306,25 @@ def test_green_gray_propagates_simple(): "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), } - fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + amp_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01 + ) conv_expr = relay.nn.conv2d( - relay.cast(data, "float16"), - relay.cast(weight, "float16"), + relay.cast(data, target_precision), + relay.cast(weight, target_precision), strides=(1, 1), padding=(1, 1), - out_dtype="float16", + out_dtype=target_precision, ) expected_mod = tvm.IRModule.from_expr(conv_expr + conv_expr) expected_mod = tvm.relay.transform.InferType()(expected_mod) - assert not tvm.ir.structural_equal(fp16_mod, mod) - assert tvm.ir.structural_equal(fp16_mod, expected_mod) + assert not tvm.ir.structural_equal(amp_mod, mod) + assert tvm.ir.structural_equal(amp_mod, expected_mod) -def test_green_red_not_use_extraneous_cast(): +def test_green_red_not_use_extraneous_cast(target_precision): """Conv. is a green listed operation, while softmax is red. Conv. also by default accumulates to fp32 but outputs fp16. @@ -346,16 +364,18 @@ def test_green_red_not_use_extraneous_cast(): "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), } - fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + amp_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=1e-3 + ) # Construct expected structure conv = relay.cast( relay.nn.conv2d( - relay.cast(data, "float16"), - relay.cast(weight, "float16"), + relay.cast(data, target_precision), + relay.cast(weight, target_precision), strides=(1, 1), padding=(1, 1), - out_dtype="float16", + out_dtype=target_precision, ), "float32", ) @@ -363,10 +383,10 @@ def test_green_red_not_use_extraneous_cast(): expected_mod = tvm.IRModule.from_expr(result) expected_mod = InferType()(expected_mod) - assert tvm.ir.structural_equal(expected_mod, fp16_mod) + assert tvm.ir.structural_equal(expected_mod, amp_mod) -def test_red_gray_propagates_simple(): +def test_red_gray_propagates_simple(target_precision): """Everything after a softmax should be in FP32 (exception green colored ops)""" shape = [1, 2, 3] a = relay.var("a", shape=shape) @@ -378,12 +398,14 @@ def test_red_gray_propagates_simple(): mod_params = { "a": np.random.uniform(-1, 1, size=shape).astype("float32"), } - output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0.0) + output_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.0, rtol=0.0 + ) assert tvm.ir.structural_equal(mod, output_mod) -def test_let_statement_simple(): +def test_let_statement_simple(target_precision): """A 'simple' let statement example. Noticeable is the mutation of the bound variable types. @@ -405,23 +427,25 @@ def test_let_statement_simple(): "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"), "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"), } - output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.05, rtol=0.15) + output_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.05, rtol=0.15 + ) # Construct expected structure - var1 = relay.var("var1", shape=[1, 20], dtype="float16") - var2 = relay.var("var2", shape=[1, 20], dtype="float16") - data = relay.cast(relay.var("data", shape=[1, 20]), "float16") - weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16") + var1 = relay.var("var1", shape=[1, 20], dtype=target_precision) + var2 = relay.var("var2", shape=[1, 20], dtype=target_precision) + data = relay.cast(relay.var("data", shape=[1, 20]), target_precision) + weight = relay.cast(relay.var("weight", shape=[20, 20]), target_precision) r1 = var1 + var1 r2 = var2 + var2 let2 = relay.Let( var2, - relay.nn.dense(r1, weight, units=20, out_dtype="float16"), + relay.nn.dense(r1, weight, units=20, out_dtype=target_precision), r2, ) let1 = relay.Let( var1, - relay.nn.dense(data, weight, units=20, out_dtype="float16"), + relay.nn.dense(data, weight, units=20, out_dtype=target_precision), let2, ) expected_mod = tvm.IRModule.from_expr(let1) @@ -430,7 +454,7 @@ def test_let_statement_simple(): assert tvm.ir.structural_equal(expected_mod, output_mod) -def test_where_simple(): +def test_where_simple(target_precision): data = relay.var("data", shape=[1, 20]) weight = relay.var("weight", shape=[20, 20]) a = relay.nn.dense(data, weight, units=20) @@ -441,12 +465,14 @@ def test_where_simple(): "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"), } - output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + output_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01 + ) # Create expected module - data = relay.cast(relay.var("data", shape=[1, 20]), "float16") - weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16") - a = relay.nn.dense(data, weight, units=20, out_dtype="float16") + data = relay.cast(relay.var("data", shape=[1, 20]), target_precision) + weight = relay.cast(relay.var("weight", shape=[20, 20]), target_precision) + a = relay.nn.dense(data, weight, units=20, out_dtype=target_precision) b = relay.where(data, a, a) expected_mod = tvm.IRModule.from_expr(b) expected_mod = InferType()(expected_mod) @@ -454,7 +480,7 @@ def test_where_simple(): assert tvm.ir.structural_equal(expected_mod, output_mod) -def test_batch_matmul_simple(): +def test_batch_matmul_simple(target_precision): """Batch matmul is a special case where we try to accumulate to fp16. This is due to the fact heterogenous accumulation dtypes does not work @@ -468,17 +494,19 @@ def test_batch_matmul_simple(): "data": np.random.uniform(-1, 1, size=[1, 1, 20]).astype("float32"), "weight": np.random.uniform(-1, 1, size=[1, 20, 20]).astype("float32"), } - output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + output_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01 + ) # Create expected module - data = relay.cast(relay.var("data", shape=[1, 1, 20]), "float16") - weight = relay.cast(relay.var("weight", shape=[1, 20, 20]), "float16") - a = relay.nn.batch_matmul(data, weight, out_dtype="float16") + data = relay.cast(relay.var("data", shape=[1, 1, 20]), target_precision) + weight = relay.cast(relay.var("weight", shape=[1, 20, 20]), target_precision) + a = relay.nn.batch_matmul(data, weight, out_dtype=target_precision) expected_mod = tvm.IRModule.from_expr(a) expected_mod = InferType()(expected_mod) assert tvm.ir.structural_equal(expected_mod, output_mod) -def test_convert_follow_node_with_integer_arguments(): +def test_convert_follow_node_with_integer_arguments(target_precision): """Tests the conversion of a follow op with integer arguments + constant float args. The follow op should convert the floating point argument into fp16 as constants/vars @@ -497,10 +525,12 @@ def test_convert_follow_node_with_integer_arguments(): "data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"), "indices": np.array([[0]]).astype("int32"), } - output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + output_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01 + ) # Create expected module - data = relay.cast(relay.var("data", shape=[1, 10]), "float16") + data = relay.cast(relay.var("data", shape=[1, 10]), target_precision) take = relay.take(data, indices, axis=0) expected_mod = tvm.IRModule.from_expr(take) expected_mod = InferType()(expected_mod) From 9ea7b7e32b1e3d7abe8daae8524e81afa886a2b2 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Fri, 16 Sep 2022 10:07:44 +0800 Subject: [PATCH 3/3] refine accuracy checking for dnnl bf16 --- tests/python/contrib/test_dnnl.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index c4adc9785c19..f23b3c70aa96 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -150,9 +150,8 @@ def assert_result_dict_holds(result_dict): res1 = vmobj_to_list(result_dict[k1]) res2 = vmobj_to_list(result_dict[k2]) for r1, r2 in zip(res1, res2): - if "bf16" in k1 or "bf16" in k2: - np.testing.assert_array_almost_equal(r1, r2, decimal=1) - else: + # ignore the accuracy checking if only one bf16 result presents + if ("bf16" in k1) == ("bf16" in k2): tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3)