diff --git a/apps/uma/_template/patterns.py b/apps/uma/_template/patterns.py index ce25fe4dff8e..1c841f2dbf1d 100644 --- a/apps/uma/_template/patterns.py +++ b/apps/uma/_template/patterns.py @@ -23,3 +23,8 @@ def conv2d_pattern(): pattern = is_op("nn.conv2d")(wildcard(), wildcard()) pattern = pattern.has_attr({"strides": [1, 1], "groups": 1}) return pattern + + +def dense_pattern(): + pattern = is_op("nn.dense")(wildcard(), wildcard()) + return pattern diff --git a/python/tvm/relay/backend/contrib/uma/api/lower.py b/python/tvm/relay/backend/contrib/uma/api/lower.py index 34630949a151..334b6d101f82 100644 --- a/python/tvm/relay/backend/contrib/uma/api/lower.py +++ b/python/tvm/relay/backend/contrib/uma/api/lower.py @@ -60,27 +60,7 @@ def _lower_relay_to_tir(self, relay_prim_func: relay.Function) -> tvm.tir.PrimFu """ def _get_tensors(te_cached_func): - outputs = list(te_cached_func.outputs) - stack = [] - visited = set() - for output_ in outputs: - if output_ not in visited: - visited.add(output_) - stack.append(output_) - - args = [] - while len(stack) != 0: - tensor = stack.pop() - if isinstance(tensor.op, tvm.te.tensor.PlaceholderOp): - args.append(tensor) - elif isinstance(tensor.op, tvm.te.tensor.ComputeOp): - inputs = tensor.op.input_tensors - for input_ in inputs: - if input_ not in visited: - visited.add(input_) - stack.append(input_) - - return args + outputs + return list(te_cached_func.inputs) + list(te_cached_func.outputs) lower_to_te = tvm._ffi.get_global_func("relay.backend.LowerToTE") te_cached_func = lower_to_te(relay_prim_func) diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 5d7fb62cd204..563a7dff4a50 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -571,9 +571,7 @@ def _create_header_file(tensor_name, npy_data, output_path, data_linkage): header_file.write("};\n\n") -def convert_to_relay( - tflite_model_buf, -): +def convert_to_relay(tflite_model_buf, bind_params_by_name=True): """Convert a tflite model buffer in a Relay module""" # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 try: @@ -588,7 +586,8 @@ def convert_to_relay( raise ImportError("The tflite package must be installed") mod, params = relay.frontend.from_tflite(tflite_model) - mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) + if bind_params_by_name: + mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) return mod, params @@ -931,20 +930,30 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"): return dict(zip(output_tensor_names, out)) -def create_relay_module_and_inputs_from_tflite_file(tflite_model_file): +def create_relay_module_and_inputs_from_tflite_file(tflite_model_file, bind_params_by_name=True): """A helper function to create a Relay IRModule with inputs and params from a tflite file""" with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - mod, params = convert_to_relay(tflite_model_buf) + mod, params = convert_to_relay(tflite_model_buf, bind_params_by_name) inputs = dict() for param in mod["main"].params: name = str(param.name_hint) data_shape = [int(i) for i in param.type_annotation.shape] dtype = str(param.type_annotation.dtype) - in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max) - data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype) + if np.issubdtype(dtype, np.floating): + # Since np.random.uniform only allows the ranges of float32, + # at first float16 is used and scaled afterwards, if necessary. + in_min, in_max = (np.finfo("float16").min, np.finfo("float16").max) + data = np.random.uniform(low=in_min, high=in_max, size=data_shape).astype(dtype) + scale = np.finfo(dtype).min / np.finfo("float16").min + data *= scale + elif np.issubdtype(dtype, np.integer): + in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max) + data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype) + else: + raise TypeError(f"Type {dtype} not supported") inputs[name] = data return mod, inputs, params diff --git a/tests/python/contrib/test_uma/test_uma_pipeline.py b/tests/python/contrib/test_uma/test_uma_pipeline.py index 49b4a196bbd4..0d7ed3ab9587 100644 --- a/tests/python/contrib/test_uma/test_uma_pipeline.py +++ b/tests/python/contrib/test_uma/test_uma_pipeline.py @@ -16,6 +16,12 @@ # under the License. import pytest + +pytest.importorskip("tflite") +pytest.importorskip("tensorflow") + +import os +import tensorflow as tf from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER from tvm.relay import transform, testing from tvm.testing.aot import ( @@ -23,6 +29,7 @@ AOTTestRunner, generate_ref_data, compile_and_run, + create_relay_module_and_inputs_from_tflite_file, ) import tvm @@ -132,5 +139,80 @@ def test_mobilenet(): ) +def test_tflite_model(): + """ + End-to-end test of TF-Lite file using UMA + """ + tflite_file = "/tmp/model.tflite" + if os.path.exists(tflite_file): + os.remove(tflite_file) + generate_tflite_file(tflite_file) + + pytest.importorskip("tflite") + + interpreter = tf.lite.Interpreter(model_path=tflite_file) + tf_model_details = interpreter.get_input_details() + mod, _, params = create_relay_module_and_inputs_from_tflite_file( + tflite_file, bind_params_by_name=False + ) + + uma_backend = VanillaAcceleratorBackend() + uma_backend.register() + target = tvm.target.Target("vanilla_accelerator", host=tvm.target.Target("c")) + target_c = tvm.target.Target("c") + + # Generation of test input and output + data_shape = [int(x) for x in mod["main"].params[0].type_annotation.shape] + data = np.random.uniform(size=data_shape).astype("float32") + input_list = {str(tf_model_details[0]["name"]): data} + output_list = generate_ref_data(mod, input_list, params) + + # UMA partitioning (needs to be done after generate_ref_data) + mod = uma_backend.partition(mod) + + aot_test_model = AOTTestModel(module=mod, inputs=input_list, outputs=output_list, params=params) + test_runner = AOTTestRunner( + pass_config={"tir.usmp.enable": True, "tir.usmp.algorithm": "greedy_by_size"} + ) + + compile_and_run( + aot_test_model, + test_runner, + interface_api="c", + use_unpacked_api=True, + workspace_byte_alignment=1, + debug_calculated_workspaces=False, + target=[target_c, target], + ) + + +def generate_tflite_file(tflite_filename): + mnist = tf.keras.datasets.mnist + (x_train, y_train), (x_test, y_test) = mnist.load_data() + x_train, x_test = x_train / 255.0, x_test / 255.0 + x_train, x_test = x_train.reshape(-1, 28, 28, 1), x_test.reshape(-1, 28, 28, 1) + tf_model = tf.keras.models.Sequential( + [ + tf.keras.Input(shape=(28, 28, 1)), + tf.keras.layers.Conv2D(4, (3, 3), padding="same", activation="relu"), + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(32, activation="relu"), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(10), + ] + ) + output = tf_model(x_train[:1]) + output = output.numpy() + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss(y_train[:1], output).numpy() + tf_model.compile(metrics=["accuracy"], optimizer="adam", loss=loss) + tf_model.fit(x_train, y_train, epochs=1) + + tflite_converter = tf.lite.TFLiteConverter.from_keras_model(tf_model) + tflite_model = tflite_converter.convert() + with open(tflite_filename, "wb") as f: + f.write(tflite_model) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py b/tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py index e7a6b21d4ab5..043203e22a99 100644 --- a/tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py +++ b/tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py @@ -24,7 +24,7 @@ ) from apps.uma._template.codegen import gen_includes -from apps.uma._template.patterns import conv2d_pattern +from apps.uma._template.patterns import conv2d_pattern, dense_pattern from tvm.relay.backend.contrib.uma import uma_available pytestmark = pytest.mark.skipif(not uma_available(), reason="UMA not available") @@ -40,6 +40,7 @@ def __init__(self): # Relay to Relay function registration ####################################################################### self._register_pattern("conv2d", conv2d_pattern()) + self._register_pattern("dense", dense_pattern()) ####################################################################### # Relay to TIR function registration