From a541908839011ecf2c1e6bb0d70a9e184a7093cd Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Tue, 18 Oct 2022 16:12:15 +0300 Subject: [PATCH 1/4] add converter for FastGelu from Microsoft onnxruntime contrib opset --- python/tvm/relay/frontend/onnx.py | 32 ++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 84a5fc3b8237..73e9e7cc78a0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -926,10 +926,39 @@ def _impl_v1(cls, inputs, attr, params): return _op.multiply(term1, term2) +class FastGelu(OnnxOpConverter): + """Operator converter for FastGelu from Microsoft onnxruntime contrib opset. + + fast_gelu(x) = 0.5x(1 + tanh(sqrt(2/pi)(x + 0.044715x^3))) + = 0.5x(1 + tanh((sqrt(2/pi)x + 0.044715(sqrt(2/pi)x^3))) + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + x = inputs[0] + if inputs[1]: + bias = inputs[1] + x += bias + + # Declare consts + const_dtype = infer_type(x).checked_type.dtype + half = _expr.const(0.5, dtype=const_dtype) + one = _expr.const(1.0, dtype=const_dtype) + const1 = _expr.const(math.sqrt(2 / math.pi), dtype=const_dtype) + const2 = _expr.const(0.044715 * math.sqrt(2 / math.pi), dtype=const_dtype) + + # Compute FastGelu + term1 = _op.multiply(half, x) + term2 = _op.multiply(const1, x) + term3 = _op.multiply(const2, _op.power(x, _expr.const(3, const_dtype))) + tanh = _op.tanh(_op.add(term2, term3)) + return _op.multiply(term1, _op.add(one, tanh)) + + class BiasGelu(OnnxOpConverter): """Operator converter for BiasGelu from Microsoft onnxruntime contrib opset. - bias_gelu(x, b) = 0.5(x, b)(1 + erf((x + b)/sqrt(2))) + bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2))) """ @classmethod @@ -5335,6 +5364,7 @@ def _get_convert_map(opset): "Selu": Selu.get_converter(opset), "Elu": Elu.get_converter(opset), "Gelu": Gelu.get_converter(opset), + "FastGelu": FastGelu.get_converter(opset), "BiasGelu": BiasGelu.get_converter(opset), # TODO: We need a better way to handle different domains, in case # of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention From 1feebd4e3216e1860349734395a374c9f03914de Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Tue, 18 Oct 2022 16:13:54 +0300 Subject: [PATCH 2/4] integrate FastGelu into test system for ONNX converters --- tests/python/frontend/onnx/test_forward.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index da6f5785023d..352be48d8213 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5638,13 +5638,14 @@ def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis): verify_reverse_sequence(x, sequence_lens, 1, 0) +@pytest.mark.parametrize("op_name", ["Gelu", "FastGelu"], scope="session") @tvm.testing.parametrize_targets -def test_gelu(target, dev): +def test_gelu(target, dev, op_name): """test_gelu""" def verify_gelu(x): node = onnx.helper.make_node( - "Gelu", + op_name, inputs=["x"], outputs=["y"], domain="com.microsoft", @@ -5666,13 +5667,14 @@ def verify_gelu(x): verify_gelu(x) +@pytest.mark.parametrize("op_name", ["BiasGelu", "FastGelu"], scope="session") @tvm.testing.parametrize_targets -def test_biasgelu(target, dev): +def test_biasgelu(target, dev, op_name): """test_biasgelu""" def verify_biasgelu(x, bias): node = onnx.helper.make_node( - "BiasGelu", + op_name, inputs=["x", "bias"], outputs=["y"], domain="com.microsoft", From c37bc161ca13036698068976375cf3f186efc2d4 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Fri, 21 Oct 2022 16:16:07 +0300 Subject: [PATCH 3/4] code review fixes --- python/tvm/relay/frontend/onnx.py | 10 +++- tests/python/frontend/onnx/test_forward.py | 56 ++++++++++++++-------- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 73e9e7cc78a0..adba6f8bcdd2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -931,6 +931,10 @@ class FastGelu(OnnxOpConverter): fast_gelu(x) = 0.5x(1 + tanh(sqrt(2/pi)(x + 0.044715x^3))) = 0.5x(1 + tanh((sqrt(2/pi)x + 0.044715(sqrt(2/pi)x^3))) + = 0.5x(1 + tanh(c1 * x + c2 * x^3))) + , where + c1 = sqrt(2/pi) + c2 = 0.044715 * sqrt(2/pi) """ @classmethod @@ -938,14 +942,16 @@ def _impl_v1(cls, inputs, attr, params): x = inputs[0] if inputs[1]: bias = inputs[1] + bias_shape = infer_shape(bias) + assert len(bias_shape) == 1, "bias term must be a 1D tensor" x += bias # Declare consts const_dtype = infer_type(x).checked_type.dtype half = _expr.const(0.5, dtype=const_dtype) one = _expr.const(1.0, dtype=const_dtype) - const1 = _expr.const(math.sqrt(2 / math.pi), dtype=const_dtype) - const2 = _expr.const(0.044715 * math.sqrt(2 / math.pi), dtype=const_dtype) + const1 = _expr.const(0.7978845608028654, dtype=const_dtype) # sqrt(2.0 / PI) + const2 = _expr.const(0.0356774081363001, dtype=const_dtype) # 0.044715 * sqrt(2.0 / PI) # Compute FastGelu term1 = _op.multiply(half, x) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 352be48d8213..dc82f3060360 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5639,9 +5639,13 @@ def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis): @pytest.mark.parametrize("op_name", ["Gelu", "FastGelu"], scope="session") +@pytest.mark.parametrize("data_type", ["float16", "float32"], scope="session") @tvm.testing.parametrize_targets -def test_gelu(target, dev, op_name): +def test_gelu(target, dev, data_type, op_name): """test_gelu""" + dtype = np.dtype(data_type) + tensor_type = mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] + absolute_tolerance = 1e-3 if data_type == "float16" else 1e-5 def verify_gelu(x): node = onnx.helper.make_node( @@ -5653,24 +5657,30 @@ def verify_gelu(x): graph = helper.make_graph( [node], - "gelu_test", - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(x.shape))], + f"{op_name}_test", + inputs=[helper.make_tensor_value_info("x", tensor_type, list(x.shape))], + outputs=[helper.make_tensor_value_info("y", tensor_type, list(x.shape))], ) - model = helper.make_model(graph, producer_name="gelu_test") - verify_with_ort_with_inputs(model, [x], [x.shape], target=target, dev=dev) + model = helper.make_model(graph, producer_name=f"{op_name}_test") + verify_with_ort_with_inputs( + model, [x], [x.shape], atol=absolute_tolerance, dtype=data_type, target=target, dev=dev + ) - x = np.array([-1.0, 0, 1.0, 100.0, -100.0, 1000.0, -1000.0], dtype=np.float32) + x = np.array([-1.0, 0, 1.0, 100.0, -100.0, 1000.0, -1000.0], dtype=dtype) verify_gelu(x) - x = np.array([[1, 2], [3, 4]], dtype=np.float32) + x = np.array([[1, 2], [3, 4]], dtype=dtype) verify_gelu(x) @pytest.mark.parametrize("op_name", ["BiasGelu", "FastGelu"], scope="session") +@pytest.mark.parametrize("data_type", ["float16", "float32"], scope="session") @tvm.testing.parametrize_targets -def test_biasgelu(target, dev, op_name): +def test_biasgelu(target, dev, data_type, op_name): """test_biasgelu""" + dtype = np.dtype(data_type) + tensor_type = mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] + absolute_tolerance = 1e-3 if data_type == "float16" else 1e-5 def verify_biasgelu(x, bias): node = onnx.helper.make_node( @@ -5682,23 +5692,31 @@ def verify_biasgelu(x, bias): graph = helper.make_graph( [node], - "biasgelu_test", + f"{op_name}_test", inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape)), - helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)), + helper.make_tensor_value_info("x", tensor_type, list(x.shape)), + helper.make_tensor_value_info("bias", tensor_type, list(bias.shape)), ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(x.shape))], + outputs=[helper.make_tensor_value_info("y", tensor_type, list(x.shape))], ) - model = helper.make_model(graph, producer_name="biasgelu_test") - verify_with_ort_with_inputs(model, [x, bias], [x.shape], target=target, dev=dev) + model = helper.make_model(graph, producer_name=f"{op_name}_test") + verify_with_ort_with_inputs( + model, + [x, bias], + [x.shape], + atol=absolute_tolerance, + dtype=data_type, + target=target, + dev=dev, + ) - x = np.array([-1.0, 0, 1.0, 100.0, -100.0, 1000.0, -1000.0], dtype=np.float32) - bias = np.repeat(2.0, 7).astype("float32") + x = np.array([-1.0, 0, 1.0, 100.0, -100.0, 1000.0, -1000.0], dtype=dtype) + bias = np.repeat(2.0, 7).astype(dtype) verify_biasgelu(x, bias) - x = np.array([[1, 2], [3, 4]], dtype=np.float32) - bias = np.array([0.3, 4.0], dtype=np.float32) + x = np.array([[1, 2], [3, 4]], dtype=dtype) + bias = np.array([0.3, 4.0], dtype=dtype) verify_biasgelu(x, bias) From 8f9ba0ba478a059ca928fa8d90a7315a1a290cfb Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Fri, 21 Oct 2022 16:45:29 +0300 Subject: [PATCH 4/4] returned constant calculation --- python/tvm/relay/frontend/onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index adba6f8bcdd2..34794b3fdc1d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -950,8 +950,8 @@ def _impl_v1(cls, inputs, attr, params): const_dtype = infer_type(x).checked_type.dtype half = _expr.const(0.5, dtype=const_dtype) one = _expr.const(1.0, dtype=const_dtype) - const1 = _expr.const(0.7978845608028654, dtype=const_dtype) # sqrt(2.0 / PI) - const2 = _expr.const(0.0356774081363001, dtype=const_dtype) # 0.044715 * sqrt(2.0 / PI) + const1 = _expr.const(math.sqrt(2 / math.pi), dtype=const_dtype) + const2 = _expr.const(0.044715 * math.sqrt(2 / math.pi), dtype=const_dtype) # Compute FastGelu term1 = _op.multiply(half, x)