diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 7432967c290d..3b94ba1d6672 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1155,11 +1155,12 @@ class FastGelu(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): - if inputs[1]: + x = inputs[0] + if len(inputs) > 1 and inputs[1] is not None: bias = inputs[1] bias_shape = bias.struct_info.shape assert len(bias_shape) == 1, "bias term must be a 1D tensor" - x += bias + x = bb.emit(relax.op.add(x, bias)) # Declare consts const_dtype = x.struct_info.dtype @@ -1169,11 +1170,13 @@ def _impl_v1(cls, bb, inputs, attr, params): const2 = relax.const(0.044715 * math.sqrt(2 / math.pi), dtype=const_dtype) # Compute FastGelu - term1 = relax.op.multiply(half, x) - term2 = relax.op.multiply(const1, x) - term3 = relax.op.multiply(const2, relax.op.power(x, relax.const(3, const_dtype))) - tanh = relax.op.tanh(relax.op.add(term2, term3)) - return relax.op.multiply(term1, relax.op.add(one, tanh)) + term1 = bb.emit(relax.op.multiply(half, x)) + term2 = bb.emit(relax.op.multiply(const1, x)) + # use x^3 = x * x * x instead of pow(x, 3) for better performance + x_cubed = bb.emit(relax.op.multiply(relax.op.multiply(x, x), x)) + term3 = bb.emit(relax.op.multiply(const2, x_cubed)) + tanh = bb.emit(relax.op.tanh(relax.op.add(term2, term3))) + return bb.emit(relax.op.multiply(term1, relax.op.add(one, tanh))) class BiasGelu(OnnxOpConverter): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index e4960e5b1a4d..a8d434e89434 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -828,6 +828,36 @@ def test_bias_gelu(): verify_binary("BiasGelu", [32, 32], [32], [32, 32], domain="com.microsoft") +def test_fast_gelu(): + """Test FastGelu with and without bias""" + # Test FastGelu without bias + fast_gelu_node = helper.make_node("FastGelu", ["x"], ["y"], domain="com.microsoft") + graph = helper.make_graph( + [fast_gelu_node], + "fast_gelu_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [32, 32])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="fast_gelu_test") + check_correctness(model) + + # Test FastGelu with bias + fast_gelu_with_bias_node = helper.make_node( + "FastGelu", ["x", "bias"], ["y"], domain="com.microsoft" + ) + graph_with_bias = helper.make_graph( + [fast_gelu_with_bias_node], + "fast_gelu_with_bias_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, [32]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + model_with_bias = helper.make_model(graph_with_bias, producer_name="fast_gelu_with_bias_test") + check_correctness(model_with_bias) + + def test_where(): where_node = helper.make_node("Where", ["a", "b", "c"], ["d"])