diff --git a/docs/conf.py b/docs/conf.py index 08fbedb8ffca..eb2b39d4b1fd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -550,6 +550,9 @@ def force_gc(gallery_conf, fname): gc.collect() +# Skips certain files to avoid dependency issues +filename_pattern_default = "^(?!.*micro_mlperftiny.py).*$" + sphinx_gallery_conf = { "backreferences_dir": "gen_modules/backreferences", "doc_module": ("tvm", "numpy"), @@ -562,7 +565,7 @@ def force_gc(gallery_conf, fname): "within_subsection_order": WithinSubsectionOrder, "gallery_dirs": gallery_dirs, "subsection_order": subsection_order, - "filename_pattern": os.environ.get("TVM_TUTORIAL_EXEC_PATTERN", ".py"), + "filename_pattern": os.environ.get("TVM_TUTORIAL_EXEC_PATTERN", filename_pattern_default), "download_all_examples": False, "min_reported_time": 60, "expected_failing_examples": [], diff --git a/gallery/how_to/work_with_microtvm/micro_mlperftiny.py b/gallery/how_to/work_with_microtvm/micro_mlperftiny.py new file mode 100644 index 000000000000..79308e072365 --- /dev/null +++ b/gallery/how_to/work_with_microtvm/micro_mlperftiny.py @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +.. _tutorial-micro-MLPerfTiny: + +Creating Your MLPerfTiny Submission with microTVM +================================================= +**Authors**: +`Mehrdad Hessar `_ + +This tutorial is showcasing building an MLPerfTiny submission using microTVM. This +tutorial shows the steps to import a TFLite model from MLPerfTiny benchmark models, +compile it with TVM and generate a Zephyr project which can be flashed to a Zephyr +supported board to benchmark the model using EEMBC runner. +""" + +###################################################################### +# +# .. include:: ../../../../gallery/how_to/work_with_microtvm/install_dependencies.rst +# + +import os +import pathlib +import tarfile +import tempfile +import shutil + +###################################################################### +# +# .. include:: ../../../../gallery/how_to/work_with_microtvm/install_zephyr.rst +# + + +###################################################################### +# +# **Note:** Install CMSIS-NN only if you are interested to generate this submission +# using CMSIS-NN code generator. +# + +###################################################################### +# +# .. include:: ../../../../gallery/how_to/work_with_microtvm/install_cmsis.rst +# + +###################################################################### +# Import Python dependencies +# ------------------------------- +# +import tensorflow as tf +import numpy as np + +import tvm +from tvm import relay +from tvm.relay.backend import Executor, Runtime +from tvm.contrib.download import download_testdata +from tvm.micro import export_model_library_format +from tvm.micro.model_library_format import generate_c_interface_header +from tvm.micro.testing.utils import ( + create_header_file, + mlf_extract_workspace_size_bytes, +) + +###################################################################### +# Import Visual Wake Word Model +# -------------------------------------------------------------------- +# +# To begin with, download and import the Visual Wake Word (VWW) TFLite model from MLPerfTiny. +# This model is originally from `MLPerf Tiny repository `_. +# We also capture metadata information from the TFLite model such as input/output name, +# quantization parameters, etc. which will be used in following steps. +# +# We use indexing for various models to build the submission. The indices are defined as follows: +# To build another model, you need to update the model URL, the short name and index number. +# +# * Keyword Spotting(KWS) 1 +# * Visual Wake Word(VWW) 2 +# * Anomaly Detection(AD) 3 +# * Image Classification(IC) 4 +# +# If you would like to build the submission with CMSIS-NN, modify USE_CMSIS environment variable. +# +# .. code-block:: bash +# +# export USE_CMSIS=1 +# + +MODEL_URL = "https://github.com/mlcommons/tiny/raw/bceb91c5ad2e2deb295547d81505721d3a87d578/benchmark/training/visual_wake_words/trained_models/vww_96_int8.tflite" +MODEL_PATH = download_testdata(MODEL_URL, "vww_96_int8.tflite", module="model") + +MODEL_SHORT_NAME = "VWW" +MODEL_INDEX = 2 + +USE_CMSIS = os.environ.get("TVM_USE_CMSIS", False) + +tflite_model_buf = open(MODEL_PATH, "rb").read() +try: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) +except AttributeError: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + +interpreter = tf.lite.Interpreter(model_path=str(MODEL_PATH)) +interpreter.allocate_tensors() +input_details = interpreter.get_input_details() +output_details = interpreter.get_output_details() + +input_name = input_details[0]["name"] +input_shape = tuple(input_details[0]["shape"]) +input_dtype = np.dtype(input_details[0]["dtype"]).name +output_name = output_details[0]["name"] +output_shape = tuple(output_details[0]["shape"]) +output_dtype = np.dtype(output_details[0]["dtype"]).name + +# We extract quantization information from TFLite model. +# This is required for all models except Anomaly Detection, +# because for other models we send quantized data to interpreter +# from host, however, for AD model we send floating data and quantization +# happens on the microcontroller. +if MODEL_SHORT_NAME != "AD": + quant_output_scale = output_details[0]["quantization_parameters"]["scales"][0] + quant_output_zero_point = output_details[0]["quantization_parameters"]["zero_points"][0] + +relay_mod, params = relay.frontend.from_tflite( + tflite_model, shape_dict={input_name: input_shape}, dtype_dict={input_name: input_dtype} +) + +###################################################################### +# Defining Target, Runtime and Executor +# -------------------------------------------------------------------- +# +# Now we need to define the target, runtime and executor to compile this model. In this tutorial, +# we use Ahead-of-Time (AoT) compilation and we build a standalone project. This is different +# than using AoT with host-driven mode where the target would communicate with host using host-driven +# AoT executor to run inference. +# + +# Use the C runtime (crt) +RUNTIME = Runtime("crt") + +# Use the AoT executor with `unpacked-api=True` and `interface-api=c`. `interface-api=c` forces +# the compiler to generate C type function APIs and `unpacked-api=True` forces the compiler +# to generate minimal unpacked format inputs which reduces the stack memory usage on calling +# inference layers of the model. +EXECUTOR = Executor( + "aot", + {"unpacked-api": True, "interface-api": "c", "workspace-byte-alignment": 8}, +) + +# Select a Zephyr board +BOARD = os.getenv("TVM_MICRO_BOARD", default="nucleo_l4r5zi") + +# Get the the full target description using the BOARD +TARGET = tvm.micro.testing.get_target("zephyr", BOARD) + +###################################################################### +# Compile the model and export model library format +# -------------------------------------------------------------------- +# +# Now, we compile the model for the target. Then, we generate model +# library format for the compiled model. We also need to calculate the +# workspace size that is required for the compiled model. +# +# + +config = {"tir.disable_vectorize": True} +if USE_CMSIS: + from tvm.relay.op.contrib import cmsisnn + + config["relay.ext.cmsisnn.options"] = {"mcpu": TARGET.mcpu} + relay_mod = cmsisnn.partition_for_cmsisnn(relay_mod, params, mcpu=TARGET.mcpu) + +with tvm.transform.PassContext(opt_level=3, config=config): + module = tvm.relay.build( + relay_mod, target=TARGET, params=params, runtime=RUNTIME, executor=EXECUTOR + ) + +temp_dir = tvm.contrib.utils.tempdir() +model_tar_path = temp_dir / "model.tar" +export_model_library_format(module, model_tar_path) +workspace_size = mlf_extract_workspace_size_bytes(model_tar_path) + +###################################################################### +# Generate input/output header files +# -------------------------------------------------------------------- +# +# To create a microTVM standalone project with AoT, we need to generate +# input and output header files. These header files are used to connect +# the input and output API from generated code to the rest of the +# standalone project. For this specific submission, we only need to generate +# output header file since the input API call is handled differently. +# + +extra_tar_dir = tvm.contrib.utils.tempdir() +extra_tar_file = extra_tar_dir / "extra.tar" + +with tarfile.open(extra_tar_file, "w:gz") as tf: + with tempfile.TemporaryDirectory() as tar_temp_dir: + model_files_path = os.path.join(tar_temp_dir, "include") + os.mkdir(model_files_path) + header_path = generate_c_interface_header( + module.libmod_name, [input_name], [output_name], [], {}, [], 0, model_files_path, {}, {} + ) + tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir)) + + create_header_file( + "output_data", + np.zeros( + shape=output_shape, + dtype=output_dtype, + ), + "include", + tf, + ) + +###################################################################### +# Create the project, build and prepare the project tar file +# -------------------------------------------------------------------- +# +# Now that we have the compiled model as a model library format, +# we can generate the full project using Zephyr template project. First, +# we prepare the project options, then build the project. Finally, we +# cleanup the temporary files and move the submission project to the +# current working directory which could be downloaded and used on +# your development kit. +# + +input_total_size = 1 +for i in range(len(input_shape)): + input_total_size *= input_shape[i] + +template_project_path = pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr")) +project_options = { + "extra_files_tar": str(extra_tar_file), + "project_type": "mlperftiny", + "board": BOARD, + "compile_definitions": [ + f"-DWORKSPACE_SIZE={workspace_size + 512}", # Memory workspace size, 512 is a temporary offset + # since the memory calculation is not accurate. + f"-DTARGET_MODEL={MODEL_INDEX}", # Sets the model index for project compilation. + f"-DTH_MODEL_VERSION=EE_MODEL_VERSION_{MODEL_SHORT_NAME}01", # Sets model version. This is required by MLPerfTiny API. + f"-DMAX_DB_INPUT_SIZE={input_total_size}", # Max size of the input data array. + ], +} + +if MODEL_SHORT_NAME != "AD": + project_options["compile_definitions"].append(f"-DOUT_QUANT_SCALE={quant_output_scale}") + project_options["compile_definitions"].append(f"-DOUT_QUANT_ZERO={quant_output_zero_point}") + +if USE_CMSIS: + project_options["compile_definitions"].append(f"-DCOMPILE_WITH_CMSISNN=1") + +# Note: You might need to adjust this based on the board that you are using. +project_options["config_main_stack_size"] = 4000 + +if USE_CMSIS: + project_options["cmsis_path"] = os.environ.get("CMSIS_PATH", "/content/cmsis") + +generated_project_dir = temp_dir / "project" + +project = tvm.micro.project.generate_project_from_mlf( + template_project_path, generated_project_dir, model_tar_path, project_options +) +project.build() + +# Cleanup the build directory and extra artifacts +shutil.rmtree(generated_project_dir / "build") +(generated_project_dir / "model.tar").unlink() + +project_tar_path = pathlib.Path(os.getcwd()) / "project.tar" +with tarfile.open(project_tar_path, "w:tar") as tar: + tar.add(generated_project_dir, arcname=os.path.basename("project")) + +print(f"The generated project is located here: {project_tar_path}") + +###################################################################### +# Use this project with your board +# -------------------------------------------------------------------- +# +# Now that we have the generated project, you can use this project locally +# to flash your board and prepare it for EEMBC runner software. +# To do this follow these steps: +# +# .. code-block:: bash +# +# tar -xf project.tar +# cd project +# mkdir build +# cmake .. +# make -j2 +# west flash +# +# Now you can connect your board to EEMBC runner using this +# `instructions `_ +# and benchmark this model on your board. +# diff --git a/python/tvm/micro/testing/utils.py b/python/tvm/micro/testing/utils.py index 097fbf283a58..170c57631444 100644 --- a/python/tvm/micro/testing/utils.py +++ b/python/tvm/micro/testing/utils.py @@ -17,6 +17,7 @@ """Defines the test methods used with microTVM.""" +import io from functools import lru_cache import json import logging @@ -24,6 +25,7 @@ import tarfile import time from typing import Union +import numpy as np import tvm from tvm import relay @@ -102,7 +104,7 @@ def mlf_extract_workspace_size_bytes(mlf_tar_path: Union[Path, str]) -> int: workspace_size = 0 with tarfile.open(mlf_tar_path, "r:*") as tar_file: - tar_members = [ti.name for ti in tar_file.getmembers()] + tar_members = [tar_info.name for tar_info in tar_file.getmembers()] assert "./metadata.json" in tar_members with tar_file.extractfile("./metadata.json") as f: metadata = json.load(f) @@ -133,3 +135,43 @@ def get_conv2d_relay_module(): mod = tvm.IRModule.from_expr(f) mod = relay.transform.InferType()(mod) return mod + + +def _npy_dtype_to_ctype(data: np.ndarray) -> str: + if data.dtype == "int8": + return "int8_t" + elif data.dtype == "int32": + return "int32_t" + elif data.dtype == "uint8": + return "uint8_t" + elif data.dtype == "float32": + return "float" + else: + raise ValueError(f"Data type {data.dtype} not expected.") + + +def create_header_file(tensor_name: str, npy_data: np.array, output_path: str, tar_file: str): + """ + This method generates a header file containing the data contained in the numpy array provided + and adds the header file to a tar file. + It is used to capture the tensor data (for both inputs and output). + """ + header_file = io.StringIO() + header_file.write("#include \n") + header_file.write("#include \n") + header_file.write("#include \n") + header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n") + header_file.write(f"{_npy_dtype_to_ctype(npy_data)} {tensor_name}[] =") + + header_file.write("{") + for i in np.ndindex(npy_data.shape): + header_file.write(f"{npy_data[i]}, ") + header_file.write("};\n\n") + + header_file_bytes = bytes(header_file.getvalue(), "utf-8") + raw_path = Path(output_path) / f"{tensor_name}.h" + tar_info = tarfile.TarInfo(name=str(raw_path)) + tar_info.size = len(header_file_bytes) + tar_info.mode = 0o644 + tar_info.type = tarfile.REGTYPE + tar_file.addfile(tar_info, io.BytesIO(header_file_bytes)) diff --git a/tests/micro/zephyr/utils.py b/tests/micro/zephyr/utils.py index 42419b637fa4..bdac4e9c63a7 100644 --- a/tests/micro/zephyr/utils.py +++ b/tests/micro/zephyr/utils.py @@ -32,6 +32,7 @@ import tvm.micro from tvm.micro import export_model_library_format from tvm.micro.model_library_format import generate_c_interface_header +from tvm.micro.testing.utils import create_header_file from tvm.micro.testing.utils import ( mlf_extract_workspace_size_bytes, aot_transport_init_wait, @@ -106,42 +107,6 @@ def build_project( return project, project_dir -def create_header_file(tensor_name, npy_data, output_path, tar_file): - """ - This method generates a header file containing the data contained in the numpy array provided. - It is used to capture the tensor data (for both inputs and expected outputs). - """ - header_file = io.StringIO() - header_file.write("#include \n") - header_file.write("#include \n") - header_file.write("#include \n") - header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n") - - if npy_data.dtype == "int8": - header_file.write(f"int8_t {tensor_name}[] =") - elif npy_data.dtype == "int32": - header_file.write(f"int32_t {tensor_name}[] = ") - elif npy_data.dtype == "uint8": - header_file.write(f"uint8_t {tensor_name}[] = ") - elif npy_data.dtype == "float32": - header_file.write(f"float {tensor_name}[] = ") - else: - raise ValueError("Data type not expected.") - - header_file.write("{") - for i in np.ndindex(npy_data.shape): - header_file.write(f"{npy_data[i]}, ") - header_file.write("};\n\n") - - header_file_bytes = bytes(header_file.getvalue(), "utf-8") - raw_path = pathlib.Path(output_path) / f"{tensor_name}.h" - ti = tarfile.TarInfo(name=str(raw_path)) - ti.size = len(header_file_bytes) - ti.mode = 0o644 - ti.type = tarfile.REGTYPE - tar_file.addfile(ti, io.BytesIO(header_file_bytes)) - - # TODO move CMSIS integration to microtvm_api_server.py # see https://discuss.tvm.apache.org/t/tvm-capturing-dependent-libraries-of-code-generated-tir-initially-for-use-in-model-library-format/11080 def loadCMSIS(temp_dir): diff --git a/tests/scripts/request_hook/request_hook.py b/tests/scripts/request_hook/request_hook.py index 4e3db220e0b4..b033f1ca8457 100644 --- a/tests/scripts/request_hook/request_hook.py +++ b/tests/scripts/request_hook/request_hook.py @@ -208,6 +208,7 @@ "https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5": f"{BASE}/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5", "https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels.h5": f"{BASE}/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels.h5", "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz": f"{BASE}/tensorflow/tf-keras-datasets/mnist.npz", + "https://github.com/mlcommons/tiny/raw/bceb91c5ad2e2deb295547d81505721d3a87d578/benchmark/training/visual_wake_words/trained_models/vww_96_int8.tflite": f"{BASE}/mlcommons/tiny/benchmark/training/visual_wake_words/trained_models/vww_96_int8.tflite", } diff --git a/tests/scripts/task_python_microtvm.sh b/tests/scripts/task_python_microtvm.sh index 6153cdf82392..0b43c9c1fa8f 100755 --- a/tests/scripts/task_python_microtvm.sh +++ b/tests/scripts/task_python_microtvm.sh @@ -51,6 +51,13 @@ python3 gallery/how_to/work_with_microtvm/micro_aot.py python3 gallery/how_to/work_with_microtvm/micro_pytorch.py ./gallery/how_to/work_with_microtvm/micro_tvmc.sh +# without CMSIS-NN +python3 gallery/how_to/work_with_microtvm/micro_mlperftiny.py +# with CMSIS-NN +export TVM_USE_CMSIS=1 +python3 gallery/how_to/work_with_microtvm/micro_mlperftiny.py +export TVM_USE_CMSIS= + # Tutorials running with Zephyr export TVM_MICRO_USE_HW=1 export TVM_MICRO_BOARD=qemu_x86