diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 2c6ca1f74000..3fdcdb6c24b5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1584,6 +1584,92 @@ def __call__(self, *args, **kwargs): pass +class FullyConnectedRewriter(DFPatternCallback): + """Legalize Fully Connected (with bias and clip) to an NPU operator""" + + def __init__(self): + super().__init__(require_type=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.FullyConnectedParams.composite_name}) + )(wildcard()) + + def callback(self, pre, post, node_map): + params = ethosu_patterns.FullyConnectedParams(post.op.body) + params.ifm.tensor = post.args[0] + + # IFM reshapes + ifm = post.args[0] + if len(params.ifm.shape) != 4 or not params.ifm.shape[1] == params.ifm.shape[2] == 1: + ifm = relay.reshape(ifm, (1, 1, 1, params.ifm.shape[-1])) + + # Weight transformations + weights_values = params.weights.values + weights_values_ohwi = np.expand_dims(weights_values, axis=(1, 2)) + if params.activation: + activation = "CLIP" + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) + else: + activation = "NONE" + clip_min = 0 + clip_max = 0 + bias_values = ( + params.biases.tensor.data.asnumpy() + if params.biases + else np.zeros((params.ofm.shape[-1])) + ) + scale_bias = vela_api.pack_biases( + biases=bias_values, + ifm_scale=params.ifm.q_params.scale_f32, + ifm_dtype=np.dtype(params.ifm.dtype), + weight_scales=params.weights.q_params.scale_f32, + ofm_scale=params.ofm.q_params.scale_f32, + is_activation_tanh_or_sigmoid=False, + ) + ethosu_fc = ethosu_ops.ethosu_conv2d( + ifm=ifm, + weight=relay.const(weights_values_ohwi, params.weights.values.dtype), + scale_bias=relay.const(scale_bias, "uint8"), + lut=relay.const([], dtype="int8"), + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + weight_zero_point=int(params.weights.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + kernel_shape=[1, 1], + ofm_channels=params.weights.shape[0], + strides=(1, 1), + padding=(0, 0, 0, 0), + dilation=(1, 1), + activation=activation, + clip_min=clip_min, + clip_max=clip_max, + upscale="NONE", + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + + if len(params.ofm.shape) != 4 or not params.ofm.shape[1] == params.ofm.shape[2] == 1: + ethosu_fc = relay.reshape(ethosu_fc, params.ofm.shape) + return ethosu_fc + + +@ir.transform.module_pass(opt_level=1) +class LegalizeFullyConnected: + """This is the pass that wraps the FullyConnectedRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(FullyConnectedRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + @ir.transform.module_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -1621,6 +1707,7 @@ def transform_module( mod = LegalizeSqueeze()(mod) mod = LegalizeReshape()(mod) mod = LegalizeStridedSlice()(mod) + mod = LegalizeFullyConnected()(mod) mod = LegalizeNoOps()(mod) return mod diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 8234dc047fd8..dffc237e791c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -129,6 +129,20 @@ class DequantizeArgs(Enum): IFM_ZERO_POINT = 2 +class QDenseArgs(Enum): + """ + This is a helper enum to access the correct index of + qnn.dense arguments + """ + + IFM = 0 + WEIGHTS = 1 + IFM_ZERO_POINT = 2 + WEIGHTS_ZERO_POINT = 3 + IFM_SCALE = 4 + WEIGHTS_SCALE = 5 + + def is_composite_func(func: relay.Function, name: str) -> bool: """ This method checks whether the call is to diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index d83c5d38f429..0893be4bb84a 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1537,6 +1537,110 @@ def squeeze_pattern(): return is_op("squeeze")(wildcard()) +class FullyConnectedParams: + """ + This class will parse a call to an ethos-u.fully_connected composite + function and extract the parameter information. + """ + + composite_name = "ethos-u.fully_connected" + + @requires_vela + def __init__(self, func_body): + from tvm.relay.backend.contrib.ethosu.util import QDenseArgs # type: ignore + from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs + from tvm.relay.backend.contrib.ethosu.util import RequantArgs + + self.activation = None + if str(func_body.op) == "clip": + self.activation = func_body + requantize_op = self.activation.args[0] + else: + requantize_op = func_body + + call = requantize_op.args[0] + if str(requantize_op.args[0].op) == "nn.bias_add": + bias_add = call + qnn_dense = call.args[0] + else: + bias_add = None + qnn_dense = call + + # weights & biases are params as they should be constant + self.weights = TensorParams( + qnn_dense.args[QDenseArgs.WEIGHTS.value], + None, + qnn_dense.args[QDenseArgs.WEIGHTS_SCALE.value], + qnn_dense.args[QDenseArgs.WEIGHTS_ZERO_POINT.value], + ) + self.biases = ( + TensorParams( + bias_add.args[BiasAddArgs.BIASES.value], + None, + requantize_op.args[RequantArgs.IFM_SCALE.value], + requantize_op.args[RequantArgs.IFM_ZERO_POINT.value], + ) + if bias_add + else None + ) + self.ifm = TensorParams( + qnn_dense.args[QDenseArgs.IFM.value], + None, + qnn_dense.args[QDenseArgs.IFM_SCALE.value], + qnn_dense.args[QDenseArgs.IFM_ZERO_POINT.value], + ) + self.ofm = TensorParams( + func_body, + None, + requantize_op.args[RequantArgs.OFM_SCALE.value], + requantize_op.args[RequantArgs.OFM_ZERO_POINT.value], + ) + + def is_valid(self) -> bool: + """ + Checks whether Fully Connected has compatible attributes with HW + """ + + def check_weights_fc(weights): + """Checks whether weight tensor is compatible with HW""" + weights_limit = 127 * 65536 + # A saturation upper bound check for accumulators + weights.values = weights.values - weights.q_params.zero_point + axis = 1 + sum_weights = np.amax(np.sum(np.absolute(weights.values), axis=axis)) + if not sum_weights <= weights_limit: + return False + return True + + if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]): + return False + if not check_weights_fc(self.weights): + return False + if not check_bias(self.biases): + return False + if not check_batch_size(self.ifm): + return False + # Check input shape + if not len(self.ifm.shape) == 2: + return False + # Check output shape + if not len(self.ofm.shape) == 2: + return False + return True + + +def qnn_fc_pattern(): + dense = is_op("qnn.dense")( + wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + ) + optional_bias_add = is_op("nn.bias_add")(dense, is_constant()) + req = is_op("qnn.requantize")( + dense | optional_bias_add, is_constant(), is_constant(), is_constant(), is_constant() + ) + optional_clip = req.optional(is_op("clip")) + return optional_clip + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -1555,6 +1659,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal qnn_conv2d_transpose_pattern(), lambda pat: QnnConv2DTransposeParams(pat).is_valid(), ), + ( + FullyConnectedParams.composite_name, + qnn_fc_pattern(), + lambda pat: FullyConnectedParams(pat).is_valid(), + ), ( MaxPool2DParams.composite_name, qnn_maxpool2d_pattern(), diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 105f907e2209..ad874588a1ab 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1182,5 +1182,35 @@ def leaky_relu_func(x): _compare_tvm_with_tflite(leaky_relu_func, [ifm_shape], accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)]) +@pytest.mark.parametrize("ofm_channels", [32, 64]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("activation_function", ["RELU", "NONE"]) +def test_tflite_fully_connected( + accel_type, + ifm_shape, + ofm_channels, + use_bias, + activation_function, +): + @tf.function + def fully_connected(x): + bias_shape = ofm_channels + bias = tf.constant(np.random.uniform(size=bias_shape), dtype=tf.float32) + w = tf.constant( + np.random.uniform(size=[ifm_shape[1], ofm_channels]), + dtype=tf.float32, + ) + x = tf.matmul(x, w) + if use_bias: + x = tf.nn.bias_add(x, bias) + if activation_function: + x = tf.nn.relu(x) + return x + + _compare_tvm_with_tflite(fully_connected, [ifm_shape], accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 51548bce8b34..710c3e8c8812 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -2421,5 +2421,108 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)]) +@pytest.mark.parametrize("ofm_channels", [32, 64]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("activation_function", ["RELU", "NONE"]) +def test_tflite_fully_connected( + ifm_shape, + ofm_channels, + use_bias, + activation_function, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def fully_connected(self, x): + bias_shape = ofm_channels + bias = tf.constant(np.random.uniform(size=bias_shape), dtype=tf.float32) + w = tf.constant( + np.random.uniform(size=[ifm_shape[1], ofm_channels]), + dtype=tf.float32, + ) + x = tf.matmul(x, w) + if use_bias: + x = tf.nn.bias_add(x, bias) + if activation_function: + x = tf.nn.relu(x) + return x + + model = Model() + concrete_func = model.fully_connected.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + op = ext_func.body.args[0] + ofm_channels = op.attrs.ofm_channels + + # check IFM + ifm = op.args[0].checked_type + assert list(ifm.shape) == [1, 1] + list(ifm_shape) + assert str(ifm.dtype) == dtype + + # check OFM + ofm = op.checked_type + assert list(ofm.shape) == [1, 1, 1, ofm_channels] + assert str(ofm.dtype) == dtype + + # check weights + weights_ohwi = op.args[1].data.asnumpy() + assert str(weights_ohwi.dtype) == dtype + assert list(weights_ohwi.shape) == [ofm_channels, 1, 1, ifm_shape[1]] + + # Check that scale_bias matches weight tensor + assert list(op.args[2].checked_type.shape)[0] == ofm_channels + + assert list(op.attrs.padding) == [0, 0, 0, 0] + assert list(op.attrs.strides) == [1, 1] + assert list(op.attrs.dilation) == [1, 1] + if activation_function == "RELU": + assert str(op.attrs.activation) == "CLIP" + + fc_pattern_table = [ + ( + ethosu.FullyConnectedParams.composite_name, + ethosu.qnn_fc_pattern(), + lambda pat: ethosu.FullyConnectedParams(pat).is_valid(), + ) + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, fc_params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], fc_params) + mod = partition_ethosu_by_table(mod, fc_pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.FullyConnectedRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": pytest.main([__file__])