Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit bc91aea

Browse files
author
Wei Chu
committed
remove temperature
1 parent f866fd0 commit bc91aea

File tree

2 files changed

+8
-14
lines changed

2 files changed

+8
-14
lines changed

python/mxnet/contrib/onnx/mx2onnx/_op_translations.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -854,19 +854,14 @@ def convert_softmax(node, **kwargs):
854854

855855
axis = int(attrs.get("axis", -1))
856856
temperature = attrs.get("temperature", None)
857-
if not temperature:
858-
temperature = 1.0
859-
else:
860-
temperature = float(temperature)
857+
if temperature and float(temperature) != 1.0:
858+
raise NotImplementedError("Temperature will be supported in onnx opset13.")
861859
use_length = attrs.get("use_length", None)
862860
input_type = kwargs["in_type"]
863861
data = input_nodes[0]
864862

865863
nodes = [
866-
create_tensor([temperature], name+"_temp", kwargs["initializer"], dtype="float64"),
867-
make_node("Cast", [name+"_temp"], [name+"_T"], to=input_type),
868-
make_node("Div", [data, name+"_T"], [name+"_div_out"]),
869-
make_node("Exp", [name+"_div_out"], [name+"_exp_out"]),
864+
make_node("Exp", [data], [name+"_exp_out"]),
870865
make_node("ReduceSum", [name+"_exp_out"], [name+"_rsum_out"], axes=[axis], keepdims=1)
871866
]
872867
if len(input_nodes) == 1:

tests/python-pytest/onnx/test_operators.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,17 +306,16 @@ def test_onnx_export_cast(tmp_path, src_dtype, dst_dtype, shape):
306306

307307

308308
@pytest.mark.parametrize('dtype', ['float16', 'float32'])
309-
@pytest.mark.parametrize('temperature', [0.3, 0.5, 1.0])
310-
def test_onnx_export_softmax(tmp_path, dtype, temperature):
309+
def test_onnx_export_softmax(tmp_path, dtype):
311310
x = mx.nd.random.uniform(0, 1, (2, 3, 4), dtype=dtype)
312-
M1 = def_model('softmax', temperature=temperature)
311+
M1 = def_model('softmax')
313312
op_export_test('softmax_1', M1, [x], tmp_path)
314-
M2 = def_model('softmax', use_length=True, axis=0, temperature=temperature)
313+
M2 = def_model('softmax', use_length=True, axis=0)
315314
l2 = mx.nd.array([[2,0,2,1],[1,1,2,1], [0,0,0,1]], dtype=int)
316315
op_export_test('softmax_2', M2, [x, l2], tmp_path)
317-
M3 = def_model('softmax', use_length=True, axis=-1, temperature=temperature)
316+
M3 = def_model('softmax', use_length=True, axis=-1)
318317
l3 = mx.nd.array([[2,0,4],[0,0,0]], dtype=int)
319318
op_export_test('softmax_3', M3, [x, l3], tmp_path)
320-
M4 = def_model('softmax', use_length=True, axis=1, temperature=temperature)
319+
M4 = def_model('softmax', use_length=True, axis=1)
321320
l4 = mx.nd.array([[2,0,3,1],[0,1,0,0]], dtype=int)
322321
op_export_test('softmax_4', M4, [x, l4], tmp_path)

0 commit comments

Comments
 (0)