From 454949aab366bddf1ad3824782318e9c966d472c Mon Sep 17 00:00:00 2001 From: YuchenJin Date: Fri, 17 Feb 2023 09:56:38 -0800 Subject: [PATCH] relay to relax translator --- apps/relax_examples/e2e_auto_tir.py | 253 ++++++++++++++++ apps/relax_examples/mlp.py | 57 ++++ apps/relax_examples/nn_module.py | 69 +++++ apps/relax_examples/resnet.py | 53 ++++ python/tvm/relax/testing/__init__.py | 1 + python/tvm/relax/testing/relay_translator.py | 251 ++++++++++++++++ python/tvm/relax/testing/transform.py | 125 ++++++++ src/relay/backend/utils.cc | 7 + tests/python/relax/test_relay_translator.py | 300 +++++++++++++++++++ 9 files changed, 1116 insertions(+) create mode 100644 apps/relax_examples/e2e_auto_tir.py create mode 100644 apps/relax_examples/mlp.py create mode 100644 apps/relax_examples/nn_module.py create mode 100644 apps/relax_examples/resnet.py create mode 100644 python/tvm/relax/testing/relay_translator.py create mode 100644 python/tvm/relax/testing/transform.py create mode 100644 tests/python/relax/test_relay_translator.py diff --git a/apps/relax_examples/e2e_auto_tir.py b/apps/relax_examples/e2e_auto_tir.py new file mode 100644 index 000000000000..92cda16f7927 --- /dev/null +++ b/apps/relax_examples/e2e_auto_tir.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime +import os +import csv +import json +import argparse +import logging +from typing import Dict +import numpy as np # type: ignore + +import tvm +from tvm import relay, relax, runtime, transform +from tvm.ir.module import IRModule +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.relay_workload import get_network +from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc +from tvm.relax.testing import relay_translator +from tvm.target.target import Target + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--input-shape", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + default=None, + ) + args.add_argument( + "--rpc-port", + type=int, + default=None, + ) + args.add_argument( + "--rpc-key", + type=str, + default=None, + ) + args.add_argument( + "--work-dir", + type=str, + required=True, + ) + args.add_argument( + "--cache-dir", + type=str, + default=None, + ) + args.add_argument( + "--rpc-timeout-sec", + type=int, + default=180, + ) + args.add_argument("--num-measurement-repeats", type=int, default=5) + args.add_argument("--num-measurements", type=int, default=10) + args.add_argument("--results-file", type=str, required=False, default=None) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.input_shape = json.loads(parsed.input_shape) + if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu": + parsed.alloc_repeat = 3 + else: + parsed.alloc_repeat = 1 + if parsed.rpc_host and parsed.rpc_port and parsed.rpc_key: + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=parsed.rpc_timeout_sec, + ) + parsed.workers = parsed.rpc_config.count_num_servers(allow_missing=False) + else: + # check all rpc configs are None + assert ( + (parsed.rpc_host is None) and (parsed.rpc_port is None) and (parsed.rpc_key is None) + ), "Please set all 'rpc_host', 'rpc_port' and 'rpc_key' to use PRC server" + parsed.rpc_config = None + parsed.workers = 1 + return parsed + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +ARGS = _parse_args() + + +def apply_opt_before_tuning( + relay_mod: IRModule, params: Dict[str, runtime.NDArray], target: Target +): + with transform.PassContext(opt_level=3): + main_func = relay_mod["main"] + bind_main_func = relay.build_module.bind_params_by_name(main_func, params) + relay_mod = IRModule.from_expr(bind_main_func) + relay_mod = relay.transform.SimplifyInference()(relay_mod) + relay_mod = relay.transform.FoldConstant()(relay_mod) + relay_mod = relay.transform.FoldScaleAxis()(relay_mod) + relay_mod = relay.transform.CanonicalizeOps()(relay_mod) + relay_mod = relay.transform.AlterOpLayout()(relay_mod) + relay_mod = relay.transform.FoldConstant()(relay_mod) + + relax_mod = relay_translator.from_relay(relay_mod["main"], target=target) + relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod) + relax_mod = relax.transform.FuseOps()(relax_mod) + relax_mod = relax.transform.FuseTIR()(relax_mod) + return relax_mod + + +def f_measurement( + rt_mod: runtime.Module, device: runtime.ndarray.Device, input_data: Dict[str, runtime.NDArray] +): + vm = relax.vm.VirtualMachine(exec=rt_mod, device=device) + vm.save_function("main", "measure_func", **input_data, include_return=False) + evaluator = vm.time_evaluator( + func_name="measure_func", + dev=device, + repeat=ARGS.num_measurement_repeats, + number=ARGS.num_measurements, + min_repeat_ms=500, + ) + return evaluator() + + +def get_runner(): + runner_config = { + "evaluator_config": ms.runner.EvaluatorConfig( + number=3, + repeat=1, + min_repeat_ms=100, + enable_cpu_cache_flush=False, + ), + "alloc_repeat": ARGS.alloc_repeat, + } + if ARGS.rpc_config: + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, max_workers=ARGS.workers, **runner_config + ) + else: + runner = ms.runner.LocalRunner(**runner_config) + + return runner + + +def main(): + relay_mod, params, (input_name, input_shape, input_dtype) = get_network( + ARGS.workload, + ARGS.input_shape, + cache_dir=ARGS.cache_dir, + ) + input_info = {input_name: input_shape} + input_data = {} + for input_name, input_shape in input_info.items(): + print(f" input_name: {input_name}") + print(f" input_shape: {input_shape}") + print(f" input_dtype: {input_dtype}") + + # translate the ResNet model from Relay to Relax + relax_mod = apply_opt_before_tuning(relay_mod, params, target=ARGS.target) + assert isinstance(relax_mod, tvm.IRModule) + + db = ms.relax_integration.tune_relax( + mod=relax_mod, + target=ARGS.target, + params=params, + num_trials_per_iter=64, + max_trials_per_task=ARGS.num_trials, + max_trials_global=ARGS.num_trials, + runner=get_runner(), + work_dir=ARGS.work_dir, + ) + executable = ms.relax_integration.compile_relax( + db, + mod=relax_mod, + target=ARGS.target, + params=params, + ) + + for input_name, input_shape in input_info.items(): + if input_dtype.startswith("float"): + input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype) + else: + input_data[input_name] = np.random.randint( + low=0, high=10000, size=input_shape, dtype=input_dtype + ) + + # for documentation purposes + start_time = datetime.datetime.now() + + if ARGS.rpc_config: + result = run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=executable.mod, + dev_type=ARGS.target.kind.name, + args=input_data, + continuation=f_measurement, + ) + else: + dev = tvm.device(ARGS.target.kind.name) + result = f_measurement(executable.mod, dev, input_data) + + print(result) + + if not ARGS.results_file: + return + + out_path = os.path.abspath(os.path.expanduser(ARGS.results_file)) + with open(out_path, "w") as out_file: + writer = csv.writer(out_file) + # write experiment parameters at the top as a record + writer.writerow(["start", str(start_time)]) + writer.writerow(["workload", ARGS.workload]) + writer.writerow(["input_shape", ARGS.input_shape]) + writer.writerow(["target", ARGS.target]) + writer.writerow(["num_measurement_repeats", ARGS.num_measurement_repeats]) + for res in result.results: + writer.writerow([str(res)]) + + +if __name__ == "__main__": + main() diff --git a/apps/relax_examples/mlp.py b/apps/relax_examples/mlp.py new file mode 100644 index 000000000000..02e17dc3041a --- /dev/null +++ b/apps/relax_examples/mlp.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Example code on creating, compiling, and running an MLP model in relax + + +import tvm +from tvm import relax, tir, topi +import numpy as np + + +def build_mlp(data, weight): + bb = relax.BlockBuilder() + + with bb.function("mlp", [data, weight]): + gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False) + gv1 = bb.emit_te(topi.nn.relu, gv0) + bb.emit_func_output(gv1) + + mod = bb.get() + return mod + + +if __name__ == "__main__": + # symbolic dimensions + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + # create data and weight variables + data = relax.Var("data", relax.TensorStructInfo([n, m], "float32")) + weight = relax.Var("weight", relax.TensorStructInfo([m, n], "float32")) + + # construct a mlp model + mod = build_mlp(data, weight) + + # build and create vm executor + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + # run the mlp model on relax vm + data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + res = vm["mlp"](data, weight) + print(res) diff --git a/apps/relax_examples/nn_module.py b/apps/relax_examples/nn_module.py new file mode 100644 index 000000000000..b57cb00685ae --- /dev/null +++ b/apps/relax_examples/nn_module.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Example code on creating, compiling, and running a neural network with pytorch-like API + + +import tvm +from tvm.relay import Call +from tvm import relax, tir +from tvm.relax.testing import nn +from tvm.script import relax as R +import numpy as np + + +if __name__ == "__main__": + builder = relax.BlockBuilder() + + # a symbolic variable to represent minibatch size + n = tir.Var("n", "int64") + input_size = 784 + hidden_sizes = [128, 32] + output_size = 10 + + # build a three linear-layer neural network for a classification task + with builder.function("main"): + model = nn.Sequential( + nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LogSoftmax(), + ) + data = nn.Placeholder((n, input_size), name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + + # get and print the IRmodule being built + mod = builder.get() + mod.show() + + # build the IRModule and create relax vm + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + # init parameters + params = nn.init_params(mod) + + # run the model on relax vm + # the input data has a minibatch size of 3 + data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32)) + res = vm["main"](data, *params) + print(res) diff --git a/apps/relax_examples/resnet.py b/apps/relax_examples/resnet.py new file mode 100644 index 000000000000..df0cab02f19c --- /dev/null +++ b/apps/relax_examples/resnet.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example ResNet workload by translating the Relay program to Relax""" + +import tvm +import tvm.testing +from tvm.relay import testing +from tvm import relax, relay +from tvm.relax.testing import relay_translator, nn +from tvm.runtime import vm as vm_rt +from tvm.script import relax as R +import numpy as np + +if __name__ == "__main__": + relay_mod, _ = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32") + + # translate the ResNet model from Relay to Relax + target = tvm.target.Target("llvm", host="llvm") + relax_mod = relay_translator.from_relay(relay_mod["main"], target) + + # print the ResNet IRmodule got translated + relax_mod.show() + + # build the IRModule and create relax vm + ex = relax.vm.build(relax_mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + # init weights and run the model on relax vm + shape = (1, 3, 224, 224) + data = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + params = nn.init_params(relax_mod) + res = vm["main"](data, *params) + + # check correctness by comparing with relay result + exe = relay.vm.compile(relay_mod, target) + relay_vm = vm_rt.VirtualMachine(exe, tvm.cpu()) + inputs = [data] + params + expected_output = relay_vm.run(*inputs) + tvm.testing.assert_allclose(res.numpy(), expected_output.numpy(), rtol=1e-4, atol=1e-4) diff --git a/python/tvm/relax/testing/__init__.py b/python/tvm/relax/testing/__init__.py index ab1dd6f5155e..7344798f70dc 100644 --- a/python/tvm/relax/testing/__init__.py +++ b/python/tvm/relax/testing/__init__.py @@ -18,3 +18,4 @@ """The Relax testing namespace containing nn and translator.""" from .nn import * +from .relay_translator import * diff --git a/python/tvm/relax/testing/relay_translator.py b/python/tvm/relax/testing/relay_translator.py new file mode 100644 index 000000000000..fd5aab89fa76 --- /dev/null +++ b/python/tvm/relax/testing/relay_translator.py @@ -0,0 +1,251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, invalid-name, no-else-return, too-many-nested-blocks +"""Relay to Relax translator.""" + +from typing import Any, Dict, List, Optional + +import tvm +from tvm import relax, relay +from tvm.ir.module import IRModule +from tvm.relax.testing import nn +from tvm.relay.backend.te_compiler import select_implementation +from tvm.runtime import NDArray +from tvm.target import Target +from tvm.meta_schedule.relay_integration import _autotvm_silencer + + +def from_relay( + func: relay.Function, + target: Target, + relay_params: Optional[Dict[str, NDArray]] = None, + *, + opt_level: int = 3, + pass_config: Optional[Dict[str, Any]] = None, + disabled_pass: Optional[List[str]] = None, + translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] = None, +) -> IRModule: + """Convert a Relay function into a Relax program. + + Parameters + ---------- + func : relay.Function + Relay function to be converted. + + target: Target + The target to compile the model, used for selecting topi functions. + + relay_params: Optional[Dict[str, NDArray]] + Parameters to bind. + + opt_level: int + The optimization level. + + pass_config: Optional[Dict[str, Any]] + Pass configuration. + + disabled_pass: Optional[List[str]] + Passes to disable. + + translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] + Dict that maps op names to user-defined PrimFuncs. + Takes relay operator names and forces them to user-defined PrimFuncs during translation. + + Returns + ------- + mod : tvm.IRModule + The Relax IRModule for compilation + """ + # A map to store the mapping of Relay Expr to its corresponding Relax var + var_map = {} + # The output of the function + output_var = None + + if not isinstance(target, Target): + target = Target(target) + if disabled_pass is None: + disabled_pass = [] + if pass_config is None: + pass_config = { + "relay.FuseOps.max_depth": 1, # Disable relay fusion + "relay.backend.use_meta_schedule": True, + "relay.backend.use_meta_schedule_dispatch": True, + } + + if relay_params: + func = relay.build_module.bind_params_by_name(func, relay_params) + + params = [] + tir_var_map: Dict[tvm.tir.Var, tvm.tir.PrimExpr] = dict() + + def convert_shape(shape: List[tvm.tir.PrimExpr]) -> List[tvm.tir.PrimExpr]: + """Convert the relay shape to relax shape by changing Any dim to symbolic dim""" + ret = [] + for dim in shape: + if isinstance(dim, tvm.tir.IntImm): + ret.append(tvm.tir.IntImm("int64", int(dim))) + elif isinstance(dim, tvm.tir.Any): + ret.append(tvm.tir.Var("d", "int64")) + else: + ret.append(dim) + return ret + + def _copy_undefined_var_in_shape(sinfo: relax.TensorStructInfo): + def _visit_expr(e: tvm.tir.PrimExpr): + if isinstance(e, tvm.tir.Var) and e not in tir_var_map: + new_var = tvm.tir.Var(e.name, e.dtype) + tir_var_map[e] = new_var + + assert isinstance( + sinfo.shape, relax.ShapeExpr + ), "arg with TensorStructInfo in Relay translator must have ShapeExpr shape" + for shape_value in sinfo.shape.values: + tvm.tir.stmt_functor.post_order_visit(shape_value, _visit_expr) + + def visit_func(node): + nonlocal output_var + if isinstance(node, relay.Var): + if isinstance(node.type_annotation, relay.TensorType): + var_map[node] = nn.Placeholder( + tuple(convert_shape(node.type_annotation.shape)), + node.type_annotation.dtype, + node.name_hint, + ) + params.append(var_map[node]) + else: + raise TypeError("The type of relay.Var to be translated must be of TensorType.") + elif isinstance(node, relay.Call): + args = node.args + new_args = [] + te_inputs = [] + for arg in args: + if arg in var_map: + arg_expr = var_map[arg] + if isinstance(arg_expr.struct_info, relax.TensorStructInfo): + _copy_undefined_var_in_shape(arg_expr.struct_info) + new_args.append(arg_expr) + te_inputs.append(tvm.relax.expr.te_tensor(arg_expr, tir_var_map)) + elif isinstance(arg_expr.struct_info, relax.TupleStructInfo): + n_tensor = len(arg_expr.struct_info.fields) + bound_tuple = bb.lookup_binding(arg_expr) + if isinstance(bound_tuple, relax.Tuple): + assert len(bound_tuple) == n_tensor + for i in range(n_tensor): + if isinstance(bound_tuple, relax.Tuple): + item = bb.emit(bound_tuple[i]) + else: + item = bb.emit(relax.TupleGetItem(arg_expr, i)) + + assert isinstance(item.struct_info, relax.TensorStructInfo), ( + "Relay translator doesn't support Call " + "argument being nested Tensor tuple." + ) + _copy_undefined_var_in_shape(item.struct_info) + new_args.append(item) + te_inputs.append(tvm.relax.expr.te_tensor(item, tir_var_map)) + else: + raise TypeError( + f"CallTIR argument type being {type(arg_expr.checked_type)} is not " + "supported." + ) + + op_name = node.op.name + attrs = node.attrs + out_type = node.checked_type + + if translate_op_with_tir and op_name in translate_op_with_tir: + tir_gvar = bb.add_func(translate_op_with_tir[op_name], op_name) + call = relax.call_tir( + tir_gvar, new_args, relax.TensorStructInfo(out_type.shape, out_type.dtype) + ) + var = bb.emit(call) + else: + with target: + best_impl, outputs = select_implementation( + node.op, + attrs, + te_inputs, + out_type, + target, + use_autotvm=False, + ) + compute_func = best_impl.compute + name_hint = op_name.split(".")[-1] + var = bb.emit_te( + compute_func, + attrs, + new_args, + node.checked_type, + primfunc_name_hint=name_hint, + ) + + output_var = var + var_map[node] = var + elif isinstance(node, relay.Constant): + # fill the shape and checked_type fields of the Constant + new_constant = relax.Constant(node.data) + var_map[node] = new_constant + elif isinstance(node, relay.Tuple): + new_fields = [] + for field in node.fields: + if field in var_map: + new_fields.append(var_map[field]) + else: + raise RuntimeError("field is not in var_map.") + new_tuple = relax.Tuple(new_fields) + new_tuple_var = relax.BlockBuilder.current().emit(new_tuple) + var_map[node] = new_tuple_var + output_var = new_tuple_var + elif isinstance(node, relay.TupleGetItem): + if node.tuple_value in var_map: + new_tuple = var_map[node.tuple_value] + new_tuple_get_item_node = relax.TupleGetItem(new_tuple, node.index) + new_tuple_get_item_var = relax.BlockBuilder.current().emit(new_tuple_get_item_node) + var_map[node] = new_tuple_get_item_var + output_var = new_tuple_get_item_var + else: + raise RuntimeError("tuple is not in var_map") + elif isinstance(node, relay.Function): + cur_bb = relax.BlockBuilder.current() + gv = cur_bb.emit_output(output_var) + df_block = cur_bb._end_block() + cur_bb._blocks.append(df_block) + cur_bb.emit_func_output(gv, params) + elif isinstance(node, tvm.ir.Op): + pass + else: + raise TypeError("{} is not supported yet.".format(str(type(node)))) + + # List of subset of relay->relay optimizations + # See src/relay/backend/utils.cc::GetPassPrefix() for full list + seq = tvm.get_global_func("relay.backend.GetPassPrefixSeq")(True, True) + + # Since optimization passes and OpStrategy are highly context-dependent, + # we match the exact same context with `extract_task_from_relay()` env + with _autotvm_silencer(), tvm.transform.PassContext( + opt_level=opt_level, + config=pass_config, + disabled_pass=disabled_pass, + ): + mod = tvm.IRModule.from_expr(func) + mod = seq(mod) + bb = relax.BlockBuilder() + with bb.function("main"): + bb._begin_dataflow_block() + relay.analysis.post_order_visit(mod["main"], visit_func) + + return bb.get() diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py new file mode 100644 index 000000000000..c8ca618d4c1a --- /dev/null +++ b/python/tvm/relax/testing/transform.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ +"""Relax transformation passes for testing""" + +from tvm import ir +from tvm import relax +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.target import Target +from tvm.ir import transform +from tvm.relax import PyExprMutator +from tvm.relax.expr import Call +from tvm.relay.backend.te_compiler import select_implementation + + +@ir.transform.module_pass(opt_level=0) +class LowerWithRelayOpStrategyPass(transform.Pass): + """Lower Relax Op into TIR by using Relay OpStrategy. + + Since operators like conv2d, add, matmul are relay-, relax- independent, + this pass assumes we can always find relay op equivalent for such relax ops, + and use Relay Op Strategy (legacy) to perform lowering and find the TOPI implementation. + + Parameters + ---------- + target : Target + target info + + Returns + ------- + pass : transform.Pass + lowering pass + """ + + def __init__(self, target: Target): + self.target = target + + def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: + """Implement lowering mechanism. + + Parameters + ---------- + mod : IRModule + Input IRModule with Relax ops + + ctx: PassContext + Pass context + + Returns + ------- + out_mod : IRModule + Output IRModule with lowered TIR functions + """ + target = self.target + + @relax.expr_functor.mutator + class Lowerer(PyExprMutator): + """Mutator that performs lowering.""" + + def visit_call_(self, call_node: Call): + # Ignore function calls + # We only target calls for operators + if isinstance(call_node.op, (relax.GlobalVar, relax.expr.ExternFunc)): + return call_node + + # Current relax op name simply adds "relax." prefix to relay op name. + # Thus, remove "relax." prefix to deduce relay op name. + relay_op_name = call_node.op.name[6:] + # Check if equivalent relay op exists. If not, return the original call. + if relay_op_name in ir.Op.list_op_names(): + relay_op = ir.Op.get(relay_op_name) + + # Todo(relax-team): to be revisited - support dyn shape or deprecate. + tir_var_map = dict() + te_inputs = [relax.expr.te_tensor(arg, tir_var_map) for arg in call_node.args] + best_impl_tuple = select_implementation( + relay_op, + call_node.attrs, + te_inputs, + call_node.checked_type, + target, + use_autotvm=False, + ) + compute_func = best_impl_tuple[0].compute + # Extract the name of the operator without the prefix + # e.g., for relay op "nn.conv2d", name_hint would be conv2d + name_hint = relay_op_name.split(".")[-1] + + return self.builder_.call_te( + compute_func, + call_node.attrs, + call_node.args, + call_node.attrs, + primfunc_name_hint=name_hint, + ) + else: + return call_node + + # TOOD(@team): transform() wapper is necessary to include TIR functions. + # IMO, this is bit unintuitive. Can we improve this? + def transform(self): + for gv, func in mod.functions.items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + self.builder_.update_func(gv, updated_func) + new_mod = self.builder_.get() + new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else new_mod + return new_mod + + return Lowerer().transform() diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 4ff8a59b349e..3fb1c89c286e 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -443,6 +443,13 @@ TVM_REGISTER_GLOBAL("relay.backend.tir_converter.allow_extern") return DefaultTIRConverterImpl(args, constants, true); }); +TVM_REGISTER_GLOBAL("relay.backend.GetPassPrefixSeq") + .set_body_typed([](bool is_homogeneous, bool is_vm) { + auto pass_seqs = GetPassPrefix(is_homogeneous, is_vm); + transform::Sequential seq(pass_seqs); + return seq; + }); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/tests/python/relax/test_relay_translator.py b/tests/python/relax/test_relay_translator.py new file mode 100644 index 000000000000..5f7e05b02d3a --- /dev/null +++ b/tests/python/relax/test_relay_translator.py @@ -0,0 +1,300 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tempfile + +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import meta_schedule as ms +from tvm import relax, relay, tir, topi +from tvm.ir.base import assert_structural_equal +from tvm.relax.testing import relay_translator +from tvm.relay import testing +from tvm.runtime import vm +from tvm.script import tir as T +from tvm.target import Target + + +def get_resnet(batch_size, dtype, layout, image_shape): + relay_mod, params = testing.resnet.get_workload( + num_layers=18, + batch_size=batch_size, + dtype=dtype, + layout=layout, + image_shape=image_shape, + ) + + return relay_mod, params + + +def relay_build_and_run(mod, target, dev, params, data): + with tempfile.TemporaryDirectory() as work_dir: + db = ms.relay_integration.tune_relay( + mod=mod, + params=params, + target=target, + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=1024, + task_scheduler="round-robin", + work_dir=work_dir, + ) + ex = ms.relay_integration.compile_relay( + db, + mod=mod, + target=target, + params=params, + ) + rt_mod = tvm.contrib.graph_executor.GraphModule(ex["default"](dev)) + rt_mod.set_input("data", data) + rt_mod.run() + out = rt_mod.get_output(0).numpy() + return ex, rt_mod, out + + +def relax_build_and_run(mod, target, dev, params, data): + mod = relax.transform.BindParams("main", params)(mod) + with tempfile.TemporaryDirectory() as work_dir: + db = ms.relax_integration.tune_relax( + mod=mod, + target=target, + task_scheduler="round-robin", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=1024, + work_dir=work_dir, + ) + ex = ms.relax_integration.compile_relax( + db, + mod=mod, + target=target, + params=params, + ) + vm = relax.VirtualMachine(ex, dev) + res = vm["main"](data) + out = res.numpy() + return ex, vm, out + + +def verify_e2e_translation(target_str, layout, batch_size, image_shape): + target = Target(target_str) + dev = tvm.device(str(target), dev_id=0) + relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape) + input_shape = (1, *image_shape) + data = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32), dev) + relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) + assert relax_mod["main"].attrs["global_symbol"] == "main" + + _, _, relay_out = relay_build_and_run(relay_mod, target, dev, params, data) + _, _, relax_out = relax_build_and_run(relax_mod, target, dev, params, data) + tvm.testing.assert_allclose(relay_out, relax_out, atol=1e-5, rtol=1e-5) + + +@pytest.mark.skip(reason="take too much time") +@pytest.mark.parametrize( + "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 1, (224, 224, 3))] +) +def test_verify_e2e_translation_cpu(layout, batch_size, image_shape): + verify_e2e_translation("llvm --num-cores=16", layout, batch_size, image_shape) + + +@pytest.mark.skip(reason="take too much time") +@tvm.testing.requires_gpu +@pytest.mark.parametrize( + "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 1, (224, 224, 3))] +) +def test_verify_e2e_translation_gpu(layout, batch_size, image_shape): + verify_e2e_translation("cuda", layout, batch_size, image_shape) + + +def verify_extracted_tasks(target_str, layout, batch_size, image_shape): + target = Target(target_str) + relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape) + relax_mod = relay_translator.from_relay( + relay_mod["main"], + target, + params, + pass_config={ + "relay.backend.use_meta_schedule": True, + "relay.FuseOps.max_depth": 1, # Disable relay fusion + }, + ) + relay_tasks = ms.relay_integration.extract_tasks( + relay_mod, + target=target, + params=params, + pass_config={ + "relay.backend.use_meta_schedule": True, + "relay.FuseOps.max_depth": 1, # Disable relay fusion + }, + ) + relax_tasks = ms.relax_integration.extract_tasks( + relax_mod, + target=target, + params=params, + ) + # TODO (yongwww, yuchen): tophub guides relay passes, which causes inconsistent tasks + # assert len(relay_tasks) == len(relax_tasks) + # TODO: Can we compare extracted tasks as well? + + +@pytest.mark.parametrize( + "layout, batch_size, image_shape", + [ + ("NCHW", 1, (3, 224, 224)), + ("NHWC", 1, (224, 224, 3)), + ], +) +def test_verify_extracted_tasks_cpu(layout, batch_size, image_shape): + verify_extracted_tasks("llvm --num-cores=16", layout, batch_size, image_shape) + + +@tvm.testing.requires_gpu +@pytest.mark.parametrize( + "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 1, (224, 224, 3))] +) +def test_verify_extracted_tasks_gpu(layout, batch_size, image_shape): + verify_extracted_tasks("cuda", layout, batch_size, image_shape) + + +def translate_and_build_vms(relay_mod, target_str="llvm", translate_op_with_tir=None): + target = tvm.target.Target(target_str) + + # build the relay IRModule and create relay vm + relay_ex = relay.vm.compile(relay_mod, target) + relay_vm = vm.VirtualMachine(relay_ex, tvm.cpu()) + + # build the relax IRModule and create relax vm + relax_mod = relay_translator.from_relay( + relay_mod["main"], target, translate_op_with_tir=translate_op_with_tir + ) + relax_ex = relax.vm.build(relax_mod, target) + relax_vm = relax.VirtualMachine(relax_ex, tvm.cpu()) + + return relay_vm, relax_vm, relax_mod + + +def verify_vm_outputs( + input_shape, + relay_vm, + relax_vm, + extra_args=[], +): + input = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32)) + + # check correctness by comparing relax and relay result + args = [input] + extra_args + relax_output = relax_vm["main"](*args) + relay_output = relay_vm.run(*args) + tvm.testing.assert_allclose(relay_output.numpy(), relax_output.numpy()) + + +def test_single_dynamic_dim(): + wx, wy = 64, 128 + # create relay module: y = data * weights + bias with dynamic batch dimension + data = relay.var("data", shape=(relay.Any(), wx)) + weights = relay.var("weights", shape=(wx, wy)) + bias = relay.var("bias", shape=(wy,)) + y = relay.nn.matmul(data, weights) + relay_mod = tvm.IRModule.from_expr(relay.Function([data, weights, bias], y + bias)) + + relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod) + weights = tvm.nd.array(np.random.rand(wx, wy).astype(np.float32)) + bias = tvm.nd.array(np.random.rand(wy).astype(np.float32)) + # verify for different batch sizes + verify_vm_outputs([10, wx], relay_vm, relax_vm, [weights, bias]) + verify_vm_outputs([32, wx], relay_vm, relax_vm, [weights, bias]) + + +def test_multiple_dynamic_dims(): + # create relay module: y = a + a, where a has shape = (?, 5, ?) + shape = (relay.Any(), 5, relay.Any()) + a = relay.var("a", shape=shape) + + relay_mod = tvm.IRModule.from_expr(relay.Function([a], a + a)) + relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod) + # verify for different shapes + verify_vm_outputs([2, 5, 10], relay_vm, relax_vm) + verify_vm_outputs([12, 5, 24], relay_vm, relax_vm) + + +def test_layout_transform(): + shape = (1, 3, 224, 224) + a = relay.var("a", shape=shape) + b = relay.layout_transform(a, "NCHW", "NHWC") + relay_mod = tvm.IRModule.from_expr(relay.Function([a], b)) + + relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod) + verify_vm_outputs([1, 3, 224, 224], relay_vm, relax_vm) + + +def test_translate_op_with_tir(): + @T.prim_func + def tir_matmul( + A: T.Buffer((512, 512), "float32"), + B: T.Buffer((512, 512), "float32"), + C: T.Buffer((512, 512), "float32"), + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "multiply", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1, i2 in T.grid(512, 512, 512): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(C[i, j], A[i, k], B[k, j]) + T.writes(C[i, j]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + shape = (512, 512) + a = relay.var("a", shape=shape) + + relay_mod = tvm.IRModule.from_expr(relay.Function([a], a * a)) + _, _, relax_mod = translate_and_build_vms( + relay_mod, translate_op_with_tir={"multiply": tir_matmul} + ) + assert_structural_equal(relax_mod["multiply"], tir_matmul) + + +def test_translate_tuple_arg(): + x = relay.var("x", shape=(10, 16)) + y = relay.var("y", shape=(10, 16)) + relay_mod = tvm.IRModule.from_expr(relay.Function([x, y], relay.concatenate((x, y), axis=-1))) + relax_mod = relay_translator.from_relay(relay_mod["main"], target="llvm") + + # Construct the expected module + bb = relax.BlockBuilder() + x_relax = relax.Var("x", relax.TensorStructInfo([10, 16], "float32")) + y_relax = relax.Var("y", relax.TensorStructInfo([10, 16], "float32")) + with bb.function("main", [x_relax, y_relax]): + with bb.dataflow(): + _ = bb.emit(relax.Tuple((x_relax, y_relax))) + lv1 = bb.emit(x_relax) + lv2 = bb.emit(y_relax) + lv3 = bb.emit_te(topi.x86.concatenate, (lv1, lv2), axis=-1) + gv = bb.emit_output(lv3) + bb.emit_func_output(gv) + + assert_structural_equal(relax_mod, bb.get()) + + +if __name__ == "__main__": + pytest.main([__file__])