From b32bebae79108ed868ff7630a7499f2911e0015e Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Thu, 5 Jan 2023 12:05:22 +0000 Subject: [PATCH] Add support for named outputs in MLF archive Following from #12789, this adds support for determining the output tensor name from the input model within the MLF metadata json. Co-authored-by: Ashutosh Parkhi --- python/tvm/micro/model_library_format.py | 119 ++++++++---------- .../test_micro_model_library_format.py | 84 +++++++++++++ 2 files changed, 135 insertions(+), 68 deletions(-) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 5aa2d154ba57..0f30c39ad476 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -26,7 +26,6 @@ import typing import tvm -from tvm.ir.type import TupleType from tvm.micro import get_standalone_crt_dir from .._ffi import get_global_func from ..contrib import utils @@ -217,6 +216,29 @@ def _create_type_metadata(input_type): } +def _flatten_tuple_outputs(ret_type, predefined_names, offset=0): + if isinstance(ret_type, tvm.ir.tensor_type.TensorType): + name = predefined_names[offset] if predefined_names else f"output{offset}" + return {name: ret_type} + + added_fields = len(ret_type.fields) + outputs = {} + for output_index in range(added_fields): + next_output = offset + len(outputs) + outputs.update( + _flatten_tuple_outputs(ret_type.fields[output_index], predefined_names, next_output) + ) + + return outputs + + +def _get_outputs_from_ret_type(ret_type, predefined_names): + if isinstance(ret_type, tvm.ir.tensor_type.TensorType): + name = predefined_names[0] if predefined_names else "output" + return {name: ret_type} + return _flatten_tuple_outputs(ret_type, predefined_names) + + def _build_function_memory_map(function_metadata): """Build a simple map that shows how much workspace is required to execute each primitive function. The main_func describes how much memory is required @@ -297,29 +319,25 @@ def _create_empty_entry(target_device_type): target_main_entries[int(target.get_target_device_type())] = _create_empty_entry( int(target.get_target_device_type()) ) - target_main_entries[int(target.get_target_device_type())]["io_size_bytes"] = int( - main_func_metadata.io_sizes[target] - ) + target_main_on_device = target_main_entries[int(target.get_target_device_type())] + target_main_on_device["io_size_bytes"] = int(main_func_metadata.io_sizes[target]) - # Now, we also add the information about the size of each input and output of the main - # function (in bytes) - input_dict = {} - for input_param in main_func_metadata.relay_primfuncs[target].params: - input_dict[input_param.name_hint] = _create_type_metadata(input_param.checked_type) - target_main_entries[int(target.get_target_device_type())]["inputs"] = input_dict - - output_dict = {} - # For output, we dont have the name of the output, so we enumerate them - if isinstance(main_func_metadata.relay_primfuncs[target].ret_type, tvm.ir.type.TupleType): - output_list = _convert_tuple_to_outputs( - main_func_metadata.relay_primfuncs[target].ret_type - ) - for i, output_type in enumerate(output_list): - output_dict[f"output{i}"] = _create_type_metadata(output_type) - else: - output_type = main_func_metadata.relay_primfuncs[target].ret_type - output_dict["output"] = _create_type_metadata(output_type) - target_main_entries[int(target.get_target_device_type())]["outputs"] = output_dict + main_relay_func = main_func_metadata.relay_primfuncs[target] + target_main_on_device["inputs"] = { + input_param.name_hint: _create_type_metadata(input_param.checked_type) + for input_param in main_relay_func.params + } + predefined_names = ( + main_relay_func.attrs["output_tensor_names"] + if "output_tensor_names" in main_relay_func.attrs + else None + ) + target_main_on_device["outputs"] = { + name: _create_type_metadata(output_type) + for name, output_type in _get_outputs_from_ret_type( + main_relay_func.ret_type, predefined_names + ).items() + } ret = { "operator_functions": func_entries, @@ -328,30 +346,6 @@ def _create_empty_entry(target_device_type): return ret -def _get_main_relay_func(mod: executor_factory.ExecutorFactoryModule): - main_func = mod.function_metadata[MAIN_FUNC_NAME_STR] - target = list(main_func.relay_primfuncs.keys())[0] - return main_func.relay_primfuncs[target] - - -def _convert_tuple_to_outputs(ret_type, offset=0): - outputs = [] - added_fields = len(ret_type.fields) - for output_index in range(added_fields): - next_output = offset + len(outputs) - if isinstance(ret_type.fields[output_index], TupleType): - outputs.extend(_convert_tuple_to_outputs(ret_type.fields[output_index], next_output)) - else: - outputs.append(ret_type.fields[output_index]) - return outputs - - -def _get_inputs_and_outputs_from_module(mod): - inputs = [str(input_var.name) for input_var in mod.executor_codegen_metadata.inputs] - outputs = list(mod.executor_codegen_metadata.outputs) - return inputs, outputs - - def _get_pools_from_module(mod): return list(dict(mod.executor_codegen_metadata.pool_inputs).values()) @@ -462,33 +456,22 @@ def _export_graph_model_library_format( if not include_path.exists(): include_path.mkdir() - inputs, outputs = _get_inputs_and_outputs_from_module(mod) devices = mod.get_devices() pools = _get_pools_from_module(mod) io_pool_allocations = _get_io_pool_allocation_from_module(mod) - workspace_size = int( - metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0][ - "workspace_size_bytes" - ] - ) - inputs_sizes = metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0][ - "inputs" - ] - # Here, we merge the output sizes with the actual output names - output_sizes = {} - for i, key in enumerate( - metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0][ - "outputs" - ].keys() - ): - output_sizes[outputs[i]] = metadata["modules"][mod.libmod_name]["memory"][ - "functions" - ]["main"][0]["outputs"][key] + main_func = metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0] + workspace_size = int(main_func["workspace_size_bytes"]) + inputs = main_func["inputs"] + outputs = main_func["outputs"] + inputs_sizes = {name: property_map["size"] for name, property_map in inputs.items()} + output_sizes = {name: property_map["size"] for name, property_map in outputs.items()} + input_names = list(inputs.keys()) + output_names = list(outputs.keys()) generate_c_interface_header( mod.libmod_name, - inputs, - outputs, + input_names, + output_names, pools, io_pool_allocations, devices, diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 7ccaf72b1baf..e53e0dc96dac 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -632,5 +632,89 @@ def test_multiple_relay_modules_aot_graph(): assert metadata["version"] == _GENERATED_VERSION +@tvm.testing.requires_micro +def test_output_name_single(): + """Generate a conv2d Relay module for testing.""" + input_a = tvm.relay.var("input_a", shape=(3, 4, 5), dtype="int64") + output_1 = input_a + tvm.relay.const(1, "int64") + attrs = tvm.ir.make_node("DictAttrs", output_tensor_names=["test_output_a"]) + main_func = tvm.relay.Function([input_a], output_1, attrs=attrs) + mod = tvm.IRModule.from_expr(main_func) + mod = tvm.relay.transform.InferType()(mod) + + executor = Executor("aot", {"unpacked-api": True, "interface-api": "c"}) + runtime = Runtime("crt") + target = tvm.target.target.micro("host") + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + factory = tvm.relay.build(mod, target, runtime=runtime, executor=executor, mod_name="mod1") + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + + micro.export_model_library_format(factory, mlf_tar_path) + + tf = tarfile.open(mlf_tar_path) + extract_dir = temp_dir.relpath("extract") + os.mkdir(extract_dir) + tf.extractall(extract_dir) + + with open(os.path.join(extract_dir, "metadata.json")) as f: + metadata = json.load(f) + + assert metadata["modules"]["mod1"]["memory"]["functions"]["main"][0]["outputs"] == { + "test_output_a": {"size": 480, "dtype": "int64"} + } + + +@tvm.testing.requires_micro +def test_output_names_many(): + """Generate a conv2d Relay module for testing.""" + input_a = tvm.relay.var("input_a", shape=(3, 4, 5), dtype="int64") + input_b = tvm.relay.var("input_b", shape=(3, 4), dtype="int32") + input_c = tvm.relay.var("input_c", shape=(3,), dtype="float32") + + output_1 = input_a + tvm.relay.const(1, "int64") + output_2 = input_b + tvm.relay.const(2) + output_3 = input_b + tvm.relay.const(3) + output_4 = input_c + tvm.relay.const(4.0) + + full_output = tvm.relay.Tuple( + [output_1, tvm.relay.Tuple([tvm.relay.Tuple([output_2, output_3]), output_4])] + ) + attrs = tvm.ir.make_node( + "DictAttrs", + output_tensor_names=["test_output_a", "test_output_b", "test_output_c", "test_output_d"], + ) + main_func = tvm.relay.Function([input_a, input_b, input_c], full_output, attrs=attrs) + mod = tvm.IRModule.from_expr(main_func) + mod = tvm.relay.transform.InferType()(mod) + + executor = Executor("aot", {"unpacked-api": True, "interface-api": "c"}) + runtime = Runtime("crt") + target = tvm.target.target.micro("host") + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + factory = tvm.relay.build(mod, target, runtime=runtime, executor=executor, mod_name="mod1") + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + + micro.export_model_library_format(factory, mlf_tar_path) + + tf = tarfile.open(mlf_tar_path) + extract_dir = temp_dir.relpath("extract") + os.mkdir(extract_dir) + tf.extractall(extract_dir) + + with open(os.path.join(extract_dir, "metadata.json")) as f: + metadata = json.load(f) + + assert metadata["modules"]["mod1"]["memory"]["functions"]["main"][0]["outputs"] == { + "test_output_a": {"size": 480, "dtype": "int64"}, + "test_output_b": {"size": 48, "dtype": "int32"}, + "test_output_c": {"size": 48, "dtype": "int32"}, + "test_output_d": {"size": 12, "dtype": "float32"}, + } + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))