diff --git a/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py b/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py index 6c0a1dffc30c..23d4f4b45b11 100644 --- a/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py +++ b/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py @@ -82,6 +82,8 @@ def callback( self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map ) -> tvm.relay.Expr: params = self.params_class(post.op.body) + quant_min = -128 + quant_max = 127 ifm = post.args[0] ifm_dtype = ifm.checked_type.dtype @@ -121,12 +123,14 @@ def callback( ifm2_scale=0.0, ifm2_zero_point=int(params.ifm.q_params.zero_point), ofm_scale=1.0, - ofm_zero_point=127, + ofm_zero_point=quant_max, ifm_channels=depth, ifm2_channels=1, reversed_operands=False, ofm_dtype="int32", activation="LUT", + clip_min=-255, + clip_max=0, ) # PASS 2 - SHR @@ -147,8 +151,8 @@ def callback( reversed_operands=False, ofm_dtype="int32", activation="CLIP", - clip_min=-128, - clip_max=127, + clip_min=quant_min, + clip_max=quant_max, rounding_mode="NATURAL", ) @@ -165,6 +169,9 @@ def callback( ofm_channels=1, upscale="NONE", ofm_dtype="int32", + activation="CLIP", + clip_min=quant_min, + clip_max=quant_max, ) # PASS 4 - CLZ @@ -177,6 +184,9 @@ def callback( ofm_scale=0.0, ofm_zero_point=int(params.ifm.q_params.zero_point), ofm_channels=1, + activation="CLIP", + clip_min=quant_min, + clip_max=quant_max, ) # PASS 5 - Sub @@ -196,6 +206,9 @@ def callback( ifm2_channels=1, reversed_operands=False, ofm_dtype="int32", + activation="CLIP", + clip_min=quant_min, + clip_max=quant_max, ) # PASS 6 - Sub @@ -215,6 +228,9 @@ def callback( ifm2_channels=1, reversed_operands=False, ofm_dtype="int32", + activation="CLIP", + clip_min=quant_min, + clip_max=quant_max, ) # PASS 7 - SHL @@ -229,13 +245,13 @@ def callback( ifm2_zero_point=0, ofm_scale=0.0, ofm_zero_point=int(params.ifm.q_params.zero_point), - ifm_channels=depth, + ifm_channels=1, ifm2_channels=1, reversed_operands=False, ofm_dtype="int32", activation="CLIP", - clip_min=-128, - clip_max=127, + clip_min=quant_min, + clip_max=quant_max, ) # PASS 8 - Sub @@ -255,6 +271,9 @@ def callback( ifm2_channels=1, reversed_operands=False, ofm_dtype="int32", + activation="CLIP", + clip_min=quant_min, + clip_max=quant_max, ) # PASS 9 - SHL @@ -274,8 +293,8 @@ def callback( reversed_operands=False, ofm_dtype="int32", activation="CLIP", - clip_min=-128, - clip_max=127, + clip_min=quant_min, + clip_max=quant_max, ) # PASS 10 - Add @@ -296,8 +315,8 @@ def callback( reversed_operands=False, ofm_dtype="int32", activation="CLIP", - clip_min=-128, - clip_max=127, + clip_min=quant_min, + clip_max=quant_max, use_rescale=True, rescale_scale=1, rescale_shift=1, @@ -316,13 +335,13 @@ def callback( ifm2_zero_point=0, ofm_scale=2.0, ofm_zero_point=0, - ifm_channels=depth, + ifm_channels=1, ifm2_channels=1, reversed_operands=False, ofm_dtype="int32", activation="CLIP", - clip_min=-128 * 2, - clip_max=127 * 2, + clip_min=quant_min, + clip_max=quant_max, ) # PASS 12 - Add @@ -343,8 +362,8 @@ def callback( reversed_operands=False, ofm_dtype="int32", activation="CLIP", - clip_min=-128, - clip_max=127, + clip_min=quant_min, + clip_max=quant_max, ) nr_x = rescale_w_offset @@ -368,8 +387,8 @@ def callback( reversed_operands=False, ofm_dtype="int32", activation="CLIP", - clip_min=-128 * 2, - clip_max=127 * 2, + clip_min=quant_min, + clip_max=quant_max, ) # PASS 14, 19, 24 - Sub @@ -388,6 +407,9 @@ def callback( ifm2_channels=1, reversed_operands=False, ofm_dtype="int32", + activation="CLIP", + clip_min=quant_min, + clip_max=quant_max, ) # PASS 15, 20, 25 - Mul @@ -407,8 +429,8 @@ def callback( reversed_operands=False, ofm_dtype="int32", activation="CLIP", - clip_min=-128 * 2, - clip_max=127 * 2, + clip_min=quant_min, + clip_max=quant_max, ) # PASS 16, 21, 26 - Mul @@ -428,8 +450,8 @@ def callback( reversed_operands=False, ofm_dtype="int32", activation="CLIP", - clip_min=-128, - clip_max=127, + clip_min=quant_min, + clip_max=quant_max, ) # PASS 17, 22, 27 - Add @@ -448,6 +470,9 @@ def callback( ifm2_channels=1, reversed_operands=False, ofm_dtype="int32", + activation="CLIP", + clip_min=quant_min, + clip_max=quant_max, ) # PASS 28 - Mul @@ -468,8 +493,8 @@ def callback( reversed_operands=False, ofm_dtype="int32", activation="CLIP", - clip_min=-128, - clip_max=127, + clip_min=quant_min, + clip_max=quant_max, ) # PASS 29 - Mul @@ -489,8 +514,8 @@ def callback( reversed_operands=False, ofm_dtype="int32", activation="CLIP", - clip_min=-128 * 2, - clip_max=127 * 2, + clip_min=quant_min, + clip_max=quant_max, ) # PASS 30 - SHR diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 1b643f815721..c952a13c52d6 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -3526,9 +3526,10 @@ def representative_dataset(): assert tuple(func_body.args[1].checked_type.shape) == (256,) -@pytest.mark.parametrize("ifm_shape", [(1, 12), (1, 12, 32)]) -def test_tflite_softmax(ifm_shape): +def test_tflite_softmax(): + np.random.seed(0) dtype = "int8" + ifm_shape = (1, 12) def create_tflite_graph(): @tf.function @@ -3539,7 +3540,7 @@ def softmax(x): # Convert the model def representative_dataset(): for _ in range(100): - data = np.random.rand(*tuple(ifm_shape)) + data = np.random.uniform(low=-1, high=2, size=tuple(ifm_shape)) yield [data.astype(np.float32)] converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) @@ -3554,44 +3555,54 @@ def representative_dataset(): def verify(ext_func): out_op = ext_func.body ops = [] - # List of expected operations and their type if it exists - expected_ops = [ - ("reshape", None), - ("reshape", None), - ("contrib.ethosu.pooling", "MAX"), - ("contrib.ethosu.binary_elementwise", "SUB"), - ("contrib.ethosu.binary_elementwise", "SHR"), - ("contrib.ethosu.pooling", "SUM"), - ("contrib.ethosu.unary_elementwise", "CLZ"), - ("contrib.ethosu.binary_elementwise", "SUB"), - ("contrib.ethosu.binary_elementwise", "SHL"), - ("contrib.ethosu.binary_elementwise", "SUB"), - ("contrib.ethosu.binary_elementwise", "SHL"), - ("contrib.ethosu.binary_elementwise", "ADD"), - ("contrib.ethosu.binary_elementwise", "MUL"), - ("contrib.ethosu.binary_elementwise", "ADD"), - ("contrib.ethosu.binary_elementwise", "MUL"), - ("contrib.ethosu.binary_elementwise", "SUB"), - ("contrib.ethosu.binary_elementwise", "MUL"), - ("contrib.ethosu.binary_elementwise", "MUL"), - ("contrib.ethosu.binary_elementwise", "ADD"), - ("contrib.ethosu.binary_elementwise", "MUL"), - ("contrib.ethosu.binary_elementwise", "SUB"), - ("contrib.ethosu.binary_elementwise", "MUL"), - ("contrib.ethosu.binary_elementwise", "MUL"), - ("contrib.ethosu.binary_elementwise", "ADD"), - ("contrib.ethosu.binary_elementwise", "MUL"), - ("contrib.ethosu.binary_elementwise", "SUB"), - ("contrib.ethosu.binary_elementwise", "MUL"), - ("contrib.ethosu.binary_elementwise", "MUL"), - ("contrib.ethosu.binary_elementwise", "ADD"), - ("contrib.ethosu.binary_elementwise", "MUL"), - ("contrib.ethosu.binary_elementwise", "MUL"), - ("contrib.ethosu.binary_elementwise", "SUB"), - ("contrib.ethosu.binary_elementwise", "SHR"), - ("reshape", None), + # List of expected operations, their type and activation parameters if it exists + expected_ops_params = [ + ("reshape", None, [None, None, None, None, None, None]), + ("reshape", None, [None, None, None, None, None, None]), + ("contrib.ethosu.pooling", "MAX", [0.011756093241274357, -43, None, None, 0.0, -43]), + ( + "contrib.ethosu.binary_elementwise", + "SUB", + [0.011756093241274357, -43, 0.0, -43, 1.0, 127], + ), + ("contrib.ethosu.binary_elementwise", "SHR", [1.0, 0, 0.0, 0, 0.0, -43]), + ("contrib.ethosu.pooling", "SUM", [0.0, 0, None, None, 0.0, -43]), + ("contrib.ethosu.unary_elementwise", "CLZ", [0.0, 0, None, None, 0.0, -43]), + ("contrib.ethosu.binary_elementwise", "SUB", [0.0, 0, 0.0, 0, 0.0, -43]), + ("contrib.ethosu.binary_elementwise", "SHL", [0.0, 0, 0.0, 0, 0.0, -43]), + ("contrib.ethosu.binary_elementwise", "SUB", [0.0, 0, 0.0, 0, 0.0, -43]), + ("contrib.ethosu.binary_elementwise", "SHL", [0.0, 0, 0.0, 0, 0.0, -43]), + ("contrib.ethosu.binary_elementwise", "ADD", [0.0, 0, 0.0, 0, 1.0, 0]), + ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]), + ("contrib.ethosu.binary_elementwise", "ADD", [2.0, 0, 0.0, 0, 1.0, 0]), + ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]), + ("contrib.ethosu.binary_elementwise", "SUB", [2.0, 0, 0.0, 0, 1.0, 0]), + ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]), + ("contrib.ethosu.binary_elementwise", "MUL", [2.0, 0, 0.0, 0, 0.0, -43]), + ("contrib.ethosu.binary_elementwise", "ADD", [1.0, 0, 0.0, 0, 1.0, 0]), + ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]), + ("contrib.ethosu.binary_elementwise", "SUB", [2.0, 0, 0.0, 0, 1.0, 0]), + ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]), + ("contrib.ethosu.binary_elementwise", "MUL", [2.0, 0, 0.0, 0, 0.0, -43]), + ("contrib.ethosu.binary_elementwise", "ADD", [1.0, 0, 0.0, 0, 1.0, 0]), + ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]), + ("contrib.ethosu.binary_elementwise", "SUB", [2.0, 0, 0.0, 0, 1.0, 0]), + ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]), + ("contrib.ethosu.binary_elementwise", "MUL", [2.0, 0, 0.0, 0, 0.0, -43]), + ("contrib.ethosu.binary_elementwise", "ADD", [1.0, 0, 0.0, 0, 1.0, 0]), + ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 0.0, 0, 1.0, 0]), + ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]), + ("contrib.ethosu.binary_elementwise", "SUB", [0.0, 0, 0.0, 0, 0.0, -43]), + ("contrib.ethosu.binary_elementwise", "SHR", [2.0, 0, 0.0, 0, 0.00390625, -128]), + ("reshape", None, [None, None, None, None, None, None]), ] + def get_attr_value(op, attr_name): + if hasattr(op.attrs, attr_name): + return op.attrs[attr_name] + else: + return None + def get_op_type(op): if hasattr(op.attrs, "pooling_type"): return op.attrs.pooling_type @@ -3599,6 +3610,16 @@ def get_op_type(op): return op.attrs.operator_type return None + def get_activation_params(op): + activation_params = [] + activation_params.append(get_attr_value(op, "ifm_scale")) + activation_params.append(get_attr_value(op, "ifm_zero_point")) + activation_params.append(get_attr_value(op, "ifm2_scale")) + activation_params.append(get_attr_value(op, "ifm2_zero_point")) + activation_params.append(get_attr_value(op, "ofm_scale")) + activation_params.append(get_attr_value(op, "ofm_zero_point")) + return activation_params + def _visit(stmt): if isinstance(stmt, relay.expr.Call): ops.append(stmt) @@ -3616,9 +3637,18 @@ def _visit(stmt): assert ofm.dtype == dtype # check operations - - ops = [(op.op.name, get_op_type(op)) for op in ops] - assert expected_ops == ops + for op, expected_op_params in zip(ops, expected_ops_params): + activation_params = get_activation_params(op) + expected_op_name, expected_op_type, expected_activation_params = expected_op_params + assert op.op.name == expected_op_name + assert expected_op_type == get_op_type(op) + for activation_param, expected_activation_param in zip( + activation_params, expected_activation_params + ): + if isinstance(activation_param, float): + assert math.isclose(expected_activation_param, activation_param, abs_tol=1e-7) + else: + assert expected_activation_param == activation_param softmax_pattern_table = [ (