Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,10 +926,45 @@ def _impl_v1(cls, inputs, attr, params):
return _op.multiply(term1, term2)


class FastGelu(OnnxOpConverter):
"""Operator converter for FastGelu from Microsoft onnxruntime contrib opset.

fast_gelu(x) = 0.5x(1 + tanh(sqrt(2/pi)(x + 0.044715x^3)))
= 0.5x(1 + tanh((sqrt(2/pi)x + 0.044715(sqrt(2/pi)x^3)))
= 0.5x(1 + tanh(c1 * x + c2 * x^3)))
, where
c1 = sqrt(2/pi)
c2 = 0.044715 * sqrt(2/pi)
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
x = inputs[0]
if inputs[1]:
bias = inputs[1]
bias_shape = infer_shape(bias)
assert len(bias_shape) == 1, "bias term must be a 1D tensor"
x += bias

# Declare consts
const_dtype = infer_type(x).checked_type.dtype
half = _expr.const(0.5, dtype=const_dtype)
one = _expr.const(1.0, dtype=const_dtype)
const1 = _expr.const(math.sqrt(2 / math.pi), dtype=const_dtype)
const2 = _expr.const(0.044715 * math.sqrt(2 / math.pi), dtype=const_dtype)
Comment on lines +953 to +954
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look at the ONNX documentation. And they use hard-coded constants in their implementation. Let's check that for fp16 we will have the same accuracy with the current implementation as ONNX.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


# Compute FastGelu
term1 = _op.multiply(half, x)
term2 = _op.multiply(const1, x)
term3 = _op.multiply(const2, _op.power(x, _expr.const(3, const_dtype)))
tanh = _op.tanh(_op.add(term2, term3))
return _op.multiply(term1, _op.add(one, tanh))


class BiasGelu(OnnxOpConverter):
"""Operator converter for BiasGelu from Microsoft onnxruntime contrib opset.

bias_gelu(x, b) = 0.5(x, b)(1 + erf((x + b)/sqrt(2)))
bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2)))
"""

@classmethod
Expand Down Expand Up @@ -5335,6 +5370,7 @@ def _get_convert_map(opset):
"Selu": Selu.get_converter(opset),
"Elu": Elu.get_converter(opset),
"Gelu": Gelu.get_converter(opset),
"FastGelu": FastGelu.get_converter(opset),
"BiasGelu": BiasGelu.get_converter(opset),
# TODO: We need a better way to handle different domains, in case
# of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention
Expand Down
62 changes: 41 additions & 21 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5638,65 +5638,85 @@ def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis):
verify_reverse_sequence(x, sequence_lens, 1, 0)


@pytest.mark.parametrize("op_name", ["Gelu", "FastGelu"], scope="session")
@pytest.mark.parametrize("data_type", ["float16", "float32"], scope="session")
@tvm.testing.parametrize_targets
def test_gelu(target, dev):
def test_gelu(target, dev, data_type, op_name):
"""test_gelu"""
dtype = np.dtype(data_type)
tensor_type = mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
absolute_tolerance = 1e-3 if data_type == "float16" else 1e-5

def verify_gelu(x):
node = onnx.helper.make_node(
"Gelu",
op_name,
inputs=["x"],
outputs=["y"],
domain="com.microsoft",
)

graph = helper.make_graph(
[node],
"gelu_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape))],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(x.shape))],
f"{op_name}_test",
inputs=[helper.make_tensor_value_info("x", tensor_type, list(x.shape))],
outputs=[helper.make_tensor_value_info("y", tensor_type, list(x.shape))],
)

model = helper.make_model(graph, producer_name="gelu_test")
verify_with_ort_with_inputs(model, [x], [x.shape], target=target, dev=dev)
model = helper.make_model(graph, producer_name=f"{op_name}_test")
verify_with_ort_with_inputs(
model, [x], [x.shape], atol=absolute_tolerance, dtype=data_type, target=target, dev=dev
)

x = np.array([-1.0, 0, 1.0, 100.0, -100.0, 1000.0, -1000.0], dtype=np.float32)
x = np.array([-1.0, 0, 1.0, 100.0, -100.0, 1000.0, -1000.0], dtype=dtype)
verify_gelu(x)
x = np.array([[1, 2], [3, 4]], dtype=np.float32)
x = np.array([[1, 2], [3, 4]], dtype=dtype)
verify_gelu(x)


@pytest.mark.parametrize("op_name", ["BiasGelu", "FastGelu"], scope="session")
@pytest.mark.parametrize("data_type", ["float16", "float32"], scope="session")
@tvm.testing.parametrize_targets
def test_biasgelu(target, dev):
def test_biasgelu(target, dev, data_type, op_name):
"""test_biasgelu"""
dtype = np.dtype(data_type)
tensor_type = mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
absolute_tolerance = 1e-3 if data_type == "float16" else 1e-5

def verify_biasgelu(x, bias):
node = onnx.helper.make_node(
"BiasGelu",
op_name,
inputs=["x", "bias"],
outputs=["y"],
domain="com.microsoft",
)

graph = helper.make_graph(
[node],
"biasgelu_test",
f"{op_name}_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape)),
helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)),
helper.make_tensor_value_info("x", tensor_type, list(x.shape)),
helper.make_tensor_value_info("bias", tensor_type, list(bias.shape)),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(x.shape))],
outputs=[helper.make_tensor_value_info("y", tensor_type, list(x.shape))],
)

model = helper.make_model(graph, producer_name="biasgelu_test")
verify_with_ort_with_inputs(model, [x, bias], [x.shape], target=target, dev=dev)
model = helper.make_model(graph, producer_name=f"{op_name}_test")
verify_with_ort_with_inputs(
model,
[x, bias],
[x.shape],
atol=absolute_tolerance,
dtype=data_type,
target=target,
dev=dev,
)

x = np.array([-1.0, 0, 1.0, 100.0, -100.0, 1000.0, -1000.0], dtype=np.float32)
bias = np.repeat(2.0, 7).astype("float32")
x = np.array([-1.0, 0, 1.0, 100.0, -100.0, 1000.0, -1000.0], dtype=dtype)
bias = np.repeat(2.0, 7).astype(dtype)
verify_biasgelu(x, bias)

x = np.array([[1, 2], [3, 4]], dtype=np.float32)
bias = np.array([0.3, 4.0], dtype=np.float32)
x = np.array([[1, 2], [3, 4]], dtype=dtype)
bias = np.array([0.3, 4.0], dtype=dtype)
verify_biasgelu(x, bias)


Expand Down