diff --git a/3rdparty/cutlass b/3rdparty/cutlass index d8359c804b7e..92ebbf1dc461 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit d8359c804b7e3915a0f0668c19213f63ae88aac6 +Subproject commit 92ebbf1dc4612bf838ace6f2e6d262919f0abd63 diff --git a/CMakeLists.txt b/CMakeLists.txt index 818e8b50addb..22e82e2fb74a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -289,6 +289,14 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/driver/*.cc src/support/*.cc src/script/*.cc + src/relax/ir/*.cc + src/relax/op/*.cc + src/relax/analysis/*.cc + src/relax/transform/*.cc + src/relax/backend/vm/*.cc + src/relax/backend/task_extraction.cc + src/relax/backend/pattern_registry.cc + src/relax/utils.cc ) tvm_file_glob(GLOB CODEGEN_SRCS @@ -335,6 +343,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS src/runtime/*.cc src/runtime/vm/*.cc src/runtime/minrpc/*.cc + src/runtime/relax_vm/*.cc ) if(BUILD_FOR_HEXAGON) diff --git a/apps/relax_examples/e2e_auto_tir.py b/apps/relax_examples/e2e_auto_tir.py new file mode 100644 index 000000000000..8113f942d166 --- /dev/null +++ b/apps/relax_examples/e2e_auto_tir.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime +import os +import csv +import json +import argparse +import logging +from typing import Dict +import numpy as np # type: ignore + +import tvm +from tvm import relay, relax, runtime, transform +from tvm.ir.module import IRModule +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.relay_workload import get_network +from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc +from tvm.relax.testing import relay_translator +from tvm.target.target import Target + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--input-shape", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + default=None, + ) + args.add_argument( + "--rpc-port", + type=int, + default=None, + ) + args.add_argument( + "--rpc-key", + type=str, + default=None, + ) + args.add_argument( + "--work-dir", + type=str, + required=True, + ) + args.add_argument( + "--cache-dir", + type=str, + default=None, + ) + args.add_argument( + "--rpc-timeout-sec", + type=int, + default=180, + ) + args.add_argument("--num-measurement-repeats", type=int, default=5) + args.add_argument("--num-measurements", type=int, default=10) + args.add_argument("--results-file", type=str, required=False, default=None) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.input_shape = json.loads(parsed.input_shape) + if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu": + parsed.alloc_repeat = 3 + else: + parsed.alloc_repeat = 1 + if parsed.rpc_host and parsed.rpc_port and parsed.rpc_key: + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=parsed.rpc_timeout_sec, + ) + parsed.workers = parsed.rpc_config.count_num_servers(allow_missing=False) + else: + # check all rpc configs are None + assert ( + (parsed.rpc_host is None) and (parsed.rpc_port is None) and (parsed.rpc_key is None) + ), "Please set all 'rpc_host', 'rpc_port' and 'rpc_key' to use PRC server" + parsed.rpc_config = None + parsed.workers = 1 + return parsed + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +ARGS = _parse_args() + + +def apply_opt_before_tuning( + relay_mod: IRModule, params: Dict[str, runtime.NDArray], target: Target +): + with transform.PassContext(opt_level=3): + main_func = relay_mod["main"] + bind_main_func = relay.build_module.bind_params_by_name(main_func, params) + relay_mod = IRModule.from_expr(bind_main_func) + relay_mod = relay.transform.SimplifyInference()(relay_mod) + relay_mod = relay.transform.FoldConstant()(relay_mod) + relay_mod = relay.transform.FoldScaleAxis()(relay_mod) + relay_mod = relay.transform.CanonicalizeOps()(relay_mod) + relay_mod = relay.transform.AlterOpLayout()(relay_mod) + relay_mod = relay.transform.FoldConstant()(relay_mod) + + relax_mod = relay_translator.from_relay(relay_mod["main"], target=target) + relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod) + relax_mod = relax.transform.FuseOps()(relax_mod) + relax_mod = relax.transform.FuseTIR()(relax_mod) + return relax_mod + + +def f_measurement( + rt_mod: runtime.Module, device: runtime.ndarray.Device, input_data: Dict[str, runtime.NDArray] +): + vm = relax.VirtualMachine(rt_mod, device=device) + vm.save_function("main", "measure_func", **input_data, include_return=False) + evaluator = vm.time_evaluator( + func_name="measure_func", + dev=device, + repeat=ARGS.num_measurement_repeats, + number=ARGS.num_measurements, + min_repeat_ms=500, + ) + return evaluator() + + +def get_runner(): + runner_config = { + "evaluator_config": ms.runner.EvaluatorConfig( + number=3, + repeat=1, + min_repeat_ms=100, + enable_cpu_cache_flush=False, + ), + "alloc_repeat": ARGS.alloc_repeat, + } + if ARGS.rpc_config: + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, max_workers=ARGS.workers, **runner_config + ) + else: + runner = ms.runner.LocalRunner(**runner_config) + + return runner + + +def main(): + relay_mod, params, (input_name, input_shape, input_dtype) = get_network( + ARGS.workload, + ARGS.input_shape, + cache_dir=ARGS.cache_dir, + ) + input_info = {input_name: input_shape} + input_data = {} + for input_name, input_shape in input_info.items(): + print(f" input_name: {input_name}") + print(f" input_shape: {input_shape}") + print(f" input_dtype: {input_dtype}") + + # translate the ResNet model from Relay to Relax + relax_mod = apply_opt_before_tuning(relay_mod, params, target=ARGS.target) + assert isinstance(relax_mod, tvm.IRModule) + + db = ms.relax_integration.tune_relax( + mod=relax_mod, + target=ARGS.target, + params=params, + num_trials_per_iter=64, + max_trials_per_task=ARGS.num_trials, + max_trials_global=ARGS.num_trials, + runner=get_runner(), + work_dir=ARGS.work_dir, + ) + executable = ms.relax_integration.compile_relax( + db, + mod=relax_mod, + target=ARGS.target, + params=params, + ) + + for input_name, input_shape in input_info.items(): + if input_dtype.startswith("float"): + input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype) + else: + input_data[input_name] = np.random.randint( + low=0, high=10000, size=input_shape, dtype=input_dtype + ) + + # for documentation purposes + start_time = datetime.datetime.now() + + if ARGS.rpc_config: + result = run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=executable.mod, + dev_type=ARGS.target.kind.name, + args=input_data, + continuation=f_measurement, + ) + else: + dev = tvm.device(ARGS.target.kind.name) + result = f_measurement(executable.mod, dev, input_data) + + print(result) + + if not ARGS.results_file: + return + + out_path = os.path.abspath(os.path.expanduser(ARGS.results_file)) + with open(out_path, "w") as out_file: + writer = csv.writer(out_file) + # write experiment parameters at the top as a record + writer.writerow(["start", str(start_time)]) + writer.writerow(["workload", ARGS.workload]) + writer.writerow(["input_shape", ARGS.input_shape]) + writer.writerow(["target", ARGS.target]) + writer.writerow(["num_measurement_repeats", ARGS.num_measurement_repeats]) + for res in result.results: + writer.writerow([str(res)]) + + +if __name__ == "__main__": + main() diff --git a/apps/relax_examples/mlp.py b/apps/relax_examples/mlp.py new file mode 100644 index 000000000000..2a81b61543fd --- /dev/null +++ b/apps/relax_examples/mlp.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Example code on creating, compiling, and running an MLP model in relax + + +import tvm +from tvm import relax, tir, topi +import numpy as np + + +def build_mlp(data, weight): + bb = relax.BlockBuilder() + + with bb.function("mlp", [data, weight]): + gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False) + gv1 = bb.emit_te(topi.nn.relu, gv0) + bb.emit_func_output(gv1) + + mod = bb.get() + return mod + + +if __name__ == "__main__": + # symbolic dimensions + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + # create data and weight variables + data = relax.Var("data", relax.TensorStructInfo([n, m], "float32")) + weight = relax.Var("weight", relax.TensorStructInfo([m, n], "float32")) + + # construct a mlp model + mod = build_mlp(data, weight) + + # build and create vm executor + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + # run the mlp model on relax vm + data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + res = vm["mlp"](data, weight) + print(res) diff --git a/apps/relax_examples/nn_module.py b/apps/relax_examples/nn_module.py new file mode 100644 index 000000000000..57a13e4fb51b --- /dev/null +++ b/apps/relax_examples/nn_module.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Example code on creating, compiling, and running a neural network with pytorch-like API + + +import tvm +from tvm.relay import Call +from tvm import relax, tir +from tvm.relax.testing import nn +from tvm.script import relax as R +import numpy as np + + +if __name__ == "__main__": + builder = relax.BlockBuilder() + + # a symbolic variable to represent minibatch size + n = tir.Var("n", "int64") + input_size = 784 + hidden_sizes = [128, 32] + output_size = 10 + + # build a three linear-layer neural network for a classification task + with builder.function("main"): + model = nn.Sequential( + nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LogSoftmax(), + ) + data = nn.Placeholder((n, input_size), name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + + # get and print the IRmodule being built + mod = builder.get() + mod.show() + + # build the IRModule and create relax vm + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + # init parameters + params = nn.init_params(mod) + + # run the model on relax vm + # the input data has a minibatch size of 3 + data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32)) + res = vm["main"](data, *params) + print(res) diff --git a/apps/relax_examples/resnet.py b/apps/relax_examples/resnet.py new file mode 100644 index 000000000000..6c7350d77847 --- /dev/null +++ b/apps/relax_examples/resnet.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example ResNet workload by translating the Relay program to Relax""" + +import tvm +import tvm.testing +from tvm.relay import testing +from tvm import relax, relay +from tvm.relax.testing import relay_translator, nn +from tvm.runtime import vm as vm_rt +from tvm.script import relax as R +import numpy as np + +if __name__ == "__main__": + relay_mod, _ = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32") + + # translate the ResNet model from Relay to Relax + target = tvm.target.Target("llvm", host="llvm") + relax_mod = relay_translator.from_relay(relay_mod["main"], target) + + # print the ResNet IRmodule got translated + relax_mod.show() + + # build the IRModule and create relax vm + ex = relax.build(relax_mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + # init weights and run the model on relax vm + shape = (1, 3, 224, 224) + data = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + params = nn.init_params(relax_mod) + res = vm["main"](data, *params) + + # check correctness by comparing with relay result + exe = relay.vm.compile(relay_mod, target) + relay_vm = vm_rt.VirtualMachine(exe, tvm.cpu()) + inputs = [data] + params + expected_output = relay_vm.run(*inputs) + tvm.testing.assert_allclose(res.numpy(), expected_output.numpy(), rtol=1e-4, atol=1e-4) diff --git a/ci/jenkins/generated/arm_jenkinsfile.groovy b/ci/jenkins/generated/arm_jenkinsfile.groovy index 4c830dce2c30..ffcfa9b842d7 100644 --- a/ci/jenkins/generated/arm_jenkinsfile.groovy +++ b/ci/jenkins/generated/arm_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/cortexm_jenkinsfile.groovy b/ci/jenkins/generated/cortexm_jenkinsfile.groovy index d8a4d4671e86..c1a62736702b 100644 --- a/ci/jenkins/generated/cortexm_jenkinsfile.groovy +++ b/ci/jenkins/generated/cortexm_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/cpu_jenkinsfile.groovy b/ci/jenkins/generated/cpu_jenkinsfile.groovy index cdd2564e0591..e689cbb65583 100644 --- a/ci/jenkins/generated/cpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/cpu_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/docker_jenkinsfile.groovy b/ci/jenkins/generated/docker_jenkinsfile.groovy index 32dec7863bcf..74e3ddfabeac 100644 --- a/ci/jenkins/generated/docker_jenkinsfile.groovy +++ b/ci/jenkins/generated/docker_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/gpu_jenkinsfile.groovy b/ci/jenkins/generated/gpu_jenkinsfile.groovy index 390c8ddc3dc2..f14e8f541b41 100644 --- a/ci/jenkins/generated/gpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/gpu_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/hexagon_jenkinsfile.groovy b/ci/jenkins/generated/hexagon_jenkinsfile.groovy index 58fe4d14c969..7d5bd3309ee5 100644 --- a/ci/jenkins/generated/hexagon_jenkinsfile.groovy +++ b/ci/jenkins/generated/hexagon_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/i386_jenkinsfile.groovy b/ci/jenkins/generated/i386_jenkinsfile.groovy index b5bf5cb1fe40..98e09c393a69 100644 --- a/ci/jenkins/generated/i386_jenkinsfile.groovy +++ b/ci/jenkins/generated/i386_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/lint_jenkinsfile.groovy b/ci/jenkins/generated/lint_jenkinsfile.groovy index ed5aa8d67954..1a3120efb0e1 100644 --- a/ci/jenkins/generated/lint_jenkinsfile.groovy +++ b/ci/jenkins/generated/lint_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy b/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy index 4c748e3f20d7..08143791c68e 100644 --- a/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy +++ b/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/minimal_jenkinsfile.groovy b/ci/jenkins/generated/minimal_jenkinsfile.groovy index 72864ec4ca0f..ff10d01670ce 100644 --- a/ci/jenkins/generated/minimal_jenkinsfile.groovy +++ b/ci/jenkins/generated/minimal_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/riscv_jenkinsfile.groovy b/ci/jenkins/generated/riscv_jenkinsfile.groovy index 2dfeb3561281..df1160b3c1e5 100644 --- a/ci/jenkins/generated/riscv_jenkinsfile.groovy +++ b/ci/jenkins/generated/riscv_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/wasm_jenkinsfile.groovy b/ci/jenkins/generated/wasm_jenkinsfile.groovy index 27e8f6570ed0..37b50f97ad17 100644 --- a/ci/jenkins/generated/wasm_jenkinsfile.groovy +++ b/ci/jenkins/generated/wasm_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/unity_jenkinsfile.groovy b/ci/jenkins/unity_jenkinsfile.groovy new file mode 100644 index 000000000000..0f85a7dc320e --- /dev/null +++ b/ci/jenkins/unity_jenkinsfile.groovy @@ -0,0 +1,337 @@ +#!groovy +// -*- mode: groovy -*- + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Jenkins pipeline +// See documents at https://jenkins.io/doc/book/pipeline/jenkinsfile/ + +// ============================= IMPORTANT NOTE ============================= +// To keep things simple +// This file is manually updated to maintain unity branch specific builds. +// Please do not send this file to main + + +import org.jenkinsci.plugins.pipeline.modeldefinition.Utils + +// NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> +ci_lint = 'tlcpack/ci-lint:20221025-182121-e41d0ed6e' +ci_gpu = 'tlcpack/ci-gpu:20221128-070141-ae4fd7df7' +ci_cpu = 'tlcpack/ci-cpu:20230110-070003-d00168ffb' +ci_wasm = 'tlcpack/ci-wasm:v0.72' +ci_i386 = 'tlcpack/ci-i386:v0.75' +ci_qemu = 'tlcpack/ci-qemu:v0.11' +ci_arm = 'tlcpack/ci-arm:v0.08' +ci_hexagon = 'tlcpack/ci-hexagon:20221025-182121-e41d0ed6e' +// <--- End of regex-scanned config. + +// Parameters to allow overriding (in Jenkins UI), the images +// to be used by a given build. When provided, they take precedence +// over default values above. +properties([ + parameters([ + string(name: 'ci_lint_param', defaultValue: ''), + string(name: 'ci_cpu_param', defaultValue: ''), + string(name: 'ci_gpu_param', defaultValue: ''), + string(name: 'ci_wasm_param', defaultValue: ''), + string(name: 'ci_i386_param', defaultValue: ''), + string(name: 'ci_qemu_param', defaultValue: ''), + string(name: 'ci_arm_param', defaultValue: ''), + string(name: 'ci_hexagon_param', defaultValue: '') + ]) +]) + +// tvm libraries +tvm_runtime = 'build/libtvm_runtime.so, build/config.cmake' +tvm_lib = 'build/libtvm.so, ' + tvm_runtime +// LLVM upstream lib +tvm_multilib = 'build/libtvm.so, ' + + 'build/libvta_fsim.so, ' + + tvm_runtime + +tvm_multilib_tsim = 'build/libvta_tsim.so, ' + + tvm_multilib + +// command to start a docker container +docker_run = 'docker/bash.sh' +// timeout in minutes +max_time = 240 + +def per_exec_ws(folder) { + return "workspace/exec_${env.EXECUTOR_NUMBER}/" + folder +} + +// initialize source codes +def init_git() { + checkout scm + // Add more info about job node + sh ( + script: './tests/scripts/task_show_node_info.sh', + label: 'Show executor node info', + ) + retry(5) { + timeout(time: 2, unit: 'MINUTES') { + sh (script: 'git submodule update --init -f', label: 'Update git submodules') + } + } +} + +def should_skip_slow_tests(pr_number) { + withCredentials([string( + credentialsId: 'tvm-bot-jenkins-reader', + variable: 'GITHUB_TOKEN', + )]) { + // Exit code of 1 means run slow tests, exit code of 0 means skip slow tests + result = sh ( + returnStatus: true, + script: "./tests/scripts/should_run_slow_tests.py --pr '${pr_number}'", + label: 'Check if CI should run slow tests', + ) + } + return result == 0 +} + +def cancel_previous_build() { + // cancel previous build if it is not on main. + if (env.BRANCH_NAME != 'main') { + def buildNumber = env.BUILD_NUMBER as int + // Milestone API allows us to cancel previous build + // with the same milestone number + if (buildNumber > 1) milestone(buildNumber - 1) + milestone(buildNumber) + } +} + +def should_skip_ci(pr_number) { + withCredentials([string( + credentialsId: 'tvm-bot-jenkins-reader', + variable: 'TOKEN', + )]) { + // Exit code of 1 means run full CI (or the script had an error, so run + // full CI just in case). Exit code of 0 means skip CI. + git_skip_ci_code = sh ( + returnStatus: true, + script: "./tests/scripts/git_skip_ci.py --pr '${pr_number}'", + label: 'Check if CI should be skipped', + ) + } + return git_skip_ci_code == 0 +} + +cancel_previous_build() + +def lint() { +stage('Prepare') { + node('CPU-SMALL') { + // When something is provided in ci_*_param, use it, otherwise default with ci_* + ci_lint = params.ci_lint_param ?: ci_lint + ci_cpu = params.ci_cpu_param ?: ci_cpu + ci_gpu = params.ci_gpu_param ?: ci_gpu + ci_wasm = params.ci_wasm_param ?: ci_wasm + ci_i386 = params.ci_i386_param ?: ci_i386 + ci_qemu = params.ci_qemu_param ?: ci_qemu + ci_arm = params.ci_arm_param ?: ci_arm + ci_hexagon = params.ci_hexagon_param ?: ci_hexagon + + sh (script: """ + echo "Docker images being used in this build:" + echo " ci_lint = ${ci_lint}" + echo " ci_cpu = ${ci_cpu}" + echo " ci_gpu = ${ci_gpu}" + echo " ci_wasm = ${ci_wasm}" + echo " ci_i386 = ${ci_i386}" + echo " ci_qemu = ${ci_qemu}" + echo " ci_arm = ${ci_arm}" + echo " ci_hexagon = ${ci_hexagon}" + """, label: 'Docker image names') + } +} + +stage('Sanity Check') { + timeout(time: max_time, unit: 'MINUTES') { + node('CPU-SMALL') { + ws(per_exec_ws('tvm/sanity')) { + init_git() + is_docs_only_build = sh ( + returnStatus: true, + script: './tests/scripts/git_change_docs.sh', + label: 'Check for docs only changes', + ) + skip_ci = should_skip_ci(env.CHANGE_ID) + skip_slow_tests = should_skip_slow_tests(env.CHANGE_ID) + sh ( + script: "${docker_run} ${ci_lint} ./tests/scripts/task_lint.sh", + label: 'Run lint', + ) + sh ( + script: "${docker_run} ${ci_lint} ./tests/scripts/unity/task_extra_lint.sh", + label: 'Run extra lint', + ) + } + } + } +} +} + +lint() + +// Run make. First try to do an incremental make from a previous workspace in hope to +// accelerate the compilation. If something is wrong, clean the workspace and then +// build from scratch. +def make(docker_type, path, make_flag) { + timeout(time: max_time, unit: 'MINUTES') { + try { + cmake_build(docker_type, path, make_flag) + // always run cpp test when build + // sh "${docker_run} ${docker_type} ./tests/scripts/task_cpp_unittest.sh" + } catch (hudson.AbortException ae) { + // script exited due to user abort, directly throw instead of retry + if (ae.getMessage().contains('script returned exit code 143')) { + throw ae + } + echo 'Incremental compilation failed. Fall back to build from scratch' + sh ( + script: "${docker_run} ${docker_type} ./tests/scripts/task_clean.sh ${path}", + label: 'Clear old cmake workspace', + ) + cmake_build(docker_type, path, make_flag) + cpp_unittest(docker_type) + } + } +} + +// Specifications to Jenkins "stash" command for use with various pack_ and unpack_ functions. +tvm_runtime = 'build/libtvm_runtime.so, build/config.cmake' // use libtvm_runtime.so. +tvm_lib = 'build/libtvm.so, ' + tvm_runtime // use libtvm.so to run the full compiler. +// LLVM upstream lib +tvm_multilib = 'build/libtvm.so, ' + + 'build/libvta_fsim.so, ' + + tvm_runtime + +tvm_multilib_tsim = 'build/libvta_tsim.so, ' + + tvm_multilib + +microtvm_tar_gz = 'build/microtvm_template_projects.tar.gz' + +// pack libraries for later use +def pack_lib(name, libs) { + sh (script: """ + echo "Packing ${libs} into ${name}" + echo ${libs} | sed -e 's/,/ /g' | xargs md5sum + """, label: 'Stash libraries and show md5') + stash includes: libs, name: name +} + +// unpack libraries saved before +def unpack_lib(name, libs) { + unstash name + sh (script: """ + echo "Unpacked ${libs} from ${name}" + echo ${libs} | sed -e 's/,/ /g' | xargs md5sum + """, label: 'Unstash libraries and show md5') +} + +// compress microtvm template projects and pack the tar. +def pack_microtvm_template_projects(name) { + sh( + script: 'cd build && tar -czvf microtvm_template_projects.tar.gz microtvm_template_projects/', + label: 'Compress microtvm_template_projects' + ) + pack_lib(name + '-microtvm-libs', microtvm_tar_gz) +} + +def unpack_microtvm_template_projects(name) { + unpack_lib(name + '-microtvm-libs', microtvm_tar_gz) + sh( + script: 'cd build && tar -xzvf microtvm_template_projects.tar.gz', + label: 'Unpack microtvm_template_projects' + ) +} + +def ci_setup(image) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_ci_setup.sh", + label: 'Set up CI environment', + ) +} + +def python_unittest(image) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_python_unittest.sh", + label: 'Run Python unit tests', + ) +} + +def fsim_test(image) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_python_vta_fsim.sh", + label: 'Run VTA tests in FSIM', + ) +} + +def cmake_build(image, path, make_flag) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_build.py --sccache-bucket tvm-sccache-prod", + label: 'Run cmake build', + ) +} + +def cpp_unittest(image) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_cpp_unittest.sh", + label: 'Build and run C++ tests', + ) +} + +def add_hexagon_permissions() { + sh( + script: 'find build/hexagon_api_output -type f | xargs chmod +x', + label: 'Add execute permissions for hexagon files', + ) +} + +// NOTE: limit tests to relax folder for now to allow us to skip some of the tests +// that are mostly related to changes in main. +// This helps to speedup CI time and reduce CI cost. +stage('Build and Test') { + if (is_docs_only_build != 1) { + parallel 'BUILD: GPU': { + node('GPU') { + ws(per_exec_ws('tvm/build-gpu')) { + init_git() + sh "${docker_run} ${ci_gpu} nvidia-smi" + sh "${docker_run} ${ci_gpu} ./tests/scripts/task_config_build_gpu.sh build" + make("${ci_gpu}", 'build', '-j2') + sh "${docker_run} ${ci_gpu} ./tests/scripts/unity/task_python_relax_gpuonly.sh" + } + } + }, + 'BUILD: CPU': { + node('CPU-SMALL') { + ws(per_exec_ws('tvm/build-cpu')) { + init_git() + sh "${docker_run} ${ci_cpu} ./tests/scripts/task_config_build_cpu.sh build" + make(ci_cpu, 'build', '-j2') + sh "${docker_run} ${ci_cpu} ./tests/scripts/unity/task_python_relax.sh" + } + } + } + } else { + Utils.markStageSkippedForConditional('BUILD: CPU') + } +} diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index bbbf6b89ba2e..96d5922e84d9 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -92,6 +92,10 @@ if(USE_CUDA) tvm_file_glob(GLOB RUNTIME_CUDA_GRAPH_SRCS src/runtime/graph_executor/cuda_graph/*.cc) list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_GRAPH_SRCS}) endif() + + # Add CUDA builtins to RelaxVM + tvm_file_glob(GLOB RELAX_VM_CUDA_BUILTIN_SRC_CC src/runtime/relax_vm/cuda/*.cc) + list(APPEND RUNTIME_SRCS ${RELAX_VM_CUDA_BUILTIN_SRC_CC}) else(USE_CUDA) list(APPEND COMPILER_SRCS src/target/opt/build_cuda_off.cc) endif(USE_CUDA) diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index afd5ef530252..4b4ef355b678 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -16,8 +16,8 @@ # under the License. if(USE_CUDA AND USE_CUTLASS) - tvm_file_glob(GLOB CUTLASS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc) - list(APPEND COMPILER_SRCS ${CUTLASS_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc) + list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC}) message(STATUS "Build with CUTLASS") endif() diff --git a/cmake/modules/contrib/DNNL.cmake b/cmake/modules/contrib/DNNL.cmake index 2c6b03daeccf..631b3e47aba1 100644 --- a/cmake/modules/contrib/DNNL.cmake +++ b/cmake/modules/contrib/DNNL.cmake @@ -21,8 +21,8 @@ if(IS_DIRECTORY ${USE_DNNL}) message(WARNING "Cannot find DNNL library at ${USE_DNNL}.") else() add_definitions(-DUSE_JSON_RUNTIME=1) - tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc) - list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc src/relax/backend/contrib/dnnl/*.cc) + list(APPEND COMPILER_SRCS ${DNNL_CONTRIB_SRC}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -34,8 +34,8 @@ if(IS_DIRECTORY ${USE_DNNL}) endif() elseif((USE_DNNL STREQUAL "ON") OR (USE_DNNL STREQUAL "JSON")) add_definitions(-DUSE_JSON_RUNTIME=1) - tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc) - list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc src/relax/backend/contrib/dnnl/*.cc) + list(APPEND COMPILER_SRCS ${DNNL_CONTRIB_SRC}) find_library(EXTERN_LIBRARY_DNNL dnnl) list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) diff --git a/cmake/modules/contrib/TensorRT.cmake b/cmake/modules/contrib/TensorRT.cmake index 696108b50142..a749b6e80fd2 100644 --- a/cmake/modules/contrib/TensorRT.cmake +++ b/cmake/modules/contrib/TensorRT.cmake @@ -23,7 +23,7 @@ include (FindPackageHandleStandardArgs) if(USE_TENSORRT_CODEGEN) message(STATUS "Build with TensorRT codegen") - tvm_file_glob(GLOB COMPILER_TENSORRT_SRCS src/relay/backend/contrib/tensorrt/*.cc) + tvm_file_glob(GLOB COMPILER_TENSORRT_SRCS src/relay/backend/contrib/tensorrt/*.cc src/relax/backend/contrib/tensorrt/*.cc) set_source_files_properties(${COMPILER_TENSORRT_SRCS} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") tvm_file_glob(GLOB RUNTIME_TENSORRT_SRCS src/runtime/contrib/tensorrt/tensorrt_runtime.cc) set_source_files_properties(${RUNTIME_TENSORRT_SRCS} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index c8531c88465a..c662067a0486 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -367,6 +367,14 @@ class RelayExprNode : public BaseExprNode { * This value is discarded during serialization. */ mutable Type checked_type_ = Type(nullptr); + + /*! + * \brief Stores the result of structure information of the + * expression that encapsulate both static shape and + * runtime information such as shape. + */ + mutable Optional struct_info_ = Optional(); + /*! * \return The checked_type */ @@ -454,6 +462,7 @@ class GlobalVarNode : public RelayExprNode { v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); + v->Visit("struct_info_", &struct_info_); } bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const { diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 1493544e7324..381ea6b8d6d3 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -65,6 +65,68 @@ enum class CallingConv : int { kDeviceKernelLaunch = 2, }; +/*! + * \brief Supported linkage types. + */ +enum class LinkageType : int { + /*! + * \brief Internal linkage. + */ + kInternal = 0, + /*! + * \brief External linkage. + - Function with external linkage should have a global symbol attached to it. + */ + kExternal = 1 +}; + +/*! + * \brief Generic attribute names that can be attached to any function. + * + * \sa tvm::tir::attr, tvm::relay::attr + */ +namespace attr { +/*! + * \brief Indicates the special calling convention. + * + * Type: Integer + * + * \sa tvm::CallingConv + */ +constexpr const char* kCallingConv = "calling_conv"; + +/*! + * \brief Compilation target of the function. + * + * Type: Target + * + * \sa tvm::Target + */ +constexpr const char* kTarget = "target"; + +/*! + * \brief Global linker symbol of the function in generated code. + * + * This option forces the code generator to name the + * function with the given. + * + * For example, we could set a global_symbol of a function + * early to make sure that we can always refer to it by + * the symbol name in the generated DLL. + * + * We should not set the attribute for local functions, + * so that the compiler can freely rename them. + * + * A unique global symbol will be automatically assigned + * to each function in the module before the target code + * generation phase. + * + * Type: String + */ +constexpr const char* kGlobalSymbol = "global_symbol"; + +} // namespace attr + /*! * \brief Base node of all functions. * @@ -130,6 +192,31 @@ class BaseFuncNode : public RelayExprNode { * \endcode */ bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); } + /*! + * \brief Get the type of the linkage. + * + * Currently, we only consider external/internal linkage. + * This can be extended in the future when necessary. + * + * \return Linkage type. + * + * \code + * + * void Example(const BaseFunc& f) { + * if (f->GetLinkageType() == tvm::LinkageType::kExternal) { + * // Do not remove a function with external linkage + * } + * } + * + * \endcode + */ + + LinkageType GetLinkageType() const { + if (GetAttr(attr::kGlobalSymbol)) + return LinkageType::kExternal; + else + return LinkageType::kInternal; + } static constexpr const char* _type_key = "BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; @@ -145,51 +232,5 @@ class BaseFunc : public RelayExpr { TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); }; -/*! - * \brief Generic attribute names that can be attached to any function. - * - * \sa tvm::tir::attr, tvm::relay::attr - */ -namespace attr { -/*! - * \brief Indicates the special calling convention. - * - * Type: Integer - * - * \sa tvm::CallingConv - */ -constexpr const char* kCallingConv = "calling_conv"; - -/*! - * \brief Compilation target of the function. - * - * Type: Target - * - * \sa tvm::Target - */ -constexpr const char* kTarget = "target"; - -/*! - * \brief Global linker symbol of the function in generated code. - * - * This option forces the code generator to name the - * function with the given. - * - * For example, we could set a global_symbol of a function - * early to make sure that we can always refer to it by - * the symbol name in the generated DLL. - * - * We should not set the attribute for local functions, - * so that the compiler can freely rename them. - * - * A unique global symbol will be automatically assigned - * to each function in the module before the target code - * generation phase. - * - * Type: String - */ -constexpr const char* kGlobalSymbol = "global_symbol"; - -} // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h new file mode 100644 index 000000000000..65b5e0a3d28d --- /dev/null +++ b/include/tvm/ir/global_info.h @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/global_info.h + * \brief GlobalInfo are globally static object that are referred by the IR itself. + */ + +#ifndef TVM_IR_GLOBAL_INFO_H_ +#define TVM_IR_GLOBAL_INFO_H_ + +#include "tvm/ir/expr.h" + +namespace tvm { + +/*! + * \brief GlobalInfo are globally static object that are referred by the IR itself. + * Base node for all global info that can appear in the IR + */ +class GlobalInfoNode : public Object { + public: + static constexpr const char* _type_key = "GlobalInfoNode"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(GlobalInfoNode, Object); +}; + +/*! + * \brief Managed reference to GlobalInfoNode. + * \sa GlobalInfoNode + */ +class GlobalInfo : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(GlobalInfo, ObjectRef, GlobalInfoNode); +}; + +/*! + * \brief A dummy global info sub-class for testing purpose. + */ +class DummyGlobalInfoNode : public GlobalInfoNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "DummyGlobalInfo"; + + TVM_DLL bool SEqualReduce(const DummyGlobalInfoNode* other, SEqualReducer equal) const { + return true; + } + + TVM_DLL void SHashReduce(SHashReducer hash_reduce) const {} + TVM_DECLARE_FINAL_OBJECT_INFO(DummyGlobalInfoNode, GlobalInfoNode); +}; + +/*! + * \brief Managed reference to DummyGlobalInfoNode. + * \sa DummyGlobalInfoNode + */ +class DummyGlobalInfo : public GlobalInfo { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DummyGlobalInfo, GlobalInfo, DummyGlobalInfoNode); +}; + +} // namespace tvm + +#endif // TVM_IR_GLOBAL_INFO_H_ diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index fdb44b11887c..4c2d5cd81264 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -63,6 +64,8 @@ class IRModuleNode : public Object { SourceMap source_map; /* \brief Additional attributes storing meta-data about the module. */ DictAttrs attrs; + /*! \brief Globally static object that are referred by the IR itself */ + Map> global_infos; /*! * \brief A map from string names to global variables that * ensures global uniqueness. @@ -115,6 +118,12 @@ class IRModuleNode : public Object { return GetAttr(attr_key, Optional(default_value)); } + /*! + * \brief Get the metadata attributes. + * \returns The additional meta-data attributes + */ + DictAttrs GetAttrs() const { return attrs; } + /*! * \brief Check whether the module has an non-zero integer attr. * @@ -145,6 +154,7 @@ class IRModuleNode : public Object { v->Visit("global_type_var_map_", &global_type_var_map_); v->Visit("source_map", &source_map); v->Visit("attrs", &attrs); + v->Visit("global_infos", &global_infos); } TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const; @@ -204,6 +214,13 @@ class IRModuleNode : public Object { */ TVM_DLL void UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type); + /*! + * \brief Update an array of global infos in the global environment. + * \param name The name of the global info. + * \param info The new array of global infos. + */ + TVM_DLL void UpdateGlobalInfo(const String& name, const Array& info); + /*! * \brief Remove a function from the global environment. * \param var The name of the global function to update. @@ -353,12 +370,13 @@ class IRModule : public ObjectRef { * \param type_definitions Type definitions in the module. * \param import_set Set of imported files in the module. * \param map The module source map. - * \param attrs The module attributes. + * \param attrs The module meta-data attributes. + * \param global_infos Global infos in the module. */ TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, std::unordered_set import_set = {}, SourceMap map = {}, - DictAttrs attrs = {}); + DictAttrs attrs = {}, Map> global_infos = {}); /*! \brief default constructor */ IRModule() : IRModule(Map({})) {} diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 473e6291685d..ff54a6b5eacd 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -32,18 +32,18 @@ * - Reducing the effort required to implement new passes for compiler * developers, etc. * - * Similar to LLVM's pass manager, we designed the Relay pass manager to work + * Similar to LLVM's pass manager, we designed the Relay/Relax pass manager to work * different granularity, i.e. module level, function level, and even sequential * passe that contains a host of passes. * * However, we also extend the functionality of the traditional pass manager * with the consideration of requirements/convention from deep learning - * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass + * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay/Relax pass * manager performs the IRModule -> IRModule transformation. All * different types of passes, including the sequential-level pass object, are * essentially pass objects. This design, therefore, effectively provides users * a consistent and convenient interface, i.e. Pass, to play with. It offers a - * means to ease the development and testing of Relay passes. For example, with + * means to ease the development and testing of Relay/Relax passes. For example, with * the pass manager, external users will be able to have custom passes correctly * scheduled without having to modify a single handcrafted pass order. * @@ -90,7 +90,16 @@ class PassContextNode : public Object { /*! \brief A list of pass instrument implementations. */ Array instruments; - + // TODO(@sunggg): Fix dependency issue in the header file and correct the types + // e.g., relax::trace, relax::database in tvm/relax/tuning_api.h + /*! \brief Trace stack for relax pass infra. */ + mutable Array trace_stack; + /*! \brief List of passes to be traced. If not defined, make every pass traceable. */ + Optional> make_traceable; + /*! \brief Number of evaluations conducted in the pass pipeline. */ + mutable int num_evals{0}; + /*! \brief Database for tuning API. */ + Optional tuning_api_database; PassContextNode() = default; /*! @@ -130,7 +139,27 @@ class PassContextNode : public Object { v->Visit("instruments", &instruments); v->Visit("config", &config); v->Visit("diag_ctx", &diag_ctx); + v->Visit("trace_stack", &trace_stack); + v->Visit("make_traceable", &make_traceable); + v->Visit("num_evals", &num_evals); + v->Visit("tuning_api_daatabase", &tuning_api_database); + } + + Array GetTraceStack() { return trace_stack; } + void PushTrace(ObjectRef new_trace) { trace_stack.push_back(new_trace); } + void PopTrace() { + ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check."; + trace_stack.pop_back(); } + int GetTraceStackSize() { return trace_stack.size(); } + ObjectRef GetCurrentTrace() { + ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check."; + return trace_stack.back(); + } + void SetNumEvals(int _num_evals) { num_evals = _num_evals; } + void IncNumEvals(int _num_evals) { num_evals += _num_evals; } + + Optional GetTuningAPIDatabase() { return tuning_api_database; } static constexpr const char* _type_key = "transform.PassContext"; static constexpr bool _type_has_method_sequal_reduce = false; @@ -287,6 +316,9 @@ class PassInfoNode : public Object { /*! \brief The name of an optimization/analysis pass. */ String name; + /*! \brief Boolean that tells whether this pass will be traced or not. */ + bool traceable; + /*! \brief The passes that are required to perform the current pass. */ Array required; @@ -296,6 +328,7 @@ class PassInfoNode : public Object { v->Visit("opt_level", &opt_level); v->Visit("name", &name); v->Visit("required", &required); + v->Visit("traceable", &traceable); } static constexpr const char* _type_key = "transform.PassInfo"; @@ -314,8 +347,9 @@ class PassInfo : public ObjectRef { * \param opt_level The optimization level * \param name Name of the pass. * \param required The passes that are required to perform the current pass. + * \param traceable Boolean that tells whether the pass is traceable. */ - TVM_DLL PassInfo(int opt_level, String name, Array required); + TVM_DLL PassInfo(int opt_level, String name, Array required, bool traceable); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -323,7 +357,7 @@ class PassInfo : public ObjectRef { /*! * \brief PassNode is the base type of differnt types of optimization passes. * It is designed as a pure class and implemented by different pass subclasses - * at different granularity of Relay nodes. + * at different granularity of Relay/Relax nodes. */ class PassNode : public Object { public: @@ -396,7 +430,7 @@ class Pass : public ObjectRef { }; /*! - * \brief The SequentialNode contains a set of passes that transform Relay + * \brief The SequentialNode contains a set of passes that transform Relay/Relax * programs from one AST to another semantically equivalent one. * * One example of this level of pass is that the pass manager needs to correctly @@ -489,9 +523,9 @@ class Sequential : public Pass { * * \return The created module pass. */ -TVM_DLL Pass -CreateModulePass(const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, Array required); +TVM_DLL Pass CreateModulePass( + const runtime::TypedPackedFunc& pass_func, int opt_level, + String name, Array required, bool traceable = false); /*! * \brief A special trace pass that prints the header and IR to LOG(INFO). diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index c6baf5e08be3..ec13635a2643 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -131,8 +131,9 @@ class PrimType : public Type { /*! * \brief Constructor * \param dtype The corresponding dtype. + * \param span The span */ - TVM_DLL explicit PrimType(runtime::DataType dtype); + TVM_DLL explicit PrimType(runtime::DataType dtype, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode); }; diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index f4fec04035fc..00e4d925a055 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -43,6 +43,13 @@ class PrinterConfigNode : public Object { std::string ir_prefix = "I"; /*! \brief The prefix of TIR nodes */ std::string tir_prefix = "T"; + /*! \brief The prefix of Relax nodes */ + std::string relax_prefix = "R"; + /*! + * \brief The alias of the current module at cross-function call + * \note Directly use module name if it's empty. + */ + std::string module_alias = "cls"; /*! \brief Default data type of TIR buffer */ DataType buffer_dtype = DataType::Float(32); /*! \brief Default data type of integer literals */ @@ -76,6 +83,9 @@ class PrinterConfigNode : public Object { v->Visit("binding_names", &binding_names); v->Visit("show_meta", &show_meta); v->Visit("ir_prefix", &ir_prefix); + v->Visit("tir_prefix", &tir_prefix); + v->Visit("relax_prefix", &relax_prefix); + v->Visit("module_alias", &module_alias); v->Visit("buffer_dtype", &buffer_dtype); v->Visit("int_dtype", &int_dtype); v->Visit("float_dtype", &float_dtype); @@ -90,6 +100,8 @@ class PrinterConfigNode : public Object { v->Visit("obj_to_annotate", &obj_to_annotate); } + Array GetBuiltinKeywords(); + static constexpr const char* _type_key = "node.PrinterConfig"; TVM_DECLARE_FINAL_OBJECT_INFO(PrinterConfigNode, Object); }; diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h new file mode 100644 index 000000000000..21dc5ef42056 --- /dev/null +++ b/include/tvm/relax/analysis.h @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis.h + * \brief The set of Relax specific analysis on IR. + */ +#ifndef TVM_RELAX_ANALYSIS_H_ +#define TVM_RELAX_ANALYSIS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { +//----------------------------------- +// Shape expression analysis +//---------------------------------- +/*! + * \brief Can prove the two symbolic shape arrays equals to each other. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana The analyzer used for integer analysis. + * \return The prove result. + * + * \note This function does best effort prove, which means + * if result is false, there is still possibility that + * two shapes equals to each other during runtime. + */ +TVM_DLL bool CanProveShapeEqual(const Array& lhs, const Array& rhs, + arith::Analyzer* ana); + +/*! + * \brief Can prove the two symbolic shape expressions equals to each other. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana The analyzer used for integer analysis. + * + * \note This function does best effort prove, which means + * if result is false, there is still possibility that + * two shapes equals to each other during runtime. + */ +TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana); + +//----------------------------------- +// Foundational StructInfo analysis +//----------------------------------- +/*! + * \brief Get the corresponding static type from a given struct info. + * \param info The struct info. + * \return the corresponding static type. + */ +TVM_DLL Type GetStaticType(const StructInfo& info); + +/*! + * \brief Get the corresponding struct info from static type. + * \param type The input type + * \return the corresponding struct info. + */ +TVM_DLL StructInfo StructInfoFromType(const Type& type); + +/*! + * \return Derive the call's ret value struct info from inputs. + * \param finfo The function struct info. + * \param call The call expression to be derived. + * \param ctx The builder context. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return The derived struct info of the call. + * \note call->op field is ignored during derivation and we only rely on information + * presented by func_sinfo. + */ +TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, + const BlockBuilder& ctx, arith::Analyzer* ana = nullptr); + +/*! + * \brief Erase the info to a corresponding more coarse grained + * struct info that is still well-defined(with all the vars in scope). + * + * When we are returning a StructInfo to another scope, + * it is important to remember that StructInfo may carry + * dependencies on var that is not defined the other scope. + * + * In such cases, it is important to call EraseToWellDefined to get + * another StructInfo that **only** contains the vars that are defined + * in the target scope. + * + * For example, consider the following function + * + * \code + * + * @R.function + * def f(x: R.Tensor[(n, m)]): + * k = tir.Var("k", "int64") + * v0 = opaque_fn(x) + * v1 = match_cast(v0, R.Tensor[(n, k)]) + * v2 : R.Tensor[(n + 1, k + 2)] = pad(v1) + * return v2 + * + * \endcode + * + * In the above code, the return value y have shape `(n + 1, k + 2)`, + * However, at the level of function signature, only n, m are defined, + * k is undefined here. + * + * When we call EraseToWellDefined(R.Tensor[(n + 1, k + 2)], fshape_var_map={n: n, m: m}), + * we will obtain R.Tensor(ndim=2), which is an erased info that does not depend + * on k(which is undefined from parameter signature). + * + * However, if we call EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: n, m: m}), + * Then the return value will be R.Tensor[(n + 1, m)], because both n and m are defined. + * + * We can also make these var map to return a different expression. + * For example, EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: 2, m: m}) + * will give us R.Tensor[(3, m)], where n get replaced by 2. + * + * Use this function in the following scenarios: + * - Decide the struct_info of expr with sub-scopes, such as If, SeqExpr + * - Decide the deduced return struct_info of a function that can be fully decided by params. + * + * \param info The struct info. + * \param f_shape_var_map callback function to specify + * whether a symbolic shape var is defined and the value it maps to, + * return nullopt if var is undefined. + * \param f_var_map callback function to specify + * whether a var is defined in the target scope and the value it maps to, + * return nullopt if var is undefined. + * \param ana Optional context analyzer to prove symbolic expression equality. + * + * \return the corresponding erased struct info. + */ +TVM_DLL StructInfo +EraseToWellDefined(const StructInfo& info, + std::function(const tir::Var& var)> f_shape_var_map = nullptr, + std::function(const Var& var)> f_var_map = nullptr, + arith::Analyzer* ana = nullptr); + +/*! + * \brief EraseToWellDefined variant with map. + * \param info The struct info. + * \param shape_var_map map to specify + * whether a symbolic shape var is defined and the value it maps to, + * return nullopt if var is undefined. + * \param var_map map to specify + * whether a var is defined in the target scope and the value it maps to, + * return nullopt if var is undefined. + * \param ana Optional context analyzer to prove symbolic expression equality. + * + * \return the corresponding erased struct info. + */ +TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, Map shape_var_map, + Map var_map, arith::Analyzer* ana = nullptr); + +/*! + * \brief Fine grained result of base check. + * + * This analysis comes with different levels of checking failures + * that can help to customize the compilation decisions. + * + * For a given pair of lhs_struct_info, rhs_struct_info. We adopt + * the following terminology: + * - LSet = {value | value matches lhs_struct_info} + * - RSet = {value | value matches rhs_struct_info} + * + * See the definition of each level below. + */ +enum class BaseCheckResult { + /*! + * \brief The two value sets have no intersection at all: Interset(LSet, RSet) = empty + */ + kFailL0 = 0, + /*! + * \brief LSet is not superset of RSet by only looking at static information. + * + * \note This level will trigger static type checking error when lhs is param and rhs is arg. + */ + kFailL1 = 1, + /*! + * \brief WLSet is not superset of RSet because of mismatch in value information. + * + * L1-level mismatches in params of FuncStructInfo is categorized as + * If lhs is FuncStructInfo, then L1-level mismatch in its params + * is categorized as L2-level mismatch for lhs. + * + * Design considerations for functions: + * - (a) We want to be able to erase type/value in function signature + * when we unify function struct info and preserve simpler representations. + * - (b) We automatically insert match_cast at function boundary, so + * we can erase (int)->int argument as (object)->int. + * The input shape/type mismatch will be detected by runtime checks at function boundary. + * This behavior is also consistent with the PackedFunc behavior. + * + * \note This level means there is no problem about static known information. + * It is OK for the checker to do best effort and return this value. + */ + kFailL2 = 2, + /*! \brief LSet is superset of RSet. */ + kPass = 3 +}; + +/*! + * \brief Run a base check to see if base subsumes derived. + * + * This function returns fine-grained base-check result on reasons of failure. + * + * \param base The base struct info. + * \param derived The derived struct info. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return Whether the relation holds. + * + * \sa BaseCheckResult + */ +TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, + arith::Analyzer* ana = nullptr); + +/*! + * \brief Check the relation of two struct info to see if one subsumes another one. + * + * \param base The base struct info. + * \param derived The derived struct info. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return Whether the relation holds. + */ +TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, + arith::Analyzer* ana = nullptr); + +/*! + * \brief Unify the two struct info to their least common ancestor. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return The unified information. + */ +TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, + arith::Analyzer* ana = nullptr); + +//----------------------------------- +// General IR analysis +//----------------------------------- +/*! + * \brief Get all bound variables from expression expr. + * + * Bound variables are all variables that are declared in the expr. + * They only have meaning inside that expr, and can only be used in it. + * + * \param expr the expression. + * + * \return List of bound vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array BoundVars(const Expr& expr); + +/*! + * \brief Get free type parameters from expression expr. + * + * Free variables are variables that are not bound by a + * varbinding or a function parameter in the context. + * + * \param expr the expression. + * + * \return List of free vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array FreeVars(const Expr& expr); + +/*! + * \brief Get all variables from expression expr. + * + * \param expr the expression. + * + * \return List of all vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array AllVars(const Expr& expr); + +/*! + * \brief Get all global variables from expression expr. + * + * AllVars is a superset of BoundVars and FreeVars. + * The union of BoundVars and FreeVars is Allvars. + * + * \param expr the expression. + * + * \return List of all global variables, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array AllGlobalVars(const Expr& expr); + +/*! + * \brief Find all sets of recursive or mutually recursive functions in the module. + * + * Two or more functions are mutually recursive if there is some cycle of references + * among them. For example, if there are two functions A and B, they are + * mutually recursive if A calls B and B calls A. Another case would be with + * three functions A, B, and C, where A calls B, B calls C, and C calls A. + * + * (Note that functions do not have to call each other to reference each other. + * For example, if a function returns another function, that is still a reference + * that could potentially be recursive, even without a call.) + * + * If a function is simply recursive and not mutually recursive with any other, + * it will be reported as a group by itself. + * + * \param m The module + * + * \return List of all groups of mutually recursive functions. + * Each member of the result is a list of functions in the module + * that are all mutually recursive. + * If a function is simply recursive and not mutually recursive with any other, + * then it will be listed as a group by itself. + */ +TVM_DLL tvm::Array> DetectRecursion(const IRModule& m); + +/*! + * \brief Analyze var -> value mapping from VarBindings. + * + * \param m The IRModule to check. + * \return Var -> Value (Expr) + */ +TVM_DLL Map AnalyzeVar2Value(const IRModule& m); + +/*! + * \brief Analyze var -> value mapping from VarBindings. + * + * \param expr The expression to check. + * \return Var -> Value (Expr) + */ +TVM_DLL Map AnalyzeVar2Value(const Expr& expr); + +/*! + * \brief Analyze var -> value mapping from VarBindings. + * + * \param dfb The dataflow block to check. + * \return Var -> Value (Expr) + */ +TVM_DLL Map AnalyzeVar2Value(const DataflowBlock& dfb); + +/*! + * \brief Return a mapping from variable name to its Bindings. + * + * \param fn The function to be analyzed. + * \return A mapping from variable name to its Bindings. + */ +TVM_DLL Map> NameToBinding(const Function& fn); + +/*! + * \brief Get the use-def chain of variables inside a dataflow block. + * + * \param dfb The dataflow block to be analyzed. + * \return A map mapping variable definitions to a set of uses. + */ +TVM_DLL Map> DataflowBlockUseDef(const DataflowBlock& dfb); + +/*! + * \brief Get the use-def chain of variables inside a function. + * + * \param fn The function to be analyzed. + * \return A map from variable definitions to a set of uses and variables needed by return value. + */ +std::pair>, Array> FunctionUseDef(const Function& fn); + +/*! + * \brief Remove unused statements inside DataflowBlocks. + * + * \param fn The function to remove unused statements. + * \return The function that contains no unused statements in DataflowBlock. + */ +TVM_DLL Function RemoveAllUnused(const Function fn); + +/*! + * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps. + * + * \param func The PrimFunc to be analyzed. + * \return The Op Pattern Kind. + * + * \note This analysis applies on TIR function but is primarily used by relax passes. + * As a result we place it under the relax namespace. + */ +TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func); + +/*! + * \brief Check if the given PrimFunc is essentially doing a reshape operation. + * The reshape operation also includes expand_dims, squeeze, flatten, etc. + * \details Here the allowed reshape pattern is: for example, assume the operation is + * `B[l_0, l_1, ..., l_b] = A[r_0, r_1, ..., r_a]`, we check if we can prove that the flattened + * index of l_0, ..., l_b under buffer B equals to the flattened index of r_0, ..., r_a under + * buffer A. + * \param func The function to be examined. + * \return A boolean indicating if the given PrimFunc is doing a reshape. + * \note According to the description above, the returned result can only be false-negative and + * cannot be false-positive, since whenever we cannot prove the equality, we return false. This + * property guarantees the safety of this function. + */ +TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); + +/*! + * \brief Check if the IRModule is well formed. + * + * \param m the IRModule to check. + * \param check_struct_info A boolean flag indicating if the property "every Expr + * must have defined structure info" will be checked. + * \return true if the IRModule is well formed, false if not. + * \note By default the structure info is always checked. It is only in test cases + * where `check_struct_info` might be false, so that other well-formed requirements + * will be well tested and will not be blocked by not having structure info. + */ +TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true); + +/*! + * \brief Using the layout transforms on the outputs, suggest layout transformation on the blocks + * and buffers for the PrimFunc. + * + * \param fn The PrimFunc to be analyzed. + * \param write_buffer_transformations Array of IndexMap transformations on PrimFunc outputs. + * \return Suggested transforms per block in `fn`. For each block the returned value is a map + * from the object (block or buffer) to it's index map transformation. + */ + +TVM_DLL Map> SuggestLayoutTransforms( + const Function& fn, Array write_buffer_transformations); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ANALYSIS_H_ diff --git a/include/tvm/relax/attrs/create.h b/include/tvm/relax/attrs/create.h new file mode 100644 index 000000000000..6af176a42c9d --- /dev/null +++ b/include/tvm/relax/attrs/create.h @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/create.h + * \brief Attributes for tensor creation operators. + */ +#ifndef TVM_RELAX_ATTRS_CREATE_H_ +#define TVM_RELAX_ATTRS_CREATE_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operators */ +struct InitAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(InitAttrs, "relax.attrs.InitAttrs") { + TVM_ATTR_FIELD(dtype).describe("The data type of the created tensor."); + } +}; // struct InitAttrs + +/*! \brief Attributes used in tril and triu operator */ +struct TriluAttrs : public tvm::AttrsNode { + int k; + + TVM_DECLARE_ATTRS(TriluAttrs, "relax.attrs.TriluAttrs") { + TVM_ATTR_FIELD(k).describe( + "The number of diagonals above or below the main diagonal to exclude or include."); + } +}; // struct TriluAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_CREATE_H_ diff --git a/include/tvm/relax/attrs/datatype.h b/include/tvm/relax/attrs/datatype.h new file mode 100644 index 000000000000..c5a5a4e7d22e --- /dev/null +++ b/include/tvm/relax/attrs/datatype.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/datatype.h + * \brief Attributes for datatype operators. + */ +#ifndef TVM_RELAX_ATTRS_DATATYPE_H_ +#define TVM_RELAX_ATTRS_DATATYPE_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in astype operator */ +struct AstypeAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(AstypeAttrs, "relax.attrs.AstypeAttrs") { + TVM_ATTR_FIELD(dtype).describe("Target data type"); + } +}; // struct AstypeAttrs. + +/*! \brief Attributes used in wrap_param operator */ +struct WrapParamAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(WrapParamAttrs, "relax.attrs.WrapParamAttrs") { + TVM_ATTR_FIELD(dtype).describe("Target data type"); + } +}; // struct WrapParamAttrs. + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_DATATYPE_H_ diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h new file mode 100644 index 000000000000..13463aaa4849 --- /dev/null +++ b/include/tvm/relax/attrs/image.h @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/image.h + * \brief Attributes for image operators. + */ +#ifndef TVM_RELAX_ATTRS_IMAGE_H_ +#define TVM_RELAX_ATTRS_IMAGE_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in image resize2d operator */ +struct Resize2DAttrs : public tvm::AttrsNode { + Array roi; + String layout; + String method; + String coordinate_transformation_mode; + String rounding_method; + double cubic_alpha; + int cubic_exclude; + double extrapolation_value; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Resize2DAttrs, "relax.attrs.Resize2DAttrs") { + TVM_ATTR_FIELD(roi).describe( + "Region of Interest for coordinate transformation mode 'tf_crop_and_resize'"); + TVM_ATTR_FIELD(layout).describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method).describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Bilinear Interpolation" + "cubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(rounding_method) + .describe( + "indicates how to find the \"nearest\" pixel in nearest_neighbor method" + "Available options are round, floor, and ceil."); + TVM_ATTR_FIELD(cubic_alpha).describe("Spline Coefficient for Bicubic Interpolation"); + TVM_ATTR_FIELD(cubic_exclude) + .describe("Flag to exclude exterior of the image during bicubic interpolation"); + TVM_ATTR_FIELD(extrapolation_value) + .describe("Value to return when roi is outside of the image"); + TVM_ATTR_FIELD(out_dtype).describe( + "The dtype of the output tensor. It it is not specified, the output will have the same " + "dtype as input if not specified."); + } +}; // struct Resize2dAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_IMAGE_H_ diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h new file mode 100644 index 000000000000..c95395a80376 --- /dev/null +++ b/include/tvm/relax/attrs/index.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/index.h + * \brief Attributes for indexing operators. + */ +#ifndef TVM_RELAX_ATTRS_INDEX_H_ +#define TVM_RELAX_ATTRS_INDEX_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in take operator */ +struct TakeAttrs : public tvm::AttrsNode { + Optional axis; + + TVM_DECLARE_ATTRS(TakeAttrs, "relax.attrs.TakeAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis over which to select values."); + } +}; // struct TakeAttrs + +/*! \brief Attributes used in strided_slice operator */ +struct StridedSliceAttrs : public tvm::AttrsNode { + Array axes; + Array begin; + Array end; + Optional> strides; + + TVM_DECLARE_ATTRS(StridedSliceAttrs, "relax.attrs.StridedSliceAttrs") { + TVM_ATTR_FIELD(axes).describe("Axes along which slicing is applied."); + TVM_ATTR_FIELD(begin).describe("The indices to begin with in the slicing, inclusive."); + TVM_ATTR_FIELD(end).describe("The indices indicating end of the slice, exclusive."); + TVM_ATTR_FIELD(strides).describe( + "Specifies the stride values, it can be negative in that case, the input tensor will be " + "reversed in that particular axis. If not specified, it by default is an list of ones of " + "the same length as `axes`."); + } +}; // struct StridedSliceAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_INDEX_H_ diff --git a/include/tvm/relax/attrs/linear_algebra.h b/include/tvm/relax/attrs/linear_algebra.h new file mode 100644 index 000000000000..4b0e04298c9e --- /dev/null +++ b/include/tvm/relax/attrs/linear_algebra.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/linear_algebra.h + * \brief Attributes for linear algebra operators. + */ +#ifndef TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_ +#define TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes for matmul operator */ +struct MatmulAttrs : public tvm::AttrsNode { + DataType out_dtype; + + TVM_DECLARE_ATTRS(MatmulAttrs, "relax.attrs.MatmulAttrs") { + TVM_ATTR_FIELD(out_dtype).describe("The data type of the output tensor"); + } +}; // struct MatmulAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_ diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h new file mode 100644 index 000000000000..4aa51f2b73d4 --- /dev/null +++ b/include/tvm/relax/attrs/manipulate.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/manipulate.h + * \brief Attributes for tensor manipulation operators. + */ +#ifndef TVM_RELAX_ATTRS_MANIPULATE_H_ +#define TVM_RELAX_ATTRS_MANIPULATE_H_ + +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in concat operators */ +struct ConcatAttrs : public tvm::AttrsNode { + Optional axis; + + TVM_DECLARE_ATTRS(ConcatAttrs, "relax.attrs.ConcatAttrs") { + TVM_ATTR_FIELD(axis).describe( + "The axis at which the input arrays are concatenated." + "Should lie in range `[-ndim, ndim)`."); + } +}; // struct ConcatAttrs + +/*! \brief Attributes used in expand_dims operators */ +struct ExpandDimsAttrs : public tvm::AttrsNode { + Array axis; + + TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relax.attrs.ExpandDimsAttrs") { + TVM_ATTR_FIELD(axis).describe( + "The axes at which the input array are expanded. " + "All values are required to lie in range `[-data.ndim - 1, data.ndim]`, " + "with the convention of negative indexing."); + } +}; // struct ExpandDimsAttrs + +/*! \brief Attributes used in layout_transform operator */ +struct LayoutTransformAttrs : public tvm::AttrsNode { + tir::IndexMap index_map; + // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This + // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future. + Optional pad_value; + + TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") { + TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply."); + TVM_ATTR_FIELD(pad_value).describe( + "The specific value to be used to pad if the layout transform would result in implicit " + "padding. If not specified, the compiler is free to choose any value."); + } +}; // struct LayoutTransformAttrs + +/*! \brief Attributes used in permute_dims operator */ +struct PermuteDimsAttrs : public tvm::AttrsNode { + Optional> axes; + + TVM_DECLARE_ATTRS(PermuteDimsAttrs, "relax.attrs.PermuteDimsAttrs") { + TVM_ATTR_FIELD(axes).describe("The target axes order, reverse order if not specified."); + } +}; // struct PermuteDimsAttrs + +/*! \brief Attributes used in split operator */ +struct SplitAttrs : public tvm::AttrsNode { + ObjectRef indices_or_sections; + int axis; + + TVM_DECLARE_ATTRS(SplitAttrs, "relax.attrs.SplitAttrs") { + TVM_ATTR_FIELD(indices_or_sections) + .describe("The input array of indices or the number of split sections."); + TVM_ATTR_FIELD(axis).describe("The axis to be splitted"); + } +}; // struct SplitAttrs + +/*! \brief Attributes used in squeeze operators */ +struct SqueezeAttrs : public tvm::AttrsNode { + Optional> axis; + + TVM_DECLARE_ATTRS(SqueezeAttrs, "relax.attrs.SqueezeAttrs") { + TVM_ATTR_FIELD(axis).describe( + "The axis to squeeze in the input tensor." + "If `axis = None`, all axis of dimension 1 get squeezed;" + "Else, the dimension in axes get squeezed." + "It is an error if an axis does not has dimension 1."); + } +}; // struct SqueezeAttrs + +/*! \brief Attributes used in repeat operators */ +struct RepeatAttrs : public tvm::AttrsNode { + int repeats; + Optional axis; + + TVM_DECLARE_ATTRS(RepeatAttrs, "relax.attrs.RepeatAttrs") { + TVM_ATTR_FIELD(repeats).describe("The number of repetitions."); + TVM_ATTR_FIELD(axis).describe( + "The axis along which to repeat values. The negative numbers are interpreted " + "counting from the backward. By default, use the flattened input array, and " + "return a flat output array."); + } +}; // struct RepeatAttrs + +/*! \brief Attributes used in tile operators */ +struct TileAttrs : public tvm::AttrsNode { + Array repeats; + + TVM_DECLARE_ATTRS(TileAttrs, "relax.attrs.TileAttrs") { + TVM_ATTR_FIELD(repeats).describe("The number of repetitions of data along each axis."); + } +}; // struct TileAttrs + +/*! \brief Attributes used in cumsum operators */ +struct CumsumAttrs : public tvm::AttrsNode { + Optional axis; + DataType dtype; + + TVM_DECLARE_ATTRS(CumsumAttrs, "relax.attrs.CumsumAttrs") { + TVM_ATTR_FIELD(axis).describe( + "Axis along which the cumulative sum is computed." + "The default (None) is to compute the cumsum over the flattened array."); + TVM_ATTR_FIELD(dtype).describe( + "Type of the returned array and of the accumulator in which the elements are summed." + "If dtype is not specified, it defaults to the dtype of data."); + } +}; // struct CumsumAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_MANIPULATE_H_ diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h new file mode 100644 index 000000000000..bcfe3207bcef --- /dev/null +++ b/include/tvm/relax/attrs/nn.h @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/nn.h + * \brief Attributes for neural network operators. + */ +#ifndef TVM_RELAX_ATTRS_NN_H_ +#define TVM_RELAX_ATTRS_NN_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in Conv1d operator */ +struct Conv1DAttrs : public tvm::AttrsNode { + Array strides; + Array padding; + Array dilation; + int groups; + String data_layout; + String kernel_layout; + String out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv1DAttrs, "relax.attrs.Conv1DAttrs") { + TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on both sides" + "two int : padding width in the order of (left, right)"); + TVM_ATTR_FIELD(dilation).describe( + "Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).describe( + "Number of groups to split the input into for grouped convolution. The number of input and " + "output channels should be divisible by the number of groups."); + TVM_ATTR_FIELD(data_layout) + .describe( + "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, width" + "dimensions respectively. Convolution is applied on the 'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .describe( + "Dimension ordering of weight. Can be 'OIW', 'IOW', etc." + "'O', 'I', 'W' stands for num_filter, input_channel, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .describe( + "Dimension ordering of output. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(out_dtype).describe( + "Output data type, set to explicit type under mixed precision setting"); + } +}; // struct Conv1dAttrs + +/*! \brief Attributes used in Conv2d operator */ +struct Conv2DAttrs : public tvm::AttrsNode { + Array strides; + Array padding; + Array dilation; + int groups; + String data_layout; + String kernel_layout; + String out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv2DAttrs, "relax.attrs.Conv2DAttrs") { + TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation).describe( + "Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).describe( + "Number of groups to split the input into for grouped convolution. The number of input and " + "output channels should be divisible by the number of groups."); + TVM_ATTR_FIELD(data_layout) + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(out_dtype).describe( + "Output data type, set to explicit type under mixed precision setting"); + } +}; // struct Conv2dAttrs + +/*! \brief Attributes used in Conv2d operator */ +struct Conv2DTransposeAttrs : public tvm::AttrsNode { + Array strides; + Array padding; + Array output_padding; + Array dilation; + int groups; + String data_layout; + String kernel_layout; + String out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relax.attrs.Conv2DTransposeAttrs") { + TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(output_padding).describe("Used to disambiguate the output shape."); + TVM_ATTR_FIELD(dilation).describe( + "Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).describe( + "Number of groups to split the input into for grouped convolution. The number of input and " + "output channels should be divisible by the number of groups."); + TVM_ATTR_FIELD(data_layout) + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(out_dtype).describe( + "Output data type, set to explicit type under mixed precision setting"); + } +}; // struct Conv2DTransposeAttrs + +/*! \brief Attributes used in max_pool2d and avg_pool2d operator */ +struct Pool2DAttrs : public tvm::AttrsNode { + Array pool_size; + Array strides; + Array padding; + Array dilation; + bool ceil_mode; + String layout; + String out_layout; + + TVM_DECLARE_ATTRS(Pool2DAttrs, "relax.attrs.Pool2DAttrs") { + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the convolution."); + TVM_ATTR_FIELD(padding).describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(ceil_mode).describe( + "A boolean indicating if use ceil or floor to compute the output shape. By using ceil, " + "every element in the input tensor will be covered by a sliding window."); + TVM_ATTR_FIELD(layout).describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + } +}; // struct Pool2dAttrs + +/*! \brief Attributes for 2d adaptive pool operator */ +struct AdaptivePool2DAttrs : public tvm::AttrsNode { + Optional> output_size; + String layout; + String out_layout; + + TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relax.attrs.AdaptivePool2DAttrs") { + TVM_ATTR_FIELD(output_size).describe("Output height and width."); + TVM_ATTR_FIELD(layout).describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + } +}; // struct AdaptivePool2DAttrs + +/*! \brief Attributes used in softmax operators */ +struct SoftmaxAttrs : public tvm::AttrsNode { + int axis; + + TVM_DECLARE_ATTRS(SoftmaxAttrs, "relax.attrs.SoftmaxAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis to sum over when computing softmax."); + } +}; + +/*! \brief Attributes used in batch_norm operator */ +struct BatchNormAttrs : public tvm::AttrsNode { + int axis; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(BatchNormAttrs, "relax.attrs.BatchNormAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis along which the normalization is applied."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).describe( + "Indicating if the beta offset will be added to the normalized tensor."); + TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); + } +}; // struct BatchNormAttrs + +/*! \brief Attributes used in layer_norm operator */ +struct LayerNormAttrs : public tvm::AttrsNode { + Array axes; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(LayerNormAttrs, "relax.attrs.LayerNormAttrs") { + TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization is applied."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).describe( + "Indicating if the beta offset will be added to the normalized tensor."); + TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); + } +}; // struct LayerNormAttrs + +/*! \brief Attributes used in group_norm operator */ +struct GroupNormAttrs : public tvm::AttrsNode { + int num_groups; + int channel_axis; + Array axes; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(GroupNormAttrs, "relax.attrs.GroupNormAttrs") { + TVM_ATTR_FIELD(num_groups).describe("The number of groups to separate the channels into."); + TVM_ATTR_FIELD(channel_axis).describe("The axis that represents the channel."); + TVM_ATTR_FIELD(axes).describe( + "The axes that along which the normalization is applied (excluding the channel axis)."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).describe( + "Indicating if the beta offset will be added to the normalized tensor."); + TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); + } +}; // struct GroupNormAttrs + +/*! \brief Attributes used in dropout operator */ +struct DropoutAttrs : public tvm::AttrsNode { + double rate; + + TVM_DECLARE_ATTRS(DropoutAttrs, "relax.attrs.DropoutAttrs") { + TVM_ATTR_FIELD(rate).describe( + "Fraction of the input that gets dropped out during training time"); + } +}; // struct DropoutAttrs + +/*! \brief Attributes used in dropout operator */ +struct AttentionAttrs : public tvm::AttrsNode { + Optional scale; + + TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") { + TVM_ATTR_FIELD(scale).describe( + "The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim)."); + } +}; // struct AttentionAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_NN_H_ diff --git a/include/tvm/relax/attrs/search.h b/include/tvm/relax/attrs/search.h new file mode 100644 index 000000000000..f3854078f11e --- /dev/null +++ b/include/tvm/relax/attrs/search.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/search.h + * \brief Attributes for search operators. + */ +#ifndef TVM_RELAX_ATTRS_SEARCH_H_ +#define TVM_RELAX_ATTRS_SEARCH_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes for search operators */ +struct ArgmaxArgminAttrs : public tvm::AttrsNode { + Optional axis; + bool keepdims; + + TVM_DECLARE_ATTRS(ArgmaxArgminAttrs, "relax.attrs.ArgmaxArgminAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis along which to perform the argmin/argmax."); + TVM_ATTR_FIELD(keepdims).describe( + "If this is set to `True`, the reduced axis is left in the result as dimension with size " + "one."); + } +}; // struct ArgmaxArgminAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_SEARCH_H_ diff --git a/include/tvm/relax/attrs/statistical.h b/include/tvm/relax/attrs/statistical.h new file mode 100644 index 000000000000..bb1ab2195d9a --- /dev/null +++ b/include/tvm/relax/attrs/statistical.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/statistical.h + * \brief Attributes for statistical operators. + */ +#ifndef TVM_RELAX_ATTRS_STATISTICAL_H_ +#define TVM_RELAX_ATTRS_STATISTICAL_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes for statistical operators */ +struct StatisticalAttrs : public tvm::AttrsNode { + Optional> axis; + bool keepdims; + + TVM_DECLARE_ATTRS(StatisticalAttrs, "relax.attrs.StatisticalAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis or axes along which to perform the reduction."); + TVM_ATTR_FIELD(keepdims).describe( + "If this is set to `True`, the reduced axes are left in the result as dimension with size " + "one."); + } +}; // struct StatisticalAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_STATISTICAL_H_ diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h new file mode 100644 index 000000000000..2fb11f5a6f83 --- /dev/null +++ b/include/tvm/relax/backend.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/backend.h + * \brief Relax backend specific transformation passes. + */ +#ifndef TVM_RELAX_BACKEND_H_ +#define TVM_RELAX_BACKEND_H_ + +#include + +namespace tvm { +namespace relax { +namespace transform { + +/*! + * \brief Perform builtin lowering to map most of the op to VM builtin functions. + * + * \return The Pass. + */ +TVM_DLL Pass VMBuiltinLower(); + +/*! + * \brief Lower the shape expression in relax to VM shape heap and TIR functions. + * + * \return The Pass. + */ +TVM_DLL Pass VMShapeLower(); + +} // namespace transform +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_H_ diff --git a/include/tvm/relax/binding_rewrite.h b/include/tvm/relax/binding_rewrite.h new file mode 100644 index 000000000000..3f3d4d047dc6 --- /dev/null +++ b/include/tvm/relax/binding_rewrite.h @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/binding_rewrite.h + * \brief An IR rewriter to easily add/remove/replace bindings (statements). + */ + +#ifndef TVM_RELAX_BINDING_REWRITE_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Statement rewriter for relax.DataflowBlock. */ +class DataflowBlockRewriteNode : public Object { + public: + /*! \brief Replace all uses of old_var with new_var. */ + void ReplaceAllUses(Var old_var, Var new_var); + /*! \brief Insert a Binding statement. */ + void Add(Binding binding); + /*! \brief Insert an expression as VarBinding with variable name. */ + void Add(String var_name, Expr expr, bool is_dfvar = false) { + auto var = is_dfvar ? DataflowVar(var_name, GetStructInfo(expr)) // + : Var(var_name, GetStructInfo(expr)); + Add(VarBinding(std::move(var), std::move(expr))); + } + /*! \brief Insert an expression as VarBinding with automatic variable name. */ + void Add(Expr expr, bool is_dfvar = false) { + Add(name_table_.GetUniqueName("tmp"), expr, is_dfvar); + } + /*! \brief Remove the definition statement of an unused variable. */ + void RemoveUnused(Var unused, bool allow_undef = false); + /*! \brief Remove the definition statements of all unused variables. */ + void RemoveAllUnused(); + + /*! \brief The rewritten dataflow block. */ + DataflowBlock MutatedDataflowBlock() { return dfb_; } + /*! \brief The rewritten function. */ + Function MutatedFunc() { return root_fn_.value(); } + /*! \brief The rewritten IRModule. */ + IRModule MutateIRModule(IRModule irmod); + + /*! \brief Visit attributes. */ + void VisitAttrs(AttrVisitor* v) { + v->Visit("dfb", &dfb_); + v->Visit("root_fn", &root_fn_); + } + + static constexpr const char* _type_key = "relax.DataflowBlockRewrite"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockRewriteNode, Object); + + protected: + friend class DataflowBlockRewrite; + + DataflowBlock dfb_; //!< The rewritten dataflow block. + Optional root_fn_; //!< The rewritten function. + const FunctionNode* original_fn_ptr_; //!< Pointer to the original function. + Map> to_users_; //!< Map from variable to its users. + Array fn_outputs_; //!< Variables required by function outputs. + + private: + NameTable name_table_; //!< Name table for tracking and generating unique names. +}; + +/*! + * \brief A statement rewriter for relax.DataflowBlock. + * \sa DataflowBlockRewriteNode + */ +class DataflowBlockRewrite : public ObjectRef { + public: + TVM_DLL explicit DataflowBlockRewrite(DataflowBlock dfb, Function root_fn); + + /*! + * \brief mutable accessor. + * \return mutable access pointer. + */ + DataflowBlockRewriteNode* operator->() { + ICHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockRewrite, ObjectRef, DataflowBlockRewriteNode); +}; + +} // namespace relax +} // namespace tvm + +#define TVM_RELAX_BINDING_REWRITE_H_ +#endif // TVM_RELAX_BINDING_REWRITE_H_ diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h new file mode 100644 index 000000000000..7222ae08f956 --- /dev/null +++ b/include/tvm/relax/block_builder.h @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/block_builder.h + * \brief The utility for constructing Relax binding blocks. + */ +#ifndef TVM_RELAX_BLOCK_BUILDER_H_ +#define TVM_RELAX_BLOCK_BUILDER_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A builder to build Relax binding blocks. + * + * BlockBuilder provides the following three categories + * of main functionalities for IR building and transformations: + * + * - Global context management: manages the IRModule, + * allowing query, update the surrounding global context. + * Provide context tools for analysis. + * - Scope management: + * - Manages block scopes for bulding nested blocks. + * - Emit bindings to the current scope. + * - Construct blocks by calling EndScope. + * - Normalization: Take an Expr, normalize it + * to deduce shape/type, turn things into normal forms. + * + * Importantly, these three categories of features can be dependent + * on each other. For example, when we emit into scope we will call + * normalize to ensure the code is in normal form. Similarly, when we + * normalize we could choose to emit into the current context. + * + * We would encourage the developers to keep these three category + * in mind when using and developing BlockBuilder, we can group + * the code in a logically clean way. + * + * BlockBuilderNode is implemented as a virtual interface to + * allow logically grouped implementation and internal data + * structures that are hidden from the users. + */ +class BlockBuilderNode : public Object { + public: + //------------------------------- + // Global Context management + //------------------------------- + /*! + * \brief Get the name table for generating unique names. + * + * \return The name table. + */ + virtual NameTable* name_table() = 0; + + /*! + * \brief Get the context IRModule in this builder. + * + * \note The context + * \return The IRModule in this BlockBuilder. + */ + virtual IRModule GetContextIRModule() const = 0; + + /*! + * \brief Add a Relax function or a TIR PrimFunc to internal context module. + * \param func The function to be added. + * \param func_name_hint The name hint of the function to be added. + * \note If the function to be added already exists, return its + * GlobalVar directly. + * \return The global var bound to the added function. + */ + virtual GlobalVar AddFunction(const BaseFunc& func, String func_name_hint) = 0; + + /*! + * \brief Update a Relax function or a TIR PrimFunc in the internal context module. + * \param gv The global var referring the function to be updated. + * \param function The updated function. + */ + virtual void UpdateFunction(const GlobalVar& gv, BaseFunc function) = 0; + + /*! + * \brief Report an error during transformation construction. + * \param diagnostic The diagnostic information. + */ + virtual void ReportFatal(const Diagnostic& diagnostic) = 0; + + //------------------------------- + // Scope management + //------------------------------- + /*! + * \brief Lookup the binding value that var binds to in the current emitted sequences. + * \param var The input var. + * \return The Expr bound to the input \p var. + * \note For function parameters, this function returns NullOpt. + */ + virtual Optional LookupBinding(const Var& var) = 0; + + /*! + * \brief Begin a new scope, with optional parameters that + * are visible within the scope. + * + * \param params Parameters that are visible within the scope. + * + * \note This function should be called when new scope is introduced + * (function, seq) to properly track the variable availability + * and help the best effort deduction. + * + * \sa EndScope + */ + virtual void BeginScope(Optional> params) = 0; + + /*! \brief End the previously defined scope. */ + virtual void EndScope() = 0; + + /*! \brief Begin to build a DataflowBlock. */ + virtual void BeginDataflowBlock() = 0; + + /*! \brief Begin to build a BindingBlock. */ + virtual void BeginBindingBlock() = 0; + /*! + * \brief End building a BindingBlock. + * \return The BindingBlock being built. + */ + virtual BindingBlock EndBlock() = 0; + + /*! + * \brief Check if the block being built is DataflowBlock or not. + * \return A boolean that indicates if the block being built is DataflowBlock or not. + */ + virtual bool CurrentBlockIsDataFlow() = 0; + + /*! + * \brief Emits an Expr, and returns the variable it is bound to. + * \param expr The Expr to be emitted. + * \param name_hint Name hint for the bound variable. + * \return The new variable that \p expr is bound to. + * + * \note This Emit function normalizes the \p expr, and + * performs shape and type deductions by calling Normalize. + */ + virtual Var Emit(Expr expr, String name_hint = "") = 0; + + /*! + * \brief Emit a MatchCast. + * \param value The input value. + * \param struct_info The struct info to be matched. + * \param name_hint Name hint for the bound variable. + * \return The variable bound to the MatchCast. + */ + virtual Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint = "") = 0; + + /*! + * \brief Generate an output for the current dataflow block. + * \param output The output variable of the block. + * \param name_hint Name hint for the bound variable. + * \return The variable bound to \p output. + */ + virtual Var EmitOutput(Expr output, String name_hint = "") = 0; + + /*! + * \brief Emit a binding that is already normalized. + * + * \param normalized_binding A binding whose value is already normalized. + * + * \note This function requires binding to be pre-normalized. + */ + virtual void EmitNormalized(Binding normalized_binding) = 0; + + /*! + * \brief Convert an expression to normal form, and try to eagerly infer types and shapes. + * \param expr The input expression. + * \return The normalized expression. + * + * \note Invariant: If any of the sub expr have struct_info field. + * they must have already been normalized. + */ + virtual Expr Normalize(const Expr& expr) = 0; + + /*! + * \brief Normalize argument to a call or another IRNode. + * \param expr The input expression. + * \return The normalized expression. + * + * \note This function will create a binding var for non-leaf expressions such as Call. + */ + virtual Expr NormalizeArgument(const Expr& expr) = 0; + + /*! + * \brief Get the analyzer of the BlockBuilder. + * \return The BlockBuilder's arithmetic analyzer. + */ + virtual arith::Analyzer* GetAnalyzer() = 0; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.BlockBuilder"; + TVM_DECLARE_BASE_OBJECT_INFO(BlockBuilderNode, Object); +}; + +class BlockBuilder : public ObjectRef { + public: + /*! + * \brief Create a BlockBuilder. + * + * \param ctx_mod Optional before-transformation context module for rewriting. + * \return The created BlockBuilder. + * + * \note When rewriting an existing IRModule, it is important to pass it in as + * ctx_mod so you can lookup the context functions for cross function + * call analysis. + */ + TVM_DLL static BlockBuilder Create(Optional ctx_mod); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BLOCK_BUILDER_H_ diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h new file mode 100644 index 000000000000..e4268be882d7 --- /dev/null +++ b/include/tvm/relax/dataflow_matcher.h @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_matcher.h + * \brief A pattern matcher for matching dataflow properties. + */ +#ifndef TVM_RELAX_DATAFLOW_MATCHER_H_ +#define TVM_RELAX_DATAFLOW_MATCHER_H_ + +#include +#include + +#include + +namespace tvm { +namespace relax { + +/** + * \brief Determine if a pattern matches an expression. + * \note The behavior of MatchExpr is to match a relax.Expr (`expr`) syntactically through + * one given pattern (`pattern`). + * + * \param pattern The pattern to match + * \param expr The expression to match + * \param bindings The mapping from relax.Var to relax.Expr + * \return true if matched + * \return false if unmatched + */ +bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings = NullOpt); + +/* \brief Similar to above, but return pairs of a matching pattern and an expression. */ +Optional> ExtractMatchedExpr( + DFPattern pattern, Expr expr, Optional> bindings = NullOpt); + +/** + * \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping. + * \note This algorithm returns the first matched sub-graph. Use `start_hint` to specify the + * starting point of the matching so that we can distinguish multiple matches. + * + * \param ctx The graph-wise patterns. + * \param dfb The function to match. + * \param start_hint The starting point expression to match to distinguish multiple matches. + * \param must_include_hint If start_hint is given, the return pattern must include start_hint. + * \return Matched patterns and corresponding bound variables + */ +TVM_DLL Optional> MatchGraph(const PatternContext& ctx, + const DataflowBlock& dfb, + Optional start_hint = NullOpt, + bool must_include_hint = false); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_DATAFLOW_MATCHER_H_ diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h new file mode 100644 index 000000000000..e4c27f3558ba --- /dev/null +++ b/include/tvm/relax/dataflow_pattern.h @@ -0,0 +1,830 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_pattern.h + * \brief A pattern language for matching dataflow properties. + */ +#ifndef TVM_RELAX_DATAFLOW_PATTERN_H_ +#define TVM_RELAX_DATAFLOW_PATTERN_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class PatternSeq; +class CallPattern; +class OrPattern; +class AndPattern; +class NotPattern; +class ShapePattern; +class TypePattern; +class DataTypePattern; +class AttrPattern; + +/*! + * \brief Create used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned. + * + * \param lhs Left hand side of the used-by relationship. + * \param rhs Right hand side of the used-by relationship. + * \param index lhs[-1] is used as the index'th argument of rhs[0]. + * \return PatternSeq The concatenated sequence of [*lhs, *rhs]. + */ +TVM_DLL PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index = -1); +/*! \brief Syntax sugar of UsedBy(lhs, rhs, -1). */ +TVM_DLL PatternSeq operator^(const PatternSeq& lhs, const PatternSeq& rhs); + +/*! + * \brief Create only-used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned. + * + * \param lhs Left hand side of the used-by relationship. + * \param rhs Right hand side of the used-by relationship. + * \param index lhs[-1] is used as the index'th argument of rhs[0]. + * \return PatternSeq The concatenated sequence of [*lhs, *rhs]. + */ +TVM_DLL PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index = -1); +/*! \brief Syntax sugar of OnlyUsedBy(lhs, rhs, -1). */ +TVM_DLL PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs); + +/*! + * \brief Base type of all dataflow patterns. + * \sa DFPattern + */ +class DFPatternNode : public Object { + public: + static constexpr const char* _type_key = "DFPatternNode"; + TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object); +}; + +/*! + * \brief Managed reference to dataflow patterns. + * \sa DFPatternNode + */ +class DFPattern : public ObjectRef { + public: + /*! \brief Syntatic Sugar for creating a CallPattern */ + template + CallPattern operator()(Args&&... args) const; + /*! \brief Syntatic Sugar for creating a CallPattern */ + TVM_DLL CallPattern operator()(const std::vector& args) const; + /*! \brief Syntatic Sugar for creating an OrPattern */ + TVM_DLL OrPattern operator|(const DFPattern& other) const; + /*! \brief Syntatic Sugar for creating an AndPattern */ + TVM_DLL AndPattern operator&(const DFPattern& other) const; + /*! \brief Syntatic Sugar for creating a NotPattern */ + TVM_DLL NotPattern operator~() const; + /*! \brief Syntatic Sugar for creating an AttrPattern */ + TVM_DLL AttrPattern HasAttr(const Map& attrs) const; + /*! \brief Syntatic Sugar for creating a TypePattern */ + TVM_DLL TypePattern HasType(const Type& type) const; + /*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */ + TVM_DLL DataTypePattern HasDtype(const DataType& dtype) const; + /*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */ + TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const; + /*! \brief Syntatic Sugar for creating a ShapePattern */ + TVM_DLL ShapePattern HasShape(const Array& shape) const; + /*! \brief Syntatic Sugar for duplicating the current pattern */ + TVM_DLL DFPattern dup() const; + + /*! \brief Implicit conversion from DFPattern to PatternSeq */ + TVM_DLL operator PatternSeq() const; + + TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode); +}; + +/*! \brief Constraint of a DFPattern edge (producer -> consumer) in graph-level matching */ +struct PairCons { + /*! \brief Constraint types of the edge */ + enum Type { + kUsedBy, /*!< producer ^ consumer */ + kOnlyUsedBy, /*!< producer >> consumer */ + } type = kUsedBy; + int index = -1; /*!< The argument index of the producer in the consumer caller site */ + + /*! + * \brief Construct a new PairCons object + * + * \param t The constraint type + * \param index The producer is called as the index'th argument of the consumer function. + */ + TVM_DLL explicit PairCons(Type t, int index = -1) : type(t), index(index) {} + + bool operator==(const PairCons& other) const { + return type == other.type && index == other.index; + } +}; + +/*! + * \brief A sequence of DFPatterns that the previous DFPattern is connected to the next one. + * \sa PatternSeq + */ +class PatternSeqNode final : public Object { + public: + tvm::Array patterns; /*!< The sequence of DFPatterns */ + std::vector pair_constraints; /*!< Constraints between the previous and next patterns */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("patterns", &patterns); } + static constexpr const char* _type_key = "relax.dpl.PatternSeq"; + TVM_DECLARE_BASE_OBJECT_INFO(PatternSeqNode, Object); +}; + +/*! + * \brief Managed reference to pattern sequences. + * \sa PatternSeqNode + */ +class PatternSeq final : public ObjectRef { + public: + TVM_DLL explicit PatternSeq(DFPattern init_pattern); + TVM_DLL explicit PatternSeq(tvm::Array patterns, bool only_used_by = false); + + PatternSeq UsedBy(PatternSeq other, int index = -1) const; + PatternSeq OnlyUsedBy(PatternSeq other, int index = -1) const; + + /*! \brief Syntatic Sugar for duplicating the current pattern sequence */ + PatternSeq dup() const; + + // friend functions + friend PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index); + friend PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index); + + TVM_DEFINE_OBJECT_REF_METHODS(PatternSeq, ObjectRef, PatternSeqNode); +}; + +/*! + * \brief A context to manage the graph-level pattern matching. + * \sa PatternContext + */ +class PatternContextNode : public Object { + public: + /*! \brief Constrainting matched graph with assertion to external uses */ + enum ExternUse { + kMay, /*!< No constraints */ + kMustNot, /*!< All nodes except outputs only have internal depedencies in the matched graph. */ + } allow_extern_use = kMay; + // src node -> constraints. + // Dst nodes are kept in a vector to keep them ordered. + std::map>>> constraints; + // Keep a separate vector of patterns to process constraints in a fixed order. + std::vector src_ordered; + + static constexpr const char* _type_key = "relax.dpl.PatternContext"; + TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextNode, Object); +}; + +/*! + * \brief Managed reference to a pattern context. + * \sa PatternContextNode + */ +class PatternContext : public ObjectRef { + public: + TVM_DLL explicit PatternContext(ObjectPtr n) : ObjectRef(n) {} + TVM_DLL explicit PatternContext(bool incremental = false); + + const PatternContextNode* operator->() const { + ICHECK(get() != nullptr); + return static_cast(get()); + } + + PatternContextNode* operator->() { + ICHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + /*! + * \brief Build an edge constraint between two patterns (producer and consumer). + * + * \param producer The pattern corresponding to the producer node. + * \param consumer The pattern corresponding to the consumer node. + * \param cons The constraint type. \sa PairCons + */ + void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons) { + auto& pairs = (*this)->constraints[producer]; + auto it = std::find_if(pairs.begin(), pairs.end(), + [consumer](auto p) { return p.first == consumer; }); + if (it == pairs.end()) { + pairs.emplace_back(consumer, std::vector{cons}); + } else { + auto& vec = it->second; + ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend()) + << "Constraint already exists"; + vec.push_back(cons); + } + + auto& patterns = (*this)->src_ordered; + if (std::find(patterns.begin(), patterns.end(), producer) == patterns.end()) { + patterns.push_back(producer); + } + } + + /*! \brief Get the constraint context object on the top of the stack */ + TVM_DLL static Optional Current(); + + class Internal; + + private: + /*! \brief The RAII-like entry of a constraint context scope */ + TVM_DLL void EnterWithScope(); + /*! \brief The RAII-like exit of a constraint context scope */ + TVM_DLL void ExitWithScope(); + friend class Internal; + friend class With; +}; + +/*! + * \brief Pattern for Relax Expression. + * \sa ExprPattern + */ +class ExprPatternNode : public DFPatternNode { + public: + Expr expr; /*!< The expression to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); } + + static constexpr const char* _type_key = "relax.dpl.ExprPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to an ExprPattern. + * \sa ExprPatternNode + */ +class ExprPattern : public DFPattern { + public: + TVM_DLL explicit ExprPattern(Expr expr); + TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relax Variable. + * \note The name field matches any string if it is empty. + * \sa VarPattern + */ +class VarPatternNode : public DFPatternNode { + public: + String name; + const String& name_hint() const { return name; } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); } + + static constexpr const char* _type_key = "relax.dpl.VarPattern"; + TVM_DECLARE_BASE_OBJECT_INFO(VarPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a VarPattern. + * \sa VarPatternNode + */ +class VarPattern : public DFPattern { + public: + /*! + * \brief Create a pattern matching by variable name. + * + * \param name_hint Variable name to match. Any if empty (""). + */ + TVM_DLL VarPattern(String name_hint); + TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relax Dataflow Variable + * \sa DataflowVarPattern + */ +class DataflowVarPatternNode : public VarPatternNode { + public: + static constexpr const char* _type_key = "relax.dpl.DataflowVarPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a DataflowVarPattern. + * \sa DataflowVarPatternNode + */ +class DataflowVarPattern : public DFPattern { + public: + /*! \sa VarPattern::VarPattern */ + TVM_DLL DataflowVarPattern(String name_hint); + TVM_DEFINE_OBJECT_REF_METHODS(DataflowVarPattern, DFPattern, DataflowVarPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relax Global Variable + * \sa GlobalVarPattern + */ +class GlobalVarPatternNode : public VarPatternNode { + public: + static constexpr const char* _type_key = "relax.dpl.GlobalVarPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a GlobalVarPattern. + * \sa GlobalVarPatternNode + */ +class GlobalVarPattern : public DFPattern { + public: + TVM_DLL GlobalVarPattern(String name_hint); + TVM_DEFINE_OBJECT_REF_METHODS(GlobalVarPattern, DFPattern, GlobalVarPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relax Constant. + * \sa ConstantPattern + */ +class ConstantPatternNode : public DFPatternNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relax.dpl.ConstantPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a ConstantPattern. + * \sa ConstantPatternNode + */ +class ConstantPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode); +}; + +/*! + * \brief A pattern to match a callable node in Relax. + * \sa CallPattern + */ +class CallPatternNode : public DFPatternNode { + public: + /*! + * \note The op field can be: + * - relay::Op which corresponds to the primitive operators. + * - user defined functions (Function, GlobalVar, Var). + */ + DFPattern op; /*!< The operator (function) being invoked */ + tvm::Array args; /*!< The arguments of the function call */ + /*! + * \note If varg_default_wildcard is true. Given args of [pA, pB], when matching a call whose + * arguments are [A, B, ...], the pattern will still match despite N(args) < N(call.args). That + * said, with varg_default_wildcard set to true, we match the args in the order we have, and + * regard the rest of the arguments as wildcards. + */ + bool varg_default_wildcard; /*!< N(args) can be < N(real args) by the padding of Wildcard */ + + // Todo(relax-team): Dataflow pattern for StructInfo, and match sinfo_args + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + v->Visit("args", &args); + } + + static constexpr const char* _type_key = "relax.dpl.CallPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode); +}; + +class CallPattern : public DFPattern { + public: + TVM_DLL CallPattern(DFPattern op, Array args, bool varg_default_wildcard = false); + TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode); +}; + +/*! + * \brief A pattern to match an array of PrimExpr. + * \sa PrimArrPattern + * \note This is often used to match shapes specified as arguments to a function. + */ +class PrimArrPatternNode : public DFPatternNode { + public: + Array fields; /*!< The array to match */ + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + static constexpr const char* _type_key = "relax.dpl.PrimArrPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimArrPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a PrimArrPattern. + * \sa PrimArrPatternNode + */ +class PrimArrPattern : public DFPattern { + public: + TVM_DLL PrimArrPattern(Array arr); + TVM_DEFINE_OBJECT_REF_METHODS(PrimArrPattern, DFPattern, PrimArrPatternNode); +}; + +/*! + * \brief A pattern to match a Relax Function + * \sa Function + * \sa FunctionPattern + */ +class FunctionPatternNode : public DFPatternNode { + public: + tvm::Array params; /*!< The parameters of the function */ + /*! + * \note Note that in Relax, the function body is a SeqExpr which contains + * 1) SeqExprNode::blocks, which is a list of blocks of statements; and 2) + * SeqExprNode::body, which is an Expr that can be anything. FunctionPattern + * only matches the body of the function (writing patterns to statements is tricky). + */ + DFPattern body; /*!< The body of the function */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "relax.dpl.FunctionPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to FunctionPatternNode. + * \sa FunctionPatternNode + */ +class FunctionPattern : public DFPattern { + public: + /*! + * \brief Constructor + * \param params The parameters of the function. + * \param body The body of the function. + */ + TVM_DLL FunctionPattern(tvm::Array params, DFPattern body); + + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, FunctionPatternNode); +}; + +/*! + * \brief Pattern to match a tuple of ordered expressions. + * \sa TuplePattern + */ +class TuplePatternNode : public DFPatternNode { + public: + tvm::Array fields; /*!< The fields of the tuple */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + + static constexpr const char* _type_key = "relax.dpl.TuplePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to TuplePatternNode. + * \sa TuplePatternNode + */ +class TuplePattern : public DFPattern { + public: + TVM_DLL explicit TuplePattern(tvm::Array fields); + TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode); +}; + +/*! + * \brief A pattern to match multiple expressions unorderedly. + * \sa UnorderedTuplePattern + */ +class UnorderedTuplePatternNode : public DFPatternNode { + public: + tvm::Array fields; /*!< The fields of the tuple */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + + static constexpr const char* _type_key = "relax.dpl.UnorderedTuplePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(UnorderedTuplePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to UnorderedTuplePatternNode. + * \sa UnorderedTuplePatternNode + */ +class UnorderedTuplePattern : public DFPattern { + public: + TVM_DLL explicit UnorderedTuplePattern(tvm::Array fields); + TVM_DEFINE_OBJECT_REF_METHODS(UnorderedTuplePattern, DFPattern, UnorderedTuplePatternNode); +}; + +/*! + * \brief A pattern to match n'th indexing to a tuple. + * \sa TupleGetItem + * \sa TupleGetItemPattern + */ +class TupleGetItemPatternNode : public DFPatternNode { + public: + DFPattern tuple; /*!< The tuple Expression */ + int index; /*!< The index of the tuple with -1 meaning arbitrary */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tuple", &tuple); + v->Visit("index", &index); + } + + static constexpr const char* _type_key = "relax.dpl.TupleGetItemPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to TupleGetItemPatternNode. + * \sa TupleGetItemPatternNode + */ +class TupleGetItemPattern : public DFPattern { + public: + TVM_DLL TupleGetItemPattern(DFPattern tuple, int index); + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode); +}; + +/*! + * \brief Match a conjunction of other patterns. + * \sa AndPattern + */ +class AndPatternNode : public DFPatternNode { + public: + DFPattern left; /*!< The left hand side of the conjunction */ + DFPattern right; /*!< The right hand side of the conjunction */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("left", &left); + v->Visit("right", &right); + } + + static constexpr const char* _type_key = "relax.dpl.AndPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(AndPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to AndPatternNode. + * \sa AndPatternNode + */ +class AndPattern : public DFPattern { + public: + TVM_DLL AndPattern(DFPattern lhs, DFPattern rhs); + TVM_DEFINE_OBJECT_REF_METHODS(AndPattern, DFPattern, AndPatternNode); +}; + +/*! + * \brief Match a disjunction of other patterns. + * \sa OrPattern + */ +class OrPatternNode : public DFPatternNode { + public: + DFPattern left; /*!< The left hand side of the disjunction */ + DFPattern right; /*!< The right hand side of the disjunction */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("left", &left); + v->Visit("right", &right); + } + + static constexpr const char* _type_key = "relax.dpl.OrPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(OrPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to OrPatternNode. + * \sa OrPatternNode + */ +class OrPattern : public DFPattern { + public: + TVM_DLL OrPattern(DFPattern left, DFPattern right); + TVM_DEFINE_OBJECT_REF_METHODS(OrPattern, DFPattern, OrPatternNode); +}; + +/*! + * \brief Pattern for rejecting a certain pattern. + * \sa NotPattern + */ +class NotPatternNode : public DFPatternNode { + public: + DFPattern reject; /*!< The pattern to reject */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("reject", &reject); } + + static constexpr const char* _type_key = "relax.dpl.NotPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(NotPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to NotPatternNode. + * \sa NotPatternNode + */ +class NotPattern : public DFPattern { + public: + TVM_DLL NotPattern(DFPattern reject); + TVM_DEFINE_OBJECT_REF_METHODS(NotPattern, DFPattern, NotPatternNode); +}; + +/*! + * \brief Wildcard Pattern is a pattern that can match anything. + * \sa WildcardPattern + */ +class WildcardPatternNode : public DFPatternNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relax.dpl.WildcardPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to WildcardPatternNode. + * \sa WildcardPatternNode + */ +class WildcardPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); +}; + +/*! + * \brief Pattern for matching a certain type. + * \sa TypePattern + */ +class TypePatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The pattern to match */ + Type type; /*!< The type to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("type", &type); + } + + static constexpr const char* _type_key = "relax.dpl.TypePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to TypePatternNode. + * \sa TypePatternNode + */ +class TypePattern : public DFPattern { + public: + TVM_DLL TypePattern(DFPattern pattern, Type type); + TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode); +}; + +/*! + * \brief A pattern that asserting a root pattern has a certain shape. + * \sa ShapePattern + */ +class ShapePatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The root pattern to match */ + Array shape; /*!< The shape to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("shape", &shape); + } + + static constexpr const char* _type_key = "relax.dpl.ShapePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to ShapePatternNode. + * \sa ShapePatternNode + */ +class ShapePattern : public DFPattern { + public: + TVM_DLL ShapePattern(DFPattern pattern, Array type); + TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode); +}; + +/*! + * \brief A pattern that asserting a root pattern has a certain data type. + * \sa DataTypePattern + */ +class DataTypePatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The root pattern to match */ + DataType dtype; /*!< The data type to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("dtype", &dtype); + } + + static constexpr const char* _type_key = "relax.dpl.DataTypePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to DataTypePatternNode. + * \sa DataTypePatternNode + */ +class DataTypePattern : public DFPattern { + public: + TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype); + TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode); +}; + +/*! + * \brief A pattern that asserting a root pattern has certain attributes. + * \sa AttrPattern + */ +class AttrPatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The root pattern to match */ + DictAttrs attrs; /*!< The attributes (a map/dictionary) to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("attrs", &attrs); + } + + static constexpr const char* _type_key = "relax.dpl.AttrPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to AttrPatternNode. + * \sa AttrPatternNode + */ +class AttrPattern : public DFPattern { + public: + TVM_DLL AttrPattern(DFPattern pattern, DictAttrs attrs); + TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode); +}; + +/*! + * \brief A pattern of external function. + * \sa ExternFunc + * \sa ExternFuncPattern + */ +class ExternFuncPatternNode : public DFPatternNode { + public: + String global_symbol_; /*!< The global symbol name of the external function */ + + /*! \brief The the external function name */ + const String& global_symbol() const { return global_symbol_; } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("global_symbol", &global_symbol_); } + + static constexpr const char* _type_key = "relax.dpl.ExternFuncPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to ExternFuncPatternNode. + * \sa ExternFuncPatternNode + */ +class ExternFuncPattern : public DFPattern { + public: + TVM_DLL ExternFuncPattern(String global_symbol); + TVM_DEFINE_OBJECT_REF_METHODS(ExternFuncPattern, DFPattern, ExternFuncPatternNode); +}; + +/*! \brief Syntatic Sugar for creating a VarPattern with a name */ +VarPattern IsVar(const String& name); +/*! \brief Syntatic Sugar for creating a ConstantPattern */ +ConstantPattern IsConst(); +/*! \brief Syntatic Sugar for creating a WildcardPattern */ +WildcardPattern Wildcard(); +/*! \brief Syntatic Sugar for creating a ExprPattern */ +ExprPattern IsExpr(const Expr& expr); +/*! \brief Syntatic Sugar for creating a ExprPattern base on an Op */ +ExprPattern IsOp(const String& op_name); +/*! \brief Syntatic Sugar for call_tir (return a tensor) */ +// Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo +CallPattern IsCallTIR(const String& name, Optional args = NullOpt); +/*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */ +CallPattern IsCallTIR(const String& name, TuplePattern var_args); +/*! \brief Syntatic Sugar for call_dps_packed (return a tensor) */ +CallPattern IsCallDPSPacked(const String& name, Optional args = NullOpt); +/*! \brief Syntatic Sugar for call_dps_packed (return a tuple of tensor) */ +CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args); +/*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */ +DFPattern IsTuple(const Array& fields, bool unordered = false); +/*! \brief Syntatic Sugar for creating a TupleGetItemPattern */ +TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index = -1); + +/*! \brief Implementation of the templated CallPattern syntax sugar */ +template +CallPattern DFPattern::operator()(Args&&... args) const { + return CallPattern(GetRef(this->get()), + Array({std::forward(args)...})); +} + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_DATAFLOW_PATTERN_H_ diff --git a/include/tvm/relax/dataflow_pattern_functor.h b/include/tvm/relax/dataflow_pattern_functor.h new file mode 100644 index 000000000000..983881ddc9a7 --- /dev/null +++ b/include/tvm/relax/dataflow_pattern_functor.h @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_pattern_functor.h + * \brief Functors and visitors for dataflow patterns. + */ +#ifndef TVM_RELAX_DATAFLOW_PATTERN_FUNCTOR_H_ +#define TVM_RELAX_DATAFLOW_PATTERN_FUNCTOR_H_ + +#include + +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A dynamical functor that dispatches on in the first DFPattern argument. + * + * \tparam FType function signature + * This type is only defined for FType with function signature R(const DFPattern&, + * Args...) + */ +template +class DFPatternFunctor; + +// functions to be overriden. +#define DFPATTERN_FUNCTOR_DEFAULT \ + { return VisitDFPatternDefault_(op, std::forward(args)...); } + +#define RELAX_DFPATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitDFPattern_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class DFPatternFunctor { + private: + using TSelf = DFPatternFunctor; + using FType = tvm::NodeFunctor; + + public: + /*! \brief virtual destructor */ + virtual ~DFPatternFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const DFPattern& n, Args... args) { + return VisitDFPattern(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitDFPattern(const DFPattern& n, Args... args) { + ICHECK(n.defined()); + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitDFPattern_(const OrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const AndPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const NotPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + + virtual R VisitDFPattern_(const DataflowVarPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const GlobalVarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ExternFuncPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const PrimArrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const UnorderedTuplePatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + + virtual R VisitDFPatternDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAX_DFPATTERN_FUNCTOR_DISPATCH(OrPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(AndPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(NotPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); + + RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(GlobalVarPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExternFuncPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(PrimArrPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(UnorderedTuplePatternNode); + return vtable; + } +}; + +/*! + * \brief A simple visitor wrapper around DFPatternFunctor. + * Recursively visit the content. + * + * DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once. + */ +class DFPatternVisitor : public DFPatternFunctor { + public: + void VisitDFPattern(const DFPattern& pattern) override; + void VisitDFPattern_(const OrPatternNode* op) override; + void VisitDFPattern_(const AndPatternNode* op) override; + void VisitDFPattern_(const NotPatternNode* op) override; + void VisitDFPattern_(const AttrPatternNode* op) override; + void VisitDFPattern_(const CallPatternNode* op) override; + void VisitDFPattern_(const ConstantPatternNode* op) override; + void VisitDFPattern_(const DataTypePatternNode* op) override; + void VisitDFPattern_(const ExprPatternNode* op) override; + void VisitDFPattern_(const FunctionPatternNode* op) override; + void VisitDFPattern_(const ShapePatternNode* op) override; + void VisitDFPattern_(const TupleGetItemPatternNode* op) override; + void VisitDFPattern_(const TuplePatternNode* op) override; + void VisitDFPattern_(const TypePatternNode* op) override; + void VisitDFPattern_(const WildcardPatternNode* op) override; + void VisitDFPattern_(const VarPatternNode* op) override; + + void VisitDFPattern_(const DataflowVarPatternNode* op) override; + void VisitDFPattern_(const GlobalVarPatternNode* op) override; + void VisitDFPattern_(const ExternFuncPatternNode* op) override; + void VisitDFPattern_(const PrimArrPatternNode* op) override; + void VisitDFPattern_(const UnorderedTuplePatternNode* op) override; + + protected: + // set of already-visited nodes + std::unordered_set visited_; +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_DATAFLOW_PATTERN_FUNCTOR_H_ diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h new file mode 100644 index 000000000000..03e58392c269 --- /dev/null +++ b/include/tvm/relax/exec_builder.h @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/exec_builder.h + */ +#ifndef TVM_RELAX_EXEC_BUILDER_H_ +#define TVM_RELAX_EXEC_BUILDER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { + +namespace vm = tvm::runtime::relax_vm; + +class ExecBuilder; + +/*! + * \brief A builder provides api to build VM executable with instructions. + */ +class ExecBuilderNode : public Object { + public: + /*! + * \brief Declare a function, it is OK to have multiple declarations. + * \param func The function name. + * \param kind The kind of the function. + */ + void DeclareFunction(const std::string& func, vm::VMFuncInfo::FuncKind kind); + /*! + * \brief To annotate the start of a vm function. + * \param func The function name. + * \param num_inputs The number of inputs. + * \param param_names The function parameter names. + * \param kind The kind of the function. + * \param init_register_size Initial setting of register file size. + */ + void EmitFunction(const std::string& func, int64_t num_inputs, + Optional> param_names, + vm::VMFuncInfo::FuncKind kind = vm::VMFuncInfo::FuncKind::kVMFunc, + int64_t init_register_size = 0); + /*! + * \brief Annotate the end of a vm function. + * \param func The function name. + */ + void EndFunction(const std::string& func); + /*! + * \brief Emit a call instruction for a packed function. + * \param func The packed function name. + * \param args The arguments of the function. + * \param ret The return register. + */ + void EmitCall(const std::string& func, std::vector args, vm::RegName ret); + /*! + * \brief Emit a call instruction with func as argument. + * \param func The packed function index. + * \param args The arguments of the function. + * \param ret The return register. + */ + void EmitCall(vm::Instruction::Arg func, std::vector args, vm::RegName ret); + /*! + * \brief Emit a ret instruction. + * \param result The return result. + * \note result must be a register. + */ + void EmitRet(vm::Instruction::Arg result); + /*! + * \brief Emit a goto instruction. + * \param pc_offset The program counter offset as the jump offset. + */ + void EmitGoto(vm::Index pc_offset); + /*! + * \brief Emit an If instruction. + * \param cond The register containing the cond value. + * \param false_offset The program counter offset for the false branch. + * \note result must be a register. + */ + void EmitIf(vm::Instruction::Arg cond, vm::Index false_offset); + /*! + * \brief Get function index by its name. + * \param name The name of the function. + * \return The argument corresponding to the function index. + */ + vm::Instruction::Arg GetFunction(const std::string& name); + /*! + * \brief Convert a constant value something that exec builder can understand. + * + * This function may update the constant pool to include the obj value. + * + * \param value The input constant value + * \return An Arg that represents the result of constant argument. + */ + template + vm::Instruction::Arg ConvertConstant(T value) { + TVMRetValue rv; + rv = value; + return ConvertConstant_(rv); + } + /*! + * \brief Raw access to underlying executable build in progress. + */ + vm::Executable* exec() const; + /*! + * \brief Finalize the build, run formalize and get the final result. + * \note This function should not be called during construction. + */ + ObjectPtr Get(); + /*! + * \brief Create an ExecBuilder. + * \return The ExecBuilder. + */ + TVM_DLL static ExecBuilder Create(); + + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.ExecBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExecBuilderNode, Object); + + private: + /*! + * \brief Convert a constant value something that exec builder can understand. + * + * This function may update the constant pool to include the obj value. + * + * \param obj The constant value to be emitted + * \return An Arg that represents the result of constant argument. + */ + vm::Instruction::Arg ConvertConstant_(TVMRetValue obj); + + /*! + * \brief A helper function to check if an executable is legal by checking if registers are used + * properly + */ + void CheckExecutable(); + /*! + * \brief Formalize the executable. + */ + void Formalize(); + + /*! \brief The mutable internal executable. */ + ObjectPtr exec_; // mutable + /*! \brief internal dedup map when creating index for a new constant */ + std::unordered_map const_dedup_map_; +}; + +class ExecBuilder : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ExecBuilder, ObjectRef, ExecBuilderNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_EXEC_BUILDER_H_ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h new file mode 100644 index 000000000000..0788193ee7c4 --- /dev/null +++ b/include/tvm/relax/expr.h @@ -0,0 +1,1039 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAX_EXPR_H_ +#define TVM_RELAX_EXPR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using Expr = RelayExpr; +using ExprNode = RelayExprNode; +/*! + * \brief The unique identifier of variables. + * + * Id is like name to the variables, + * except that id is unique for each Var. + * + * \note Do not create Id directly, they are created in Var. + */ +class IdNode : public Object { + public: + /*! + * \brief The name of the variable, + * this only acts as a hint to the user, + * and is not used for equality. + */ + String name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); } + + bool SEqualReduce(const IdNode* other, SEqualReducer equal) const { + return equal.FreeVarEqualImpl(this, other); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.FreeVarHashImpl(this); } + + static constexpr const char* _type_key = "relax.Id"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object); +}; + +class Id : public ObjectRef { + public: + /*! + * \brief The constructor + * \param name_hint The name of the variable. + */ + TVM_DLL explicit Id(String name_hint); + + TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode); +}; + +/*! + * \brief Base type of all structure information. + * + * StructInfo stores possible structure information + * deduced during compile-time. It encapsulates + * both static type and runtime information such + * as shape. + * + * StructInfo of each non-primitive Expr can be + * deduced during compilation in a "best-effort" manner. + * + * When struct_info appears in function parameter and return + * signatures. They will imply a runtime check that matches + * the structure information with the value. + * + * When it appears in Expr, they follow "assume-semantics", + * which means the compiler will take the deduced information as it is + * and only do best effort prove and checks. + * + * Each struct info can be uniquely erased to a static-type. + * The compiler will still compile the code(with less information) + * when we erase to the static type. + * + * If an StructInfo contains an Expr field, then that field + * must be normalized already through NormalizeArg. + * This invariant will be checked in constructors + * and help us to simplify our assumption + * during struct info deduction. + */ +class StructInfoNode : public Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + static constexpr const char* _type_key = "StructInfo"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 5; + TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object); +}; + +/*! + * \brief Managed reference to StructInfoNode. + * \sa StructInfoNode + */ +class StructInfo : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(StructInfo, ObjectRef, StructInfoNode); +}; + +/*! + * \brief Call corresponds to callable invocation. + * Corresponds to operation in computational graph terminology. + */ +class CallNode : public ExprNode { + public: + /*! + * \brief The operator(function) being invoked + * + * - It can be tvm::Op which corresponds to the primitive operators. + * - It can also be user defined functions (Function, GlobalVar, Var). + */ + Expr op; + + /*! \brief The arguments(inputs) of the call */ + tvm::Array args; + + /*! \brief The additional attributes */ + Attrs attrs; + + /*! + * \brief The structure info arguments of a CallNode. + * sinfo_args is designed to be non-empty only for intrinsic op (e.g., + * call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main + * usage of structure info inference. + */ + Array sinfo_args; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + v->Visit("args", &args); + v->Visit("attrs", &attrs); + v->Visit("sinfo_args", &sinfo_args); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { + // skip sinfo_args check for primitive ops. + equal->MarkGraphNode(); + return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && + equal(sinfo_args, other->sinfo_args) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(op); + hash_reduce(args); + hash_reduce(attrs); + hash_reduce(sinfo_args); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.Call"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); +}; + +class Call : public Expr { + public: + /*! + * \brief The constructor + * \param op The operator to be invoked. + * \param args The arguments of the call. + * \param attrs The attributes of the call node. + * \param sinfo_args The structure info arguments passed to a function. + * \param span The source span of the expression. + */ + TVM_DLL Call(Expr op, Array args, Attrs attrs = Attrs(), + Array sinfo_args = Array(), Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); +}; + +/*! + * \brief Returns \p call with the given properties. A null property denotes 'no change'. + * Returns \p call if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +Call WithFields(Call call, Optional opt_op = Optional(), + Optional> opt_args = Optional>(), + Optional opt_attrs = Optional(), + Optional> opt_sinfo_args = Optional>(), + Optional opt_span = Optional()); + +/*! + * \brief Condition expression + * + * Unlike traditional statement `if`s, the if evalutes + * to the result of the branch taken. + * + * x = if (true) { 1 } else { 0 }; // x is 1 + * y = if (false) { 1 } else { 0 }; // y is 0 + * + * \note This is similar to C's ternary operator. + */ +class IfNode : public ExprNode { + public: + /*! \brief The condition. */ + Expr cond; + /*! \brief The expression evaluated when condition is true. */ + Expr true_branch; + /*! \brief The expression evaluated when condition is false */ + Expr false_branch; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("cond", &cond); + v->Visit("true_branch", &true_branch); + v->Visit("false_branch", &false_branch); + v->Visit("_checked_type_", &checked_type_); + v->Visit("struct_info_", &struct_info_); + v->Visit("span", &span); + } + + bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(cond, other->cond) && equal(true_branch, other->true_branch) && + equal(false_branch, other->false_branch) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(cond); + hash_reduce(true_branch); + hash_reduce(false_branch); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.If"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); +}; + +class If : public Expr { + public: + /*! + * \brief The constructor + * \param cond The condition of a if node. + * \param true_branch The fall through branch + * \param false_branch The branch for execution when condition is false. + * \param span The source span of the expression. + */ + TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode); +}; + +/*! + * \brief Returns \p if_expr with the given properties. A null property denotes 'no change'. + * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +If WithFields(If if_expr, Optional opt_cond = Optional(), + Optional opt_true_branch = Optional(), + Optional opt_false_branch = Optional(), + Optional opt_span = Optional()); + +/*! \brief Tuple container */ +class TupleNode : public ExprNode { + public: + /*! \brief the fields of the tuple */ + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("fields", &fields); + v->Visit("_checked_type_", &checked_type_); + v->Visit("struct_info_", &struct_info_); + v->Visit("span", &span); + } + + bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from fields. + return equal(fields, other->fields); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } + + static constexpr const char* _type_key = "relax.expr.Tuple"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode); +}; + +class Tuple : public Expr { + public: + /*! + * \brief The constructor + * \param fields The fields of a tuple. + * \param span The source span of the expression. + */ + TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode); +}; + +/*! + * \brief Returns \p tuple with the given properties. A null property denotes 'no change'. + * Returns \p tuple if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +Tuple WithFields(Tuple tuple, Optional> opt_fields = Optional>(), + Optional opt_span = Optional()); + +/*! \brief Get index-th field out of a tuple. */ +class TupleGetItemNode : public ExprNode { + public: + /*! \brief The tuple Expression */ + Expr tuple; + /*! \brief which value to get */ + int index; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tuple_value", &tuple); + v->Visit("index", &index); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const { + // struct info can be deterministically tuple and index. + return equal(tuple, other->tuple) && equal(index, other->index); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(tuple); + hash_reduce(index); + } + + static constexpr const char* _type_key = "relax.expr.TupleGetItem"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); +}; + +class TupleGetItem : public Expr { + public: + /*! + * \brief The constructor + * \param tuple The tuple to get an element from. + * \param index The index for extracting a value in the tuple. + * \param span The source span of the expression. + */ + TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode); +}; + +/*! + * \brief Returns \p tuple_get_item with the given properties. A null property denotes 'no change'. + * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), + Optional opt_index = Optional(), + Optional opt_span = Optional()); + +/*! + * \brief Base type of all (non-function) leaf Exprs. + * \sa Expr + */ +class LeafExprNode : public ExprNode { + public: + static constexpr const char* _type_key = "relax.expr.LeafExpr"; + static constexpr const uint32_t _type_child_slots = 7; + TVM_DECLARE_BASE_OBJECT_INFO(LeafExprNode, ExprNode); +}; + +/*! + * \brief Managed reference to BaseExprNode. + * \sa LeafExprNode + */ +class LeafExpr : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(LeafExpr, Expr, LeafExprNode); +}; + +/*! \brief A shape expression which allows users to construct a shape containing PrimExpr. + */ +class ShapeExprNode : public LeafExprNode { + public: + /*! The values of the shape expression. */ + Array values; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("values", &values); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from values. + return equal(values, other->values); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(values); } + + static constexpr const char* _type_key = "relax.expr.ShapeExpr"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeExprNode, LeafExprNode); +}; + +class ShapeExpr : public LeafExpr { + public: + TVM_DLL explicit ShapeExpr(Array values, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, LeafExpr, ShapeExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode); +}; + +/*! \brief The variable class for all Relax bindings. */ +class VarNode : public LeafExprNode { + public: + /*! \brief The identifier of the variable, which is used for comparing stable equality across + * transformations. */ + Id vid; + + /*! \return The name hint of the variable */ + const String& name_hint() const { return vid->name_hint; } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("vid", &vid); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(vid, other->vid) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(vid); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.Var"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 2; + TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode); +}; + +class Var : public LeafExpr { + public: + TVM_DLL explicit Var(String name_hint, Optional struct_info_annotation, + Span span = Span()) + : Var(Id(name_hint), struct_info_annotation, span) {} + + TVM_DLL explicit Var(Id vid, Optional struct_info_annotation, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); +}; + +/*! \brief A sub-type of the variable node used to mark dataflow variables from + * normal visible "function local" bindings. + */ +class DataflowVarNode : public VarNode { + public: + void VisitAttrs(AttrVisitor* v) { + v->Visit("vid", &vid); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(vid, other->vid) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(vid); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.DataflowVar"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarNode, VarNode); +}; + +class DataflowVar : public Var { + public: + TVM_DLL explicit DataflowVar(String name_hint, Optional struct_info_annotation, + Span span = Span()) + : DataflowVar(Id(name_hint), struct_info_annotation, span) {} + + TVM_DLL explicit DataflowVar(Id vid, Optional struct_info_annotation, + Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode); +}; + +/*! + * \brief Constant tensor. + * + * \note Scalar constants are represented by ndim-0 constant tensors. + */ +class ConstantNode : public LeafExprNode { + public: + /*! \brief The data of the tensor */ + runtime::NDArray data; + + /*! \return The corresponding tensor type of the data */ + TensorType tensor_type() const; + + /*! \return Whether it is scalar(ndim-0 tensor) */ + bool is_scalar() const { return data->ndim == 0; } + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("data", &data); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(data, other->data); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); } + + static constexpr const char* _type_key = "relax.expr.Constant"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, LeafExprNode); +}; + +class Constant : public LeafExpr { + public: + /*! + * \brief The constructor + * \param data The data of the constant tensor. + * \param span The source span of the expression. + */ + TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Constant, LeafExpr, ConstantNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode); +}; + +/*! + * \brief PrimValue. + * + * Expression representing a TIR POD expression. + */ +class PrimValueNode : public LeafExprNode { + public: + /*! \brief The prim expr representing the value */ + PrimExpr value; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const PrimValueNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } + + static constexpr const char* _type_key = "relax.expr.PrimValue"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimValueNode, LeafExprNode); +}; + +/*! + * \brief Managed reference to PrimValueNode + * \sa PrimValeNode + */ +class PrimValue : public LeafExpr { + public: + /*! + * \brief The constructor + * \param value The value input. + * \param span The source span of the expression. + */ + TVM_DLL explicit PrimValue(PrimExpr value, Span span = Span()); + + /*! + * \brief Create a int64 prim value. + * \param value The input value. + * \param span The source span of the expression. + * \return The created prim value. + */ + TVM_DLL static PrimValue Int64(int64_t value, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(PrimValue, LeafExpr, PrimValueNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimValueNode); +}; + +/*! + * \brief Represent a string literal constant. + */ +class StringImmNode : public LeafExprNode { + public: + /*! \brief The data value. */ + String value; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } + + static constexpr const char* _type_key = "relax.expr.StringImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, LeafExprNode); +}; + +/*! + * \brief Managed reference to StringImm + * \sa StringImmNode + */ +class StringImm : public LeafExpr { + public: + /*! + * \brief The constructor + * \param value The value input. + * \param span The source span of the expression. + */ + TVM_DLL explicit StringImm(String value, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(StringImm, LeafExpr, StringImmNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode); +}; + +/*! + * \brief Represent a data type constant. + */ +class DataTypeImmNode : public LeafExprNode { + public: + /*! \brief The data value. */ + DataType value; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const DataTypeImmNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } + + static constexpr const char* _type_key = "relax.expr.DataTypeImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataTypeImmNode, LeafExprNode); +}; + +/*! + * \brief Managed reference to DataTypeImm + * \sa DataTypeImmNode + */ +class DataTypeImm : public LeafExpr { + public: + /*! + * \brief The constructor + * \param value The value input. + * \param span The source span of the expression. + */ + TVM_DLL explicit DataTypeImm(DataType value, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DataTypeImm, LeafExpr, DataTypeImmNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataTypeImmNode); +}; + +/*! \brief The base class of a variable binding in Relax. */ +class BindingNode : public Object { + public: + /*! \brief The return variable to bound to. */ + Var var; + mutable Span span; + + static constexpr const char* _type_key = "relax.expr.Binding"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(BindingNode, Object); +}; + +class Binding : public ObjectRef { + protected: + Binding() = default; + + public: + explicit Binding(ObjectPtr n) : ObjectRef(n) {} + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding); + const BindingNode* operator->() const { return static_cast(data_.get()); } + const BindingNode* get() const { return operator->(); } + using ContainerType = BindingNode; +}; + +/*! + * \brief Runtime-match the value to the struct info. + * + * This operation does runtime check, populates the un-defined symbolic shape vars + * and vars in struct_info in first occurance, and insert equality assertions in + * other cases. + */ +class MatchCastNode : public BindingNode { + public: + /*! \brief The input value to match cast. */ + Expr value; + /*! \brief The struct info pattern to match to. */ + StructInfo struct_info; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("struct_info", &struct_info); + v->Visit("span", &span); + } + + bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const { + // NOTE: pattern can contain ShapeExpr which defines the vars + return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) && + equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { + // NOTE: pattern can contain ShapeExpr which defines the vars + hash_reduce.DefHash(var); + hash_reduce.DefHash(struct_info); + hash_reduce(value); + } + + static constexpr const char* _type_key = "relax.expr.MatchCast"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(MatchCastNode, BindingNode); +}; + +/*! + * \brief Managed reference to MatchCastNode. + * \sa MatchCastNode + */ +class MatchCast : public Binding { + public: + TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(MatchCast, Binding, MatchCastNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode); +}; + +class VarBindingNode : public BindingNode { + public: + /*! \brief The binding value. */ + Expr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("span", &span); + } + + bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { + return equal.DefEqual(var, other->var) && equal(value, other->value); + } + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(var); + hash_reduce(value); + } + static constexpr const char* _type_key = "relax.expr.VarBinding"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(VarBindingNode, BindingNode); +}; + +class VarBinding : public Binding { + public: + TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode); +}; + +class BindingBlockNode : public Object { + public: + mutable Span span; + Array bindings; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("span", &span); + v->Visit("bindings", &bindings); + } + + bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const { + return equal(bindings, other->bindings); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + + static constexpr const char* _type_key = "relax.expr.BindingBlock"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(BindingBlockNode, Object); +}; + +class BindingBlock : public ObjectRef { + public: + TVM_DLL explicit BindingBlock(Array bindings, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode); +}; + +class DataflowBlock; +class DataflowBlockNode : public BindingBlockNode { + public: + bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const { + return equal(bindings, other->bindings); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + + static constexpr const char* _type_key = "relax.expr.DataflowBlock"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockNode, BindingBlockNode); +}; + +class DataflowBlock : public BindingBlock { + public: + TVM_DLL explicit DataflowBlock(Array bindings, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode); +}; + +/*! \brief A sequence of blocks followed by an expression. + * + * The order of blocks enforces scoping and ordering. + */ +class SeqExprNode : public ExprNode { + public: + Array blocks; + Expr body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("blocks", &blocks); + v->Visit("body", &body); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const SeqExprNode* other, SEqualReducer equal) const { + return equal(blocks, other->blocks) && equal(body, other->body) && + equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(blocks); + hash_reduce(body); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.SeqExpr"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(SeqExprNode, ExprNode); +}; + +class SeqExpr : public Expr { + public: + TVM_DLL explicit SeqExpr(Array blocks, Expr body, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode); +}; + +/*! \brief A Relax function. */ +class FunctionNode : public BaseFuncNode { + public: + /*! \brief The parameters to the function. */ + Array params; + /*! \brief The body of the function. */ + Expr body; + /*! \brief The return type of the function. */ + StructInfo ret_struct_info; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("body", &body); + v->Visit("ret_struct_info", &ret_struct_info); + v->Visit("attrs", &attrs); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal.DefEqual(params, other->params) && equal(body, other->body) && + equal(ret_struct_info, other->ret_struct_info) && equal(attrs, other->attrs) && + equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce.DefHash(params); + hash_reduce(body); + hash_reduce(ret_struct_info); + hash_reduce(attrs); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.Function"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); +}; + +class Function : public BaseFunc { + public: + TVM_DLL explicit Function(Array params, Expr body, Optional ret_struct_info, + DictAttrs attrs = NullValue(), Span span = Span()); + + /*! + * \brief Mimics the constructor but without body Expr. + * \note ret_struct_info is required, since it can not deduced by the body + */ + TVM_DLL static Function CreateEmpty(Array params, StructInfo ret_struct_info, + DictAttrs attrs = NullValue(), Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); +}; + +// TODO(@sunggg): Investigate the exact usage of kComposite, kPartitionedFromPattern, and +// kPrimitive. +namespace attr { +/*! \brief Mark the function as a primitive function. */ +constexpr const char* kPrimitive = "Primitive"; +/*! + * \brief Indicate the codegen that should be used for building this function. + * When this is unset or set to "default", the default compilation pipeline will be used. + */ +constexpr const char* kCodegen = "Codegen"; +/*! \brief Treat the function as a composite operator. */ +constexpr const char* kComposite = "Composite"; +/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ +constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; +} // namespace attr + +/*! \brief The extern function, which can represent packed function. */ +class ExternFuncNode : public BaseFuncNode { + public: + /*! \brief The name of global symbol. */ + String global_symbol; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("global_symbol", &global_symbol); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const { + return equal(global_symbol, other->global_symbol) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(global_symbol); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.ExternFunc"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncNode, BaseFuncNode); +}; + +class ExternFunc : public BaseFunc { + public: + TVM_DLL ExternFunc(String global_symbol, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, BaseFunc, ExternFuncNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode); +}; + +/*! + * \brief Get the shape of Expr. + * \param expr The input expr. + * \return The corresonding shape. + * + * \note This function requires expr to be normalized. + * The function will report an error if expr's StructInfo is not TensorStructInfo. + * It will try to return symbolic function when possible. If the tensor do not + * have a compile-time symbolic shape, the function will then choose to return + * Call(relax.op.shape_of, [expr]). + */ +TVM_DLL Expr GetShapeOf(const Expr& expr); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_EXPR_H_ diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h new file mode 100644 index 000000000000..ce209ccd460f --- /dev/null +++ b/include/tvm/relax/expr_functor.h @@ -0,0 +1,551 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/expr_functor.h + * \brief A more powerful visitor which enables defining arbitrary function + * signatures with type based dispatch on first argument. + */ +#ifndef TVM_RELAX_EXPR_FUNCTOR_H_ +#define TVM_RELAX_EXPR_FUNCTOR_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +namespace tvm { +namespace relax { + +/*! + * \brief A dynamical functor that dispatches on in the first Expr argument. + * You can use this as a more powerful Visitor, since it allows you to + * define function signatures of Visit Function. + * + * \sa tvm/ir_functor.h + * + * \tparam FType function signiture + * This type is only defined for FType with function signature R(const Expr&, + * Args...) + */ +template +class ExprFunctor; + +// functions to be overriden. +#define EXPR_FUNCTOR_DEFAULT \ + { return VisitExprDefault_(op, std::forward(args)...); } + +#define RELAX_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ + }); + +#define PY_EXPR_VISITOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC) \ + { \ + if (PY_FUNC != nullptr) \ + PY_FUNC(N); \ + else \ + DEFAULT_FUNC; \ + } + +#define PY_EXPR_MUTATOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC, RET_TYPE) \ + { \ + if (PY_FUNC != nullptr) { \ + RET_TYPE ret = PY_FUNC(N); \ + return ret; \ + } else { \ + return DEFAULT_FUNC; \ + } \ + } + +#define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + if (self->PY_FUNC != nullptr) \ + self->PY_FUNC(n); \ + else \ + self->VisitExpr_(static_cast(n.get())); \ + }); + +#define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + if (self->PY_FUNC != nullptr) { \ + Expr expr = self->PY_FUNC(n); \ + return expr; \ + } else { \ + return self->VisitExpr_(static_cast(n.get())); \ + } \ + }); + +#define PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OP) \ + post_order_vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + return self->VisitExprPostOrder_(static_cast(n.get())); \ + }); + +template +class ExprFunctor { + private: + using TSelf = ExprFunctor; + using FType = tvm::NodeFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~ExprFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward(args)...); } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitExpr(const Expr& n, Args... args) { + ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may " + "have generated invalid data."; + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + // NOTE: cross dialect calls are invoked through global var + // We do not expect inline PrimFunc to appear in relax IR. + virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const DataflowVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ShapeExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ExternFuncNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SeqExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const PrimValueNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const DataTypeImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExprDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAX_EXPR_FUNCTOR_DISPATCH(ConstantNode); + RELAX_EXPR_FUNCTOR_DISPATCH(TupleNode); + RELAX_EXPR_FUNCTOR_DISPATCH(VarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(DataflowVarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(ShapeExprNode); + RELAX_EXPR_FUNCTOR_DISPATCH(ExternFuncNode); + RELAX_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode); + RELAX_EXPR_FUNCTOR_DISPATCH(CallNode); + RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode); + RELAX_EXPR_FUNCTOR_DISPATCH(IfNode); + RELAX_EXPR_FUNCTOR_DISPATCH(OpNode); + RELAX_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); + RELAX_EXPR_FUNCTOR_DISPATCH(PrimValueNode); + RELAX_EXPR_FUNCTOR_DISPATCH(StringImmNode); + RELAX_EXPR_FUNCTOR_DISPATCH(DataTypeImmNode); + return vtable; + } +}; + +/*! + * \brief A simple visitor wrapper around ExprFunctor. + * Recursively visit the content. + */ +class ExprVisitor : public ExprFunctor { + public: + /*! + * \brief Generic dispatcher for Expr. + * \param expr The expr to be visited. + */ + void VisitExpr(const Expr& expr) override; + // specific leaf level visitor functions + void VisitExpr_(const ConstantNode* op) override; + void VisitExpr_(const TupleNode* op) override; + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const DataflowVarNode* op) override; + void VisitExpr_(const ShapeExprNode* op) override; + void VisitExpr_(const ExternFuncNode* op) override; + void VisitExpr_(const GlobalVarNode* op) override; + void VisitExpr_(const FunctionNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const SeqExprNode* op) override; + void VisitExpr_(const IfNode* op) override; + void VisitExpr_(const OpNode* op) override; + void VisitExpr_(const TupleGetItemNode* op) override; + void VisitExpr_(const PrimValueNode* op) override; + void VisitExpr_(const StringImmNode* op) override; + void VisitExpr_(const DataTypeImmNode* op) override; + + /*! + * \brief Generic dispatcher for bindings. + * \param binding The binding to be visited. + */ + virtual void VisitBinding(const Binding& binding); + // specific leaf level visitor functions + virtual void VisitBinding_(const VarBindingNode* binding); + virtual void VisitBinding_(const MatchCastNode* binding); + // second level dispatching based on binding value type. + // these dispatching functions get called from first-level dispatch on VarBinding + virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const TupleNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const VarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ExternFuncNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const GlobalVarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const CallNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const SeqExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const PrimValueNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val); + /*! + * \brief Generic dispatcher for binding blocks. + * \param block The binding block to be visited. + */ + virtual void VisitBindingBlock(const BindingBlock& block); + // specific leaf level visitor functions + virtual void VisitBindingBlock_(const BindingBlockNode* block); + virtual void VisitBindingBlock_(const DataflowBlockNode* block); + + /*! + * \brief Generic dispatcher for visiting the var definition site. + * \param var The var to be visited. + * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var + */ + virtual void VisitVarDef(const Var& var); + + /*! + * \brief Visit struct_info may recursively contain Expr/PrimExpr. + * + * By default, this function recurse into struct info such as + * TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr + * accordingly. It does not recurse into FunctionStructInfo as it does + * not contain Expr defined in the current scope. + * + * Pass writers can overload this function to change to other behaviors. + * For example, if we are not interested in Expr in StructInfo, we can + * override this function by a no-op. + * + * \param struct_info Input struct info field. + */ + virtual void VisitExprDepStructInfoField(const StructInfo& struct_info); + + // specific leaf level visitor functions + virtual void VisitVarDef_(const VarNode* var); + virtual void VisitVarDef_(const DataflowVarNode* var); + + virtual void VisitSpan(const Span& span); + virtual void VisitPrimExpr(const PrimExpr& expr); + + private: + using TSelf = ExprVisitor; + using VisitBindingVTable = + tvm::NodeFunctor; + // initialize the vtable. + static VisitBindingVTable InitVisitBindingVTable(); + /*! + * \brief Private internal struct info field visitor. + * + * Support default visiting of struct info field and recursive into + * their Expr fields. + * + * We use component instead of sub-classing so there can be other + * joint inheritance between ExprVisitor and StructInfoVisitor. + */ + class DefaultStructInfoFieldVisitor : public StructInfoVisitor { + public: + explicit DefaultStructInfoFieldVisitor(ExprVisitor* parent); + + // Override defaults in struct info visitor. + void VisitStructInfoExprField(const Expr& expr) final; + void VisitStructInfoExprField(const PrimExpr& expr) final; + void VisitStructInfo_(const FuncStructInfoNode* op) final; + + private: + ExprVisitor* parent_; + }; + // This visitor is not visible to child classes and only + // used to supported default visiting behavior. + DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this}; +}; + +void PostOrderVisit(const Expr& node, std::function fvisit); + +/*! + * \brief A mutator works in unnormalized form. + * + * ExprMutatorBase expects input AST to be in the unnormalized form, i.e., checked_type_ and shape_ + * of expressions can be nullptr, and the expressions may nest(and as a result the AST is not in + * ANF). + */ + +class ExprMutatorBase : public ExprFunctor { + public: + Expr VisitExpr(const Expr& expr) override; + Expr VisitExpr_(const ConstantNode* op) override; + Expr VisitExpr_(const TupleNode* op) override; + Expr VisitExpr_(const VarNode* op) override; + Expr VisitExpr_(const DataflowVarNode* op) override; + Expr VisitExpr_(const ShapeExprNode* op) override; + Expr VisitExpr_(const ExternFuncNode* op) override; + Expr VisitExpr_(const GlobalVarNode* op) override; + Expr VisitExpr_(const FunctionNode* op) override; + Expr VisitExpr_(const CallNode* op) override; + Expr VisitExpr_(const SeqExprNode* op) override; + Expr VisitExpr_(const IfNode* op) override; + Expr VisitExpr_(const OpNode* op) override; + Expr VisitExpr_(const TupleGetItemNode* op) override; + Expr VisitExpr_(const PrimValueNode* op) override; + Expr VisitExpr_(const StringImmNode* op) override; + Expr VisitExpr_(const DataTypeImmNode* op) override; + + /*! + * \brief Mutate BindingBlock. + * \param block The binding block to be visited. + * \return The binding block after transformation. + */ + virtual BindingBlock VisitBindingBlock(const BindingBlock& block); + + /*! + * \brief Used to visit the PrimExpr inside of expressions. + * + * Can be overloaded to transform the shape expressions. + */ + virtual PrimExpr VisitPrimExpr(const PrimExpr& expr); + + /*! + * \brief Visit struct_info that may recursively contain Expr/PrimExpr. + * + * By default, this function recurse into struct info such as + * TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr + * accordingly. It does not recurse into FunctionStructInfo as it does + * not contain Expr defined in the current scope. + * + * Pass writers can overload this function to change to other behaviors. + * For example, if in Expr in StructInfo won't change, we can + * override this function by an identity function. + * + * \param struct_info Input struct info field. + * \return The updated struct info. + */ + virtual StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info); + + protected: + /*! + * \brief Check whether VisitExprDepStructInfoField change struct_info. + * \return Whether struct info changed. + * \note This function is used by mutator implementations to check if + * previous Expr update will trigger a change in struct_info. + * If change is detected, the implementation can generate a fresh + * node without struct_info, and trigger normalizer to re-derive. + */ + bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef& struct_info) { + if (const StructInfoNode* sinfo = struct_info.as()) { + return this->VisitExprDepStructInfoField(GetRef(sinfo)).same_as(struct_info); + } else { + return true; + } + } + + private: + /*! + * \brief Private internal struct info field visitor to support + * Default visiting of struct info field and recursive into their Expr fields. + * + * We use component instead of sub-classing so there can be other + * joint inheritance between ExprMutator and StructInfoMutator. + */ + class DefaultStructInfoFieldMutator : public StructInfoMutator { + public: + explicit DefaultStructInfoFieldMutator(ExprMutatorBase* parent); + + // Override defaults in struct info visitor. + Expr VisitStructInfoExprField(const Expr& expr) final; + PrimExpr VisitStructInfoExprField(const PrimExpr& expr) final; + StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final; + + private: + ExprMutatorBase* parent_; + }; + // This visitor is not visible to child classes and only + // used to supported default visiting behavior. + DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this}; +}; + +/*! + * \brief A mutator works in normal form. + * + * ExprMutator expects input AST to be in the normal form, i.e., the expressions are normalized(no + * nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are + * available. + */ +class ExprMutator : public ExprMutatorBase { + public: + using ExprMutatorBase::VisitExpr_; + + ExprMutator(Optional mod = NullOpt) { builder_ = BlockBuilder::Create(mod); } + Expr VisitExpr(const Expr& expr) override; + Expr VisitExpr_(const VarNode* op) override; + Expr VisitExpr_(const DataflowVarNode* op) override; + Expr VisitExpr_(const FunctionNode* op) override; + Expr VisitExpr_(const SeqExprNode* op) override; + Expr VisitExpr_(const IfNode* op) override; + + /*! + * \brief Generic dispatcher for bindings. + * \param binding The binding to be visited. + */ + virtual void VisitBinding(const Binding& binding); + // specific leaf level visitor functions + virtual void VisitBinding_(const VarBindingNode* binding); + virtual void VisitBinding_(const MatchCastNode* binding); + // second level dispatching based on binding value type. + // these dispatching functions get called from first-level dispatch on VarBinding + virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const TupleNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const VarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ExternFuncNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const GlobalVarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const CallNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const SeqExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const PrimValueNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val); + /*! + * \brief Generic dispatcher for binding blocks. + * \param block The binding block to be visited. + * \return The binding block after transformation. + */ + virtual BindingBlock VisitBindingBlock(const BindingBlock& block) override; // NOLINT(*) + // specific leaf level visitor functions + virtual BindingBlock VisitBindingBlock_(const BindingBlockNode* block); + virtual BindingBlock VisitBindingBlock_(const DataflowBlockNode* block); + + /*! + * \brief Generic dispatcher for rewriting the var definition site. + * \param var The var to be visited. + * \return The var after post-order rewritten. + * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var + */ + virtual Var VisitVarDef(const Var& var); + // specific leaf level visitor functions + virtual Var VisitVarDef_(const VarNode* var); + virtual Var VisitVarDef_(const DataflowVarNode* var); + + protected: + /*! + * \brief Try to remit binding and bind it to a new_value + * + * This function is called after VisitExpr(binding->value) in + * VisitBinding_(const VarBinding*). + * It will try to reuse the current binding when the new value's shape/type + * matches the original binding and no changes in var is needed. + * + * Otherwise, a new binding will be emitted to replace the var specified in + * the current binding. + */ + void ReEmitBinding(const VarBindingNode* binding, Expr new_value); + + /*! + * \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If. + * + * \param body_expr The body to be visited. + * \param params Optional parameters that are visible within the scope. + * \return The expr after visiting. + * + * \note The body_expr must be an SeqExpr in the normal form. + */ + Expr VisitWithNewScope(const Expr& body_expr, Optional> params = NullOpt); + + /*! + * \brief Look up the value bound to a variable. + * \param var The var to be looked up. + * \return The value bound to the input \p var. + * \note For function parameters, this function returns NullOpt. + */ + Optional LookupBinding(const Var& var); + + /*! + * \brief Post-order rewrite a node and normalize. + * \tparam T The node type to be rewritten. + * \param op The node to be rewritten. + * \return The node after post rewritten. + */ + template + Expr VisitExprPostOrder_(const T* op) { + return builder_->Normalize(ExprMutator::VisitExpr_(op)); + } + + /*! + * \brief Create a new var with specified struct_info if the original var's shape or type does + * not match with the specified ones. + * \param var The var to be updated. + * \param struct_info The struct info to be updated. + * \return The var filled with struct_info + */ + Var WithStructInfo(Var var, StructInfo struct_info); + + /*! \brief Internal block builder to emit bindings during rewriting. */ + BlockBuilder builder_; + + /*! \brief Remap a var to a new var in use-site. */ + std::unordered_map var_remap_; + + private: + using TSelf = ExprMutator; + using VisitBindingVTable = + tvm::NodeFunctor; + // initialize the vtable. + static VisitBindingVTable InitVisitBindingVTable(); +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_EXPR_FUNCTOR_H_ diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h new file mode 100644 index 000000000000..0564c2668797 --- /dev/null +++ b/include/tvm/relax/nested_msg.h @@ -0,0 +1,580 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/nested_msg.h + * \brief Helper container to store nested message for robust tuple-aware analysis. + * + * Please see NestedMsg for description of usage. + * + * \sa NestedMsg + */ +#ifndef TVM_RELAX_NESTED_MSG_H_ +#define TVM_RELAX_NESTED_MSG_H_ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Container that stores possibly nested message with leaf message type T. + * + * NestedMsg is a helper structure to store intermediate + * message state in pass analysis so we can robustly handle message + * passing with the presence of nested tuple types. + * + * Under the hood, NestedMsg[T] = Union[T, NullOpt, Array[NestedMsg[T]]]. + * Each nested message corresponds to the same nesting structure as + * the nested tuple types when we encounter them in analysis. + * + * Relax support nested tuple structures in the IR. Nested tuple structure + * is important to support advanced groupings in cases such as gradient calculation + * and other scenarios. + * + * The possible presence of nested tuple does mean that we need to + * to robustly handle analysis that contains nested tuple structures + * in a dataflow graph. + * + * \code + * + * v1 = relu(v0) + * v2 = exp(v0) + * t = ((v0, v1), (v2,), v0) + * t1 = t[0] + * v3 = concat(t1) + * v4 = t[2] + * v5 = add(v4, v3) + * + * \endcode + * + * Consider the above code sequence that contains a mixture of tuple + * nesting and normal operations. A common message-passing-based analysis + * will track messages attached to each intermediate variable. + * + * Because the intermediate value can contain nested-tuples, we need to have + * abilities to nest messages according to tuple structure and propagate them + * along the way. In python, this simply corresponds to using a tuple to hold + * nested messages. This class provides a helper wrapper in C++ to present such + * possibly nested message for a given leaf message. + * + * This design pattern is necessary to handle tuple values regardless of + * the normal form design of the IR to enable different messages for each + * tuple component without enforcing all tuple elements to have the same message. + * + * Please consider the following patterns in our pass: + * + * On a forward propagation message passing analysis: + * - Create map [leafnode=>NestedMsg], scan forward + * - input_msg = [MapToNestedMsg(x, lookup_map) for x in call->args] + * - output_msg = ForwardProp[call->op](input_msg, call) + * - map[binding->var] = output_msg + * - Use MapToNestedMsg to remap the remaining body. + * + * On a backward propagation message passing analysis: + * - Create map [leafnode=>NestedMsg], scan backward + * - output_msg = lookup map(binding->var) + * - handle case when output_msg is null + * - input_msg = BackProp[call->op](out_msg, call) + * - for arg, msg in zip(call->args, input_msg), + * DecomposeNestedMessage(arg, msg, lambda node, m: update_map(node, m)) + * - update_map(node, m) => CombineNestedMessage(map[node], m) + * + * Here leafnode is a node that you would like to propagate messages to + * such as constant, var and should not include tuple. + * + * We also recommend writing unit-test cases that involve nested tuple composition + * and decomposition. + * + * \sa MapToNestedMsg, DecomposeNestedMsg, CombineNestedMsg, ForEachLeaf, Equal + * + * \note If you want to write robust message passing-based analysis for + * programs that can contain nested tuples, you likely need to + * use this class or logic of a similar kind. + */ +template +class NestedMsg : public ObjectRef { + public: + // default constructors. + NestedMsg() = default; + NestedMsg(const NestedMsg&) = default; + NestedMsg(NestedMsg&&) = default; + NestedMsg& operator=(const NestedMsg&) = default; + NestedMsg& operator=(NestedMsg&&) = default; + /*! + * \brief Construct from an ObjectPtr + * whose type already satisfies the constraint + * \param ptr + */ + explicit NestedMsg(ObjectPtr ptr) : ObjectRef(ptr) {} + /*! \brief Nullopt handling */ + NestedMsg(runtime::NullOptType) {} // NOLINT(*) + // nullptr handling. + // disallow implicit conversion as 0 can be implicitly converted to nullptr_t + explicit NestedMsg(std::nullptr_t) {} + NestedMsg& operator=(std::nullptr_t) { + data_ = nullptr; + return *this; + } + // normal value handling. + NestedMsg(T other) // NOLINT(*) + : ObjectRef(std::move(other)) {} + NestedMsg& operator=(T other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + // Array> handling + NestedMsg(Array, void> other) // NOLINT(*) + : ObjectRef(std::move(other)) {} + NestedMsg& operator=(Array, void> other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + + // initializer list handling + NestedMsg(std::initializer_list> other) // NOLINT(*) + : NestedMsg(Array, void>(other)) {} + NestedMsg& operator=(std::initializer_list> other) { + return operator=(Array, void>(other)); + } + + // delete the int constructor + // since NestedMsg(0) is ambiguous + // 0 can be implicitly casted to nullptr_t + explicit NestedMsg(int val) = delete; + NestedMsg& operator=(int val) = delete; + // operator overloadings + bool operator==(std::nullptr_t) const { return data_ == nullptr; } + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } + + /*! \return Whether the nested message is not-null leaf value */ + bool IsLeaf() const { return data_ != nullptr && data_->IsInstance(); } + + /*! \return Whether the nested message is null */ + bool IsNull() const { return data_ == nullptr; } + + /*! \return Whether the nested message is nested */ + bool IsNested() const { return data_ != nullptr && data_->IsInstance(); } + + /*! + * \return The underlying leaf value. + * \note This function checks if the msg is leaf. + */ + T LeafValue() const { + ICHECK(IsLeaf()); + return T(data_); + } + + /*! + * \return a corresponding nested array. + * \note This checks if the underlying data type is array. + */ + Array, void> NestedArray() const { + ICHECK(IsNested()); + return Array, void>(data_); + } + + using ContainerType = Object; + using LeafContainerType = typename T::ContainerType; + + static_assert(std::is_base_of::value, "NestedMsg is only defined for ObjectRef."); + + static constexpr bool _type_is_nullable = true; +}; + +/*! + * \brief Apply fvisit for each leaf elements in the nested message. + * \param fvisit The visit callback. + * \param msg The input nested message. + * \tparam T the content type of nested msg + * \tparam FType the visitor type with signature void fvisit(T) + */ +template +void ForEachLeaf(const NestedMsg& msg, FType fvisit) { + if (msg == nullptr) return; + if (msg.IsLeaf()) { + fvisit(msg.LeafValue()); + } else { + for (NestedMsg x : msg.NestedArray()) { + ForEachLeaf(x, fvisit); + } + } +} + +/*! + * \brief Recursively compare two nested messages. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param fequal The equal functor with signature bool fequal(T, T) + * \tparam T the content type of nested msg + * \tparam FType the equal comparator type + */ +template +bool Equal(const NestedMsg& lhs, const NestedMsg& rhs, FType fequal) { + if (lhs.IsNull()) return rhs.IsNull(); + if (rhs.IsNull()) return lhs.IsNull(); + if (lhs.IsLeaf()) { + return rhs.IsLeaf() && fequal(lhs.LeafValue(), rhs.LeafValue()); + } else { + if (!rhs.IsNested()) return false; + Array> arr_lhs = lhs.NestedArray(); + Array> arr_rhs = rhs.NestedArray(); + if (arr_lhs.size() != arr_rhs.size()) return false; + for (size_t i = 0; i < arr_lhs.size(); ++i) { + if (!Equal(arr_lhs[i], arr_rhs[i], fequal)) return false; + } + return true; + } +} + +/*! + * \brief Map expr with possible nested-tuple to nested message. + * + * This function will unpack recursive tuples and run fmapleaf for each leaf, + * then recursively combines the results together into a NestedMsg. + * + * The nesting structure will corresponds to the tuple structure. + * + * \param expr The input expression. + * \param fmapleaf The mapping function for each leaf with signature `NestedMsg fmap(Expr)` + * \tparam T the content type of nested msg + * \tparam FType The mapping function type + */ +template +NestedMsg MapToNestedMsg(Expr expr, FType fmapleaf) { + if (auto* tuple = expr.as()) { + Array> res; + res.reserve(tuple->fields.size()); + for (Expr x : tuple->fields) { + res.push_back(MapToNestedMsg(x, fmapleaf)); + } + return res; + } else { + return fmapleaf(expr); + } +} + +/*! + * \brief Map structinfo with possible nested-sinfo to nested message. + * + * This function will unpack recursive sinfo and run fmapleaf for each leaf, + * then recursively combines the results together into a NestedMsg. + * + * The nesting structure will corresponds to the tuple structure. + * + * \param sinfo The input struct info. + * \param fmapleaf The mapping function for each leaf with signature `NestedMsg fmap(StructInfo)` + * \tparam T the content type of nested msg + * \tparam FType The mapping function type + */ +template +NestedMsg MapToNestedMsg(StructInfo sinfo, FType fmapleaf) { + if (auto* tuple = sinfo.as()) { + Array> res; + res.reserve(tuple->fields.size()); + for (StructInfo x : tuple->fields) { + res.push_back(MapToNestedMsg(x, fmapleaf)); + } + return res; + } else { + return fmapleaf(sinfo); + } +} + +/*! + * \brief Map expr with possible nested-tuple to nested message. + * + * This function will unpack recursive expr by its struct info and + * run fmapleaf for each leaf, then recursively combines the results + * together into a NestedMsg. + * + * The nesting structure will corresponds to the struct info of expr. + * + * \param expr The input expression which should have struct info. + * \param fmapleaf The mapping function for each leaf with signature `NestedMsg fmapleaf(Expr)` + * \tparam T the content type of nested msg + * \tparam FType The mapping function type + */ +template +NestedMsg MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) { + auto sinfo = GetStructInfo(expr); + if (auto* tuple = sinfo.as()) { + Array> res; + res.reserve(tuple->fields.size()); + for (size_t i = 0; i < tuple->fields.size(); ++i) { + Expr field; + if (const auto* expr_tuple = expr.as()) { + field = expr_tuple->fields[i]; + } else { + field = TupleGetItem(expr, i); + UpdateStructInfo(field, tuple->fields[i]); + } + res.push_back(MapToNestedMsgBySInfo(field, fmapleaf)); + } + return res; + } else { + return fmapleaf(expr); + } +} + +/*! + * \brief Map nested message back to the expr. + * + * This function will decompose the nested message and + * run fmapleaf for each leaf message and get the leaf expr, + * then recursively combines the results as tuple expr. + * + * \param msg The input nested message. + * \param fmapleaf The mapping function for each leaf with signature `Expr fmapleaf(Optional)`. + * \tparam T the content type of nested msg. + * \tparam FType The mapping function type. + */ +template +Expr NestedMsgToExpr(NestedMsg msg, FType fmapleaf) { + if (msg.IsNull()) { + return fmapleaf(NullOpt); + } else if (msg.IsLeaf()) { + return fmapleaf(msg.LeafValue()); + } else { + ICHECK(msg.IsNested()); + Array> arr = msg.NestedArray(); + Array subexpr; + subexpr.reserve(arr.size()); + for (size_t i = 0; i < arr.size(); ++i) { + subexpr.push_back(NestedMsgToExpr(arr[i], fmapleaf)); + } + Optional simplified_tuple; + bool simplified_flag = false; + if (subexpr.size() >= 1) { + simplified_flag = true; + for (size_t i = 0; i < subexpr.size() && simplified_flag; ++i) { + auto* node = subexpr[i].as(); + if (node == nullptr || node->index != static_cast(i)) { + simplified_flag = false; + } else { + if (simplified_tuple.defined()) { + simplified_flag &= (simplified_tuple == node->tuple); + } else { + simplified_tuple = node->tuple; + ICHECK(simplified_tuple.defined()); + } + } + } + } + return simplified_flag ? simplified_tuple.value() : Tuple(subexpr); + } +} + +/*! + * \brief Recursively combine two nested message into one. + * + * This function requires the two messages to be compatible with each other. + * The combination rule is as follows: + * - combine(null, msg) => msg + * - combine(leaf1, leaf2) => fcombine(leaf1, leaf2) + * - combine(array1, array2) => [combine(x, y) for x, y in zip(array1, array2)] + * - This function will throw an error if array have different size + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param fcombine with signature T fcombine(T lhs, T rhs) + * \tparam T the content type of nested msg + * \tparam FType combine function type. + */ +template +NestedMsg CombineNestedMsg(NestedMsg lhs, NestedMsg rhs, FType fcombine) { + if (lhs.IsNull()) return rhs; + if (rhs.IsNull()) return lhs; + + if (lhs.IsLeaf()) { + ICHECK(rhs.IsLeaf()) << "Cannot combine leaf with nested"; + return NestedMsg(fcombine(lhs.LeafValue(), rhs.LeafValue())); + } else { + ICHECK(lhs.IsNested()); + ICHECK(rhs.IsNested()) << "Cannot combine leaf with nested"; + Array> arr_lhs = lhs.NestedArray(); + Array> arr_rhs = rhs.NestedArray(); + ICHECK_EQ(arr_lhs.size(), arr_rhs.size()) + << "Cannot combine two nested array with different sizes"; + Array> res; + res.reserve(arr_lhs.size()); + for (size_t i = 0; i < arr_lhs.size(); ++i) { + res.push_back(CombineNestedMsg(arr_lhs[i], arr_rhs[i], fcombine)); + } + return NestedMsg(res); + } +} + +/*! + * \brief Recursively map a nested message to another one, with leaf mapped by the input fmapleaf. + * \param msg The nested message to be mapped. + * \param fmapleaf The leaf map function, with signature NestedMsg fmapleaf(T msg) + * \tparam T The content type of nested message. + * \tparam FType The leaf map function type. + * \return The new nested message. + */ +template +NestedMsg MapNestedMsg(NestedMsg msg, FType fmapleaf) { + if (msg.IsNull()) { + return msg; + } else if (msg.IsLeaf()) { + return fmapleaf(msg.LeafValue()); + } else { + ICHECK(msg.IsNested()); + Array> arr = msg.NestedArray(); + Array> res; + res.reserve(arr.size()); + for (int i = 0; i < static_cast(arr.size()); ++i) { + res.push_back(MapNestedMsg(arr[i], fmapleaf)); + } + return NestedMsg(res); + } +} + +/*! + * \brief Recursively decompose the tuple structure in expr and msg along with it. + * + * This function will call fvisitleaf for each leaf expression in expr. + * This function will throw an error if the nesting structure in msg does not + * match the tuple nesting structure in expr. + * + * \param expr The input expression to be decomposed. + * \param msg The input nested message. + * \param fvisitleaf with signature fvisitleaf(Expr expr, NestedMsg msg) + * \tparam T the content type of nested msg + * \tparam FType The visit function type. + */ +template +void DecomposeNestedMsg(Expr expr, NestedMsg msg, FType fvisitleaf) { + if (auto* tuple = expr.as()) { + ICHECK(msg.IsNested()) << "Expected nested to match tuple"; + Array> arr = msg.NestedArray(); + ICHECK_EQ(arr.size(), tuple->fields.size()) << "Expected nested array size to match tuple size"; + for (size_t i = 0; i < arr.size(); ++i) { + DecomposeNestedMsg(tuple->fields[i], arr[i], fvisitleaf); + } + } else { + fvisitleaf(expr, msg); + } +} + +/*! + * \brief Recursively transform the tuple structure in expr and msgs along with it. + * + * This function will call ftransleaf for each leaf expression in expr. + * This function will throw an error if the nesting structure in msg does not + * match the tuple nesting structure in expr. + * + * \param expr The input expression to be transform.  + * \param msgs The input messages to guide the transformation. + * \param ftransleaf with signature ftransleaf(Expr, Array>)->Expr + * \tparam T the content type of nested msg + * \tparam N the number of messages + * \tparam FType The visit function type. + */ +template +Expr TransformTupleLeaf(Expr expr, std::array, N> msgs, FType ftransleaf) { + StructInfo sinfo = GetStructInfo(expr); + if (const auto* tuple = sinfo.as()) { + std::array>, N> msg_arrays; + for (size_t i = 0; i < N; ++i) { + ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; + msg_arrays[i] = msgs[i].NestedArray(); + } + bool same = true; + Array fields; + fields.reserve(tuple->fields.size()); + for (size_t i = 0; i < tuple->fields.size(); ++i) { + Expr field; + if (const auto* expr_tuple = expr.as()) { + field = expr_tuple->fields[i]; + } else { + field = TupleGetItem(expr, i); + UpdateStructInfo(field, tuple->fields[i]); + } + std::array, N> sub_msgs; + for (size_t j = 0; j < N; ++j) { + sub_msgs[j] = msg_arrays[j][i]; + } + fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf)); + same &= (fields.back().same_as(field)); + } + return same ? expr : Tuple(fields); + } else { + for (const auto& msg : msgs) { + ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple"; + } + return ftransleaf(expr, msgs); + } +} + +/*! + * \brief Recursively transform the tuple structure in sinfo and msgs along with it. + * + * This function will call ftransleaf for each leaf sinfo in sinfo. + * This function will throw an error if the nesting structure in msg does not + * match the tuple nesting structure in sinfo. + * + * \param sinfo The input sinfo to be transform.  + * \param msgs The input messages to guide the transformation. + * \param ftransleaf with signature ftransleaf(StructInfo, Array>)->StructInfo + * \tparam T the content type of nested msg + * \tparam N the number of messages + * \tparam FType The visit function type. + */ +template +StructInfo TransformTupleLeaf(StructInfo sinfo, std::array, N> msgs, + FType ftransleaf) { + if (const auto* tuple = sinfo.as()) { + std::array>, N> msg_arrays; + for (size_t i = 0; i < N; ++i) { + ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; + msg_arrays[i] = msgs[i].NestedArray(); + } + bool same = true; + Array fields; + fields.reserve(tuple->fields.size()); + for (size_t i = 0; i < tuple->fields.size(); ++i) { + StructInfo field = tuple->fields[i]; + std::array, N> sub_msgs; + for (size_t j = 0; j < N; ++j) { + sub_msgs[j] = msg_arrays[j][i]; + } + fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf)); + same &= (fields.back().same_as(field)); + } + return same ? sinfo : TupleStructInfo(fields); + } else { + for (const auto& msg : msgs) { + ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple"; + } + return ftransleaf(sinfo, msgs); + } +} + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_NESTED_MSG_H_ diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h new file mode 100644 index 000000000000..413d3e0499d0 --- /dev/null +++ b/include/tvm/relax/op_attr_types.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/op_attr_types.h + * \brief Data structures that can appear in operator attributes. + */ +#ifndef TVM_RELAX_OP_ATTR_TYPES_H_ +#define TVM_RELAX_OP_ATTR_TYPES_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Infer output struct info given the call + * + * \param call The call expression to be derived. + * \param ctx The builder context. + */ +using FInferStructInfo = + runtime::TypedPackedFunc; + +/*! + * \brief Packed function implementation for operators. The relax operator will be lowered to + * this packed function call during codegen. + */ +using FCallPacked = String; + +/*! + * \brief The function type of a legalization function, which takes a + * BlockBuilder and the Call to be legalized, and outputs the legalization + * result Expr. + * \param bb The BlockBuilder context. + * \param call The call to be legalized. + */ +using FLegalize = runtime::TypedPackedFunc; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_OP_ATTR_TYPES_H_ diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h new file mode 100644 index 000000000000..0c1973bceac9 --- /dev/null +++ b/include/tvm/relax/struct_info.h @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAX_STRUCT_INFO_H_ +#define TVM_RELAX_STRUCT_INFO_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Opaque object. + */ +class ObjectStructInfoNode : public StructInfoNode { + public: + void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const ObjectStructInfoNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.ObjectStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(ObjectStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to ObjectStructInfoNode. + * \sa ObjectStructInfoNode + */ +class ObjectStructInfo : public StructInfo { + public: + TVM_DLL ObjectStructInfo(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectStructInfo, StructInfo, ObjectStructInfoNode); +}; + +/*! + * \brief Primitive value. + */ +class PrimStructInfoNode : public StructInfoNode { + public: + /*! \brief Underlying data type of the primitive value */ + DataType dtype; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); } + + static constexpr const char* _type_key = "relax.PrimStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to PrimStructInfoNode. + * \sa PrimStructInfoNode + */ +class PrimStructInfo : public StructInfo { + public: + TVM_DLL PrimStructInfo(DataType dtype, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode); +}; + +/*! + * \brief StructInfo of shape value. + */ +class ShapeStructInfoNode : public StructInfoNode { + public: + /*! \brief optionally stores the symbolic value patterns of the shape */ + Optional> values; + /*! + * \brief The number of dimension of the shape, can be unknown. + * \sa kUnknownNDim + */ + int ndim; + + /*! \return Whether the struct info contains unknown ndim. */ + bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("values", &values); + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeStructInfoNode* other, SEqualReducer equal) const { + return equal(values, other->values) && equal(ndim, other->ndim); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(values); + hash_reduce(ndim); + } + + static constexpr const char* _type_key = "relax.ShapeStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to ShapeStructInfoNode. + * \sa ShapeStructInfoNode + */ +class ShapeStructInfo : public StructInfo { + public: + /*! + * \brief Construction with known symbolic shape patterns + * \param values The symbolic shape values + * \param span The span of the AST. + */ + TVM_DLL ShapeStructInfo(Array values, Span span = Span()); + /*! + * \brief Construction with known unknown symbolic shape patterns. + * \param ndim Number of dimensions -- can be kUnknownNDim + * \param span The span of the AST. + */ + TVM_DLL ShapeStructInfo(int ndim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeStructInfo, StructInfo, ShapeStructInfoNode); +}; + +/*! + * \brief StructInfo of Tensor. + */ +class TensorStructInfoNode : public StructInfoNode { + public: + /*! + * \brief optionally store the shape expression of the tensor. + * \note shape must be normalized: it can only be NullOpt or ShapeExpr or Var. + */ + Optional shape; + /*! \brief The content data type, use void to denote the dtype is unknown. */ + DataType dtype; + /*! + * \brief The number of dimension of the tensor, can be unknown. + * \sa kUnknownNDim + */ + int ndim; + + /*! \return Whether the struct info contains unknown ndim. */ + bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + + /*! \return Whether the struct info contains unknown dtype. */ + bool IsUnknownDtype() const { return dtype.is_void(); } + + /*! \return Shape if it is known. */ + Optional> GetShape() const { + if (!shape.defined()) return {}; + ShapeStructInfo shape_sinfo = Downcast(this->shape.value()->struct_info_); + return shape_sinfo->values; + } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const TensorStructInfoNode* other, SEqualReducer equal) const { + return equal(shape, other->shape) && equal(ndim, other->ndim) && equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(shape); + hash_reduce(dtype); + hash_reduce(ndim); + } + + static constexpr const char* _type_key = "relax.TensorStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to TensorStructInfoNode. + * \sa TensorStructInfoNode + */ +class TensorStructInfo : public StructInfo { + public: + /*! + * \brief Construction with a known shape expression. + * \param shape The shape of the tensor. + * \param dtype The data type of tensor's elements. + * \param span The span of the AST. + * + * \note shape must already be normalized. + */ + TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Span span = Span()); + + /*! + * \brief Construction with an unknown shape expression. + * \param dtype The data type of tensor's elements. + * \param ndim The number of dimensions + * \param span The span of the AST. + */ + TVM_DLL TensorStructInfo(DataType dtype, int ndim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode); +}; + +/*! + * \brief StructInfo of Tuple. + */ +class TupleStructInfoNode : public StructInfoNode { + public: + /*! \brief The struct info of tuple fields. */ + Array fields; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("fields", &fields); + v->Visit("span", &span); + } + + bool SEqualReduce(const TupleStructInfoNode* other, SEqualReducer equal) const { + return equal(fields, other->fields); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } + + static constexpr const char* _type_key = "relax.TupleStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to TupleStructInfoNode. + * \sa TupleStructInfoNode + */ +class TupleStructInfo : public StructInfo { + public: + /*! + * \brief Constructor + * \param fields Struct info of tuple fields. + * \param span The span of the AST. + */ + TVM_DLL TupleStructInfo(Array fields, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, TupleStructInfoNode); +}; + +/*! + * \brief custom-defined StructInfo derivation function. + * \param call The call expression to be derived. + * \param ctx The builder context. + * \return The derived struct info of the call. + */ +using StructInfoDeriveFunc = TypedEnvFunc; + +/*! + * \brief Structure information about function. + * + * This data structure contains enough information for us to + * do best-effort structure information deduction. + */ +class FuncStructInfoNode : public StructInfoNode { + public: + /*! + * \brief The parameter struct info of the function. + * \note When params is NullOpt means the function can take arbitrary number of arguments. + * We define such functions as Opaque function. + */ + Optional> params; + /*! + * \brief The struct info of the function's return value. + */ + StructInfo ret; + /*! + * \brief Derivation function of opaque functions that may take any number of parameters. + * \note When derive_func is not empty, then params should be NullOpt, + * ret should be ObjectStructInfo() + */ + Optional derive_func; + + /*! + * \return Whether the func struct info is opaque. + * \note We define a function as opaque we have no constraints on params. + */ + bool IsOpaque() const { return !params.defined(); } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("ret", &ret); + v->Visit("derive_func", &derive_func); + v->Visit("span", &span); + } + + bool SEqualReduce(const FuncStructInfoNode* other, SEqualReducer equal) const { + return equal.DefEqual(params, other->params) && equal(ret, other->ret) && + equal(derive_func, other->derive_func); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(params); + hash_reduce(ret); + hash_reduce(derive_func); + } + + static constexpr const char* _type_key = "relax.FuncStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuncStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to FuncStructInfoNode. + * \sa FuncStructInfoNode + */ +class FuncStructInfo : public StructInfo { + public: + /*! + * \brief Constructor from parameter struct info and return value struct info. + * \param params The struct info of function parameters. + * \param ret The return value struct info. + * \param span The span of the AST. + * + * \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from + * params. If you are unsure, you can always erase ret to static. + */ + TVM_DLL FuncStructInfo(Array params, StructInfo ret, Span span = Span()); + + /*! + * \brief Constructing an opaque function struct info using derive_func. + * + * \param derive_func Derivation function. + * \param span The span of the AST. + * + * \return The FuncStructInfo for opaque packedfunc. + * \note Defaults to an derive func that always return ObjectStructInfo if not specified. + */ + TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, Span span = Span()); + + /*! + * \brief Construct an opaque function using from return struct info. + * + * \param ret The struct info of the return value. + * \param span The span of the AST. + * + * \return The FuncStructInfo for opaque packedfunc. + * \note Defaults to an derive func that always return ObjectStructInfo if not specified. + */ + TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode); +}; + +/*! + * \brief Match and check if expr have StructInfo T and return it. + * + * \param expr The input expression. + * \return The result of match. + * \tparam T the underlying structure info type + */ +template +inline Optional MatchStructInfo(const Expr& expr) { + using TNode = typename T::ContainerType; + if (const TNode* ptr = expr->struct_info_.as()) { + return GetRef(ptr); + } else { + return NullOpt; + } +} + +/*! + * \brief Get the structure info of a given expr and try to cast it as const T*. + * + * \param expr The input expression. + * \return The pointer. Returns nullptr if the type does not match + * \tparam T the underlying structure info type + */ +template +inline const T* GetStructInfoAs(const Expr& expr) { + ICHECK(expr->struct_info_.defined()) + << "The struct_info is not populated, check if you have normalized the expr"; + return expr->struct_info_.as(); +} + +/*! + * \brief Get the underlying structure info of expr. + * + * \param expr The input expression. + * \return underlying struct info. + */ +inline StructInfo GetStructInfo(const Expr& expr) { + auto* ptr = expr->struct_info_.as(); + ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr"; + return GetRef(ptr); +} + +/*! + * \brief Whether the expr has void struct info. + * + * \param expr The input expression. + * \return Whether the expr has void struct info. + */ +inline bool HasVoidStructInfo(const Expr& expr) { + auto* ptr = expr->struct_info_.as(); + return ptr != nullptr && ptr->fields.size() == 0; +} + +/*! + * \brief Update the struct info of an Expr. + * \param expr The Expr whose struct info to be updated. + * \param struct_info The struct_info assigned. + * \note We ensure idempotence, that is we can only update the struct_info of an Expr only + * if the original one is nullptr. + */ +TVM_DLL void UpdateStructInfo(Expr expr, StructInfo struct_info); + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_STRUCT_INFO_H_ diff --git a/include/tvm/relax/struct_info_functor.h b/include/tvm/relax/struct_info_functor.h new file mode 100644 index 000000000000..382b4ab2c936 --- /dev/null +++ b/include/tvm/relax/struct_info_functor.h @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/struct_info_functor.h + * \brief Functors and visitors for struct info. + */ +#ifndef TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ +#define TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ + +#include +#include + +#include + +namespace tvm { +namespace relax { + +template +class StructInfoFunctor; + +// functions to be overriden. +#define STRUCT_INFO_FUNCTOR_DEFAULT \ + { return VisitStructInfoDefault_(op, std::forward(args)...); } + +#define TVM_STRUCT_INFO_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitStructInfo_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class StructInfoFunctor { + private: + using TSelf = StructInfoFunctor; + using FStructInfo = tvm::NodeFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~StructInfoFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const StructInfo& n, Args... args) { + return VisitStructInfo(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitStructInfo(const StructInfo& n, Args... args) { + ICHECK(n.defined()); + static FStructInfo vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitStructInfo_(const ObjectStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const PrimStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const ShapeStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const TensorStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const TupleStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const FuncStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfoDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; // unreachable, written to stop compiler warning + } + + private: + // initialize the vtable. + static FStructInfo InitVTable() { + FStructInfo vtable; + // Set dispatch + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(ObjectStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(PrimStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(ShapeStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TensorStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TupleStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(FuncStructInfoNode); + return vtable; + } +}; + +#undef TVM_STRUCT_INFO_FUNCTOR_DISPATCH + +/*! + * \brief A struct info visitor. + */ +class TVM_DLL StructInfoVisitor : public StructInfoFunctor { + public: + void VisitStructInfo_(const ObjectStructInfoNode* op) override; + void VisitStructInfo_(const PrimStructInfoNode* op) override; + void VisitStructInfo_(const ShapeStructInfoNode* op) override; + void VisitStructInfo_(const TensorStructInfoNode* op) override; + void VisitStructInfo_(const TupleStructInfoNode* op) override; + void VisitStructInfo_(const FuncStructInfoNode* op) override; + + protected: + // two functions to override when visit expr fields in struct info. + virtual void VisitStructInfoExprField(const Expr& expr) {} + virtual void VisitStructInfoExprField(const PrimExpr& expr) {} +}; + +/*! + * \brief StructInfoMutator that mutates struct info. + */ +class TVM_DLL StructInfoMutator : public StructInfoFunctor { + public: + StructInfo VisitStructInfo_(const ObjectStructInfoNode* op) override; + StructInfo VisitStructInfo_(const PrimStructInfoNode* op) override; + StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) override; + StructInfo VisitStructInfo_(const TensorStructInfoNode* op) override; + StructInfo VisitStructInfo_(const TupleStructInfoNode* op) override; + StructInfo VisitStructInfo_(const FuncStructInfoNode* op) override; + + protected: + // two functions to override when visit expr fields in struct info. + virtual Expr VisitStructInfoExprField(const Expr& expr) { return expr; } + virtual PrimExpr VisitStructInfoExprField(const PrimExpr& expr) { return expr; } +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ diff --git a/include/tvm/relax/tir_pattern.h b/include/tvm/relax/tir_pattern.h new file mode 100644 index 000000000000..02634dcbbf71 --- /dev/null +++ b/include/tvm/relax/tir_pattern.h @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir_pattern.h + * \brief Data Structure of TIR Pattern used for matching. + */ + +#ifndef TVM_RELAX_TIR_PATTERN_H_ +#define TVM_RELAX_TIR_PATTERN_H_ + +#include + +namespace tvm { +namespace relax { + +using TIRPattern = tir::PrimFunc; + +/* + * \brief The match result of a TIR pattern. + */ +class MatchResultNode : public Object { + public: + /*! The matched tir pattern*/ + TIRPattern pattern; + /*! \brief The evaluated values of symbolic vars. */ + Array symbol_values; + /*! \brief The matched buffers of input and output. */ + Array matched_buffers; + void VisitAttrs(AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("symbol_values", &symbol_values); + v->Visit("matched_buffers", &matched_buffers); + } + static constexpr const char* _type_key = "relax.MatchResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(MatchResultNode, Object); +}; + +/*! + * \brief Managed reference to MatchResultNode. + */ +class MatchResult : public ObjectRef { + public: + /*! + * \brief Constructor + * \param pattern The matched tir pattern. + * \param symbol_values The evaluated values of symbolic vars. + * \param matched_buffers The matched buffers of input and output. + */ + TVM_DLL explicit MatchResult(TIRPattern pattern, Array symbol_values, + Array matched_buffers); + + TVM_DEFINE_OBJECT_REF_METHODS(MatchResult, ObjectRef, MatchResultNode) +}; + +using FCodegen = runtime::TypedPackedFunc(Array match_results)>; +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_TIR_PATTERN_H_ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h new file mode 100644 index 000000000000..f6acf80bebf1 --- /dev/null +++ b/include/tvm/relax/transform.h @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform.h + * \brief Relax specific transformation passes. + */ +#ifndef TVM_RELAX_TRANSFORM_H_ +#define TVM_RELAX_TRANSFORM_H_ + +#include +#include +#include +#include +#include +namespace tvm { +namespace relax { +namespace transform { + +using Pass = tvm::transform::Pass; +using PassInfo = tvm::transform::PassInfo; +using PassContext = tvm::transform::PassContext; +using Function = tvm::relax::Function; +using DataflowBlock = tvm::relax::DataflowBlock; + +/*! + * \brief Create a function pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the function pass. + * \param name The name of the function pass. + * \param required The list of the passes that the function pass is dependent on. + * \param traceable Boolean variable whether the dataflowblock pass is traceable. + * + * \return The created function pass. + */ +TVM_DLL Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required, bool traceable = false); + +/*! + * \brief Create a dataflowblock pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the dataflowblock pass. + * \param name The name of the dataflowblock pass. + * \param required The list of the passes that the dataflowblock pass is dependent on. + * \param traceable Boolean variable whether the dataflowblock pass is traceable. + * + * \return The created dataflowblock pass. + */ +TVM_DLL Pass CreateDataflowBlockPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required, bool traceable = false); + +/*! + * \brief Perform lambda lifting to lift functions from nested into global. + * + * \return The Pass. + */ +TVM_DLL Pass LambdaLift(); + +/*! + * \brief Transform all dataflow structure to non-dataflow version. + * + * \return The Pass. + */ +TVM_DLL Pass ToNonDataflow(); + +/*! + * \brief Perform explicit tensor allocation for call_tir and call_dps_packed. + * + * \return The Pass. + */ +TVM_DLL Pass CallTIRRewrite(); + +/*! + * \brief Convert all reshape-like call_tir whose corresponding binding + * vars are DataflowVars to relax.reshape operator calls. The relax.reshape + * calls will be lowered an external builtin function call in a subsequent + * pass, where the external builtin function does a CreateView operation + * at runtime, instead of doing real data copy. + * Here "reshape-like" includes reshape, expand_dims, flatten, etc. + * + * \return The Pass. + * \note The pass is applied at the first stage of Relax VM build, before + * rewriting call_tir, as this pass requires dataflow information. + */ +TVM_DLL Pass RewriteDataflowReshape(); + +/*! + * \brief The static memory planning pass on BindingBlock level. + * The pass will reuse allocated memory to its best effort, in order to + * reduce the total amount of allocated memory size. + * + * \return The pass. + */ +TVM_DLL Pass StaticPlanBlockMemory(); + +/*! + * \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen. + * + * \return The Pass. + */ +TVM_DLL Pass AttachGlobalSymbol(); + +/*! + * \brief Transform Relax IR to normal form: transform AST to A-normal form, and fill the + * checked_type_ and shape_ of expressions. + * + * \return The Pass. + */ +TVM_DLL Pass Normalize(); + +/*! + * \brief Simplify a Relax module by folding var bindings and match shape nodes. + * May include other forms of expression simplification in the future. + * Best used alongside constant folding and eliminating unused bindings. + * + * \return The Pass. + */ +TVM_DLL Pass CanonicalizeBindings(); + +/*! + * Eliminate common subexpressions within dataflow blocks. + * \return The pass that eliminates common subexpressions. + * + * \note For functions local to dataflow blocks, this pass performs + * CSE *within* those functions. + */ +TVM_DLL Pass EliminateCommonSubexpr(); + +/*! + * \brief Bind params of function of the module to constant tensors. + * + * \param func_name The name of the function to bind parameters. + * \param params The parameters to bind. + * + * \return The Pass. + */ +TVM_DLL Pass BindParams(String func_name, Map params); + +/*! + * \brief Fold constant expressions. + * + * \return The Pass. + */ +TVM_DLL Pass FoldConstant(); + +/*! + * \brief Legalize high-level operator calls in Relax functions to call_tir + * with corresponding low-level TIR PrimFuncs. + * + * For each high-level operator, we register the way of legalizing it as a + * function, which takes a context BlockBuilder and the Call being legalized + * as input, and returns the legalized call. Here the input BlockBuilder is + * mainly used for adding the PrimFunc created by call_te into the context + * IRModule. + * + * The legalization function for each operator is registered as an attribute (with + * attribute key `FLegalize`) of the operator. + * + * For customizability, the user can pass their own legalization by an optional customized map, + * with the key to be the operator name and value to be the legalization function. + * The default legalization function will be overridden by the customized one. + * + * \param cmap The customized operator legalization function map. The customized function + * will override the default one. + * \return The Pass. + */ +TVM_DLL Pass LegalizeOps(Optional> cmap); + +/*! + * \brief Lift transformation of the parameters of a function. + * + * When some inputs of the function is marked as 'parameters' (the model weights), this pass + * identifies the transformation of the parameters and lifts them to a separate function called + * `transform_params`. `transform_params` takes a tuple of the original parameters as input and + * returns a tuple of the transformed parameters. The original function will be rewritten to accept + * a tuple of transformed parameters as input. + * + * Users are expected to invoke the `transform_params` function in runtime and pass the transformed + * parameters to the original function as input. + * + * \return The Pass. + */ +TVM_DLL Pass LiftTransformParams(); + +/*! + * \brief Annotate Op Pattern Kind for TIR functions, which is used in FuseOps. + * \note It is an auto-detect pass for "unscheduled prim_funcs", the op_pattern will be + * "opaque" of we can't detect it. Users can manually annotate the attr `op_pattern` + * to prim_func. + * \return The Pass. + */ +TVM_DLL Pass AnnotateTIROpPattern(); + +/*! + * \brief This pass groups bindings in a dataflow block of Relax functions and generates a new + * grouped Relax function for each group, according to the fusion algorithm described in the pass + * implementation. By grouping bindings into new Relax functions, we substitute the bindings in the + * function being manipulated into function calls to the new grouped function. + * + * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. + * \param fuse_opt_level The level of fuse optimization. + * -1 indicates that the level will be inferred from pass context. + * \return The Pass. + */ +TVM_DLL Pass FuseOps(int fuse_opt_level = -1); + +/*! + * \brief The pattern object used as the input of FuseOpsByPattern. For bindings to be + * fused, it needs to be matched with `pattern` and the `check` function needs to return + * true. + */ +class FusionPatternNode : public Object { + public: + /*! + * \brief The name of pattern. It becomes the value of the kComposite attribute + * of a fused function after successful matching + */ + String name; + + /*! + * \brief The dataflow pattern that will be used to match expression in the DataflowBlock. + * All the call nodes covered by the pattern will be extracted into the fused function. + */ + DFPattern pattern; + + /*! + * \brief The map which is used to extract important expressions from the pattern match + * result. All DFPattern in this map should be part of the `pattern`. + */ + Map annotation_patterns; + + /*! + * \brief The function to determine whether the match result is accepted. This can be + * NullOpt if check function is not necessary for this pattern. + * + * It should have signature + * bool(const PatternCheckContext& context) + */ + Optional check; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("pattern", &pattern); + v->Visit("annotation_patterns", &annotation_patterns); + v->Visit("check", &check); + } + + static constexpr const char* _type_key = "relax.transform.FusionPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(FusionPatternNode, Object); +}; + +class FusionPattern : public ObjectRef { + public: + FusionPattern(String name, DFPattern pattern, Map annotation_patterns, + Optional check); + + FusionPattern(String name, DFPattern pattern) : FusionPattern(name, pattern, {}, NullOpt) {} + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FusionPattern, ObjectRef, FusionPatternNode); +}; + +/*! + * \brief The input of FusionPattern::check. + */ +class PatternCheckContextNode : public Object { + public: + /*! + * \brief The expression that's matched with the FusionPattern::pattern. + */ + Expr matched_expr; + + /*! + * \brief A map which contains all expressions matched by the sub patterns in + * FusionPattern::annotation_patterns. + */ + Map annotated_expr; + + /*! + * \brief Map from variable to its value. It contains variables from bindings that + * is being fused by FuseOpsByPattern. + */ + Map matched_bindings; + + /*! + * \brief A map mapping variable definitions to a set of uses. It has all variables + * used in the function. + */ + Map> var_usages; + + /*! + * \brief Map from value to its bound variable. It doesn't have variables after the + * matched expression. + */ + Map value_to_bound_var; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("matched_expr", &matched_expr); + v->Visit("annotated_expr", &annotated_expr); + v->Visit("matched_bindings", &matched_bindings); + v->Visit("var_usages", &var_usages); + v->Visit("value_to_bound_var", &value_to_bound_var); + } + + static constexpr const char* _type_key = "relax.transform.PatternCheckContext"; + TVM_DECLARE_FINAL_OBJECT_INFO(PatternCheckContextNode, Object); +}; + +class PatternCheckContext : public ObjectRef { + public: + PatternCheckContext(Expr matched_expr, Map annotated_expr, + Map matched_bindings, Map> var_usages, + Map value_to_bound_var); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternCheckContext, ObjectRef, + PatternCheckContextNode); +}; + +/*! + * \brief Apply pattern matching to each function in the given module, and group matched + * expressions into a new function. The end result is similar to FuseOps, but fusion is driven + * completely by the provided patterns. + * + * \param patterns The patterns to detect. The order of the patterns determines the order + * of priority in which they are matched. Higher-priority patterns should come earlier in the list. + * \param bind_constants Whether or not to keep bound constants of the grouped function. + * \param annotate_codegen If true, wrap each created composite function with another function, + * whose body consists only of a call to the composite function, and annotate the outer function + * with kCodegen and kGlobalSymbol attributes. The kCodegen attribute is set as the prefix of the + * corresponding pattern name. For example, "dnnl" if the pattern name is "dnnl.conv2d_relu". + * This must be True if the created composite functions are intended to be offloaded to + * an external backend without using the MergeCompositeFunctions pass. + * \return The Pass. + */ +TVM_DLL Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_constants = true, + bool annotate_codegen = false); + +/*! + * \brief Group one or multiple composite functions created by FuseOpsByPattern into a new + * function. The new function will be annotated with kCodegen and GlobalSymbol attributes, + * and it is intented to be offloaded to an external backend. + * + * \return The Pass. + */ +TVM_DLL Pass MergeCompositeFunctions(); + +/*! + * \brief Fuse relax sub-function into a larger TIR function if possible. + this pass works together with FuseOps to perform operator fusion. + + * \return The Pass. + */ +TVM_DLL Pass FuseTIR(); + +/*! + * \brief Run codegen. + * \param target_options pairs of target name and compilation options + * \param entry_functions list of entry functions + * \return The Pass. + */ +TVM_DLL Pass RunCodegen(Optional>> target_options, + Array entry_functions); + +/*! + * \brief Decompose composite operators during inference. For example, the result + * of a batch norm which is indexed at tuple index 0 will be unpacked into a + * number of simplified operators. Operators like Attention, Erf, etc. can be also + * simplified into several operators as well. + * \return The Pass. + */ +TVM_DLL Pass DecomposeCompositeOperator(); + +/*! + * \brief Returns a pass which replaces PrimFuncs which have matching kOperatorName attribute in \p + * op_impl_map, with replacement PrimFunc that could possibly have different layouts on i/o + * buffers. The layout transformations on i/o buffers is present in the \p op_buffer_transforms. The + * pass inserts the layout transformations in the call sites of PrimFuncs being replaced to + * transform i/o buffers into expected layout. + * + * \param op_impl_map Map from from kOperatorName attr (e.g., relax.conv2d) to replacement PrimFunc + * \param op_buffer_transforms Map from kOperatorName attr to layout transformations on each of the + * PrimFunc i/o buffers. + * \return The Pass. + */ +TVM_DLL Pass AlterOpImpl(const Map& op_impl_map, + const Map>& op_buffer_transforms); + +/*! + * \brief Layout conversion pass. + * \param desired_layouts The desired layouts for some operators. + * \return The Pass. + */ +TVM_DLL Pass ConvertLayout(Map> desired_layouts); + +/*! + * \brief Dead code elimination. + * \sa RemoveAllUnused + * Currently it removes: + * 1. Unused local VarBindings in a DataflowBlock. + * 2. Unused DataflowBlocks in a function. + * 3. Unused Relax functions in the module. + * We detect the call chain from the entry function, and remove all unused functions. + * \return The Pass. + */ +TVM_DLL Pass DeadCodeElimination(Array entry_functions); + +/*! + * \brief Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 + * only, and will automatically cast fp32 to fp16 for certain ops. + * \param out_dtype The output data type of gemm/conv, which is the data type of the accumulator. + * \return The Pass. + */ +TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype); + +} // namespace transform +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_TRANSFORM_H_ diff --git a/include/tvm/relax/tuning_api.h b/include/tvm/relax/tuning_api.h new file mode 100644 index 000000000000..b6224a6d6d9e --- /dev/null +++ b/include/tvm/relax/tuning_api.h @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/tuning_api.h + * \brief Relax Tuning Pass APIs. + */ +#ifndef TVM_RELAX_TUNING_API_H_ +#define TVM_RELAX_TUNING_API_H_ +#include +#include +#include + +#include +namespace tvm { +namespace relax { + +/*! \brief Helper function to unpack arguments in the array as parameters for the given packed + * function. */ +TVM_ALWAYS_INLINE TVMRetValue CallPackedWithArgsInArray(const runtime::PackedFunc f, + const Array& args) { + size_t num_args = args.size(); + std::vector values(num_args); + std::vector codes(num_args); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + const ObjectRef* ptr = args.template as()->begin(); + for (size_t i = 0; i < num_args; ++i) { + setter(i, *(ptr + i)); + } + + TVMRetValue rv; + f.CallPacked(TVMArgs(values.data(), codes.data(), num_args), &rv); + return rv; +} + +/*! \brief Choice manages a set of keys for transformation and constraint functions. */ +class ChoiceNode : public runtime::Object { + public: + /*! \brief ffi key for transformation function. */ + String transform_func_key; + /*! \brief ffi key for constraint function. */ + String constr_func_key; + Array transform_func_args; + Array constr_func_args; + + /*! \brief The default destructor. */ + virtual ~ChoiceNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("transform_func_key", &transform_func_key); + v->Visit("transform_func_args", &transform_func_args); + v->Visit("constr_func_key", &constr_func_key); + v->Visit("constr_func_args", &constr_func_args); + } + + /*! \brief Getter for constr_func. */ + const runtime::PackedFunc GetConstrFunc() { + const auto* constr_func = tvm::runtime::Registry::Get(constr_func_key); + ICHECK(constr_func != nullptr) << "constr_func_key is not registered: " << constr_func_key; + return *constr_func; + } + + /*! \brief Getter for transform_func. */ + const runtime::PackedFunc GetTransformFunc() { + auto* transform_func = tvm::runtime::Registry::Get(transform_func_key); + ICHECK(transform_func != nullptr) + << "transform_func_key is not registered: " << transform_func_key; + return *transform_func; + } + + /*! \brief Perform constr_func. */ + bool CheckConstr(const IRModule& mod) { + Array args(constr_func_args); + args.insert(args.begin(), mod); + return CallPackedWithArgsInArray(GetConstrFunc(), args); + } + + /*! \brief Perform transform_func. */ + IRModule ApplyTransformFunc(IRModule mod) { + // Apply transformation when constraint is satisfied. + if (CheckConstr(mod)) { + Array args(transform_func_args); + args.insert(args.begin(), GetRef(mod.CopyOnWrite())); + return CallPackedWithArgsInArray(GetTransformFunc(), args); + } + return mod; + } + + /*! + * \brief Serialize Choice as a JSON-style object + * \return The JSON-style object + */ + ObjectRef AsJSON() const; + + static constexpr const char* _type_key = "relax.tuning_api.Choice"; + TVM_DECLARE_BASE_OBJECT_INFO(ChoiceNode, Object); +}; + +/*! \brief Managed reference to ChoiceNode */ +class Choice : public runtime::ObjectRef { + public: + TVM_DLL explicit Choice(String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args); + /*! \brief Deserialize JSON-style object into Choice */ + TVM_DLL static Choice FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Choice, ObjectRef, ChoiceNode); +}; + +/*! \brief Knob manages a set of valid choices for an optimization. */ +class KnobNode : public runtime::Object { + public: + /*! \brief Name of the knob. */ + String name; + /*! \brief Decision space. */ + Map choices; + + /*! \brief The default destructor. */ + virtual ~KnobNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("choices", &choices); + } + + /*! \brief Check if a decision is valid. */ + bool IsValidDecision(String decision) { return choices.count(decision) > 0; } + + /*! \brief Apply decision if the constraint is satisfied. + Otherwise, return the original IRModule. + */ + IRModule Apply(IRModule mod, String decision) { + ICHECK(IsValidDecision(decision)) << "Invalid choice for this knob: " << decision; + return choices[decision]->ApplyTransformFunc(mod); + } + + /*! + * \brief Serialize Knob as a JSON-style object + * \return The JSON-style object + */ + ObjectRef AsJSON() const; + + static constexpr const char* _type_key = "relax.tuning_api.Knob"; + TVM_DECLARE_BASE_OBJECT_INFO(KnobNode, Object); +}; + +/*! \brief Managed reference to KnobNode */ +class Knob : public runtime::ObjectRef { + public: + TVM_DLL explicit Knob(String name, Map choices); + /*! \brief Deserialize JSON-style object into Knob */ + TVM_DLL static Knob FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Knob, ObjectRef, KnobNode); +}; + +/*! \brief Trace manages history of optimization decisions. */ +class TraceNode : public runtime::Object { + public: + /*! \brief Input IRModule. */ + IRModule in_mod; + /*! \brief Output IRModule. */ + mutable IRModule out_mod; + // TODO(sunggg): can we move knobs and decisions into private? + /*! \brief Knobs that are applied so far. */ + Array knobs; + /*! \brief Decisions made for the knobs. */ + Array decisions; + /*! \brief Performance of out_mod. */ + mutable double perf = -1; + /*! \brief Length of the decision history. */ + mutable int size = 0; + /*! \brief The default destructor. */ + virtual ~TraceNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("in_mod", &in_mod); + v->Visit("out_mod", &out_mod); + v->Visit("knobs", &knobs); + v->Visit("decisions", &decisions); + v->Visit("perf", &perf); + v->Visit("size", &size); + } + + /*! \brief Verify current decision history. */ + bool Verify() const { + if (knobs.size() != decisions.size()) return false; + int n = knobs.size(); + for (int i = 0; i < n; i++) { + if (!knobs[i]->IsValidDecision(decisions[i])) return false; + } + return true; + } + + /*! \brief Add a knob and its decision to the current trace. */ + IRModule Add(Knob knob, String decision) { + out_mod = knob->Apply(out_mod, decision); + knobs.push_back(knob); + decisions.push_back(decision); + // perf number should be initialized after new decision is applied. + perf = -1; + // increment history size. + size++; + return out_mod; + } + + /*! + * \brief Serialize Trace as a JSON-style object + * \param include_in_mod Boolean config to include input IRModule in the output. + * \return The JSON-style object + */ + ObjectRef AsJSON(bool include_in_mod = true) const; + + /*! \brief Set the performance. */ + void SetPerf(double _perf) { perf = _perf; } + /*! \brief Set output module. */ + void SetOutMod(IRModule mod_) { out_mod = mod_; } + + static constexpr const char* _type_key = "relax.tuning_api.Trace"; + TVM_DECLARE_BASE_OBJECT_INFO(TraceNode, Object); +}; + +/*! \brief Managed reference to TraceNode */ +class Trace : public runtime::ObjectRef { + public: + /*! \brief Default constructor. Creating an empty trace. */ + Trace(); + /*! + * \brief Constructor. Creating a trace from existing knobs and their decisions + * \param in_mod Input IRModule + * \param knobs The knobs used + * \param decisions The decisions made in sampling + */ + TVM_DLL explicit Trace(IRModule in_mod, Array knobs, Array decisions); + /*! \brief Deserialize JSON-style object into Trace */ + TVM_DLL static Trace FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, ObjectRef, TraceNode); +}; + +/*! \brief The class of tuning records. */ +class TuningRecordNode : public runtime::Object { + public: + /*! \brief The trace tuned. */ + Trace trace; + /*! \brief The measurement record in seconds. */ + Optional> run_secs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("trace", &trace); + v->Visit("run_secs", &run_secs); + } + + static constexpr const char* _type_key = "relax.tuning_api.TuningRecord"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object); + + /*! + * \brief Export the tuning record to a JSON string. + * \param include_irmod Boolean config to include IRModules in the output. + * \return JSON object + */ + ObjectRef AsJSON(bool include_irmod = false) const; +}; + +/*! + * \brief The managed reference of TuningRecordNode. + * \sa TuningRecordNode + */ +class TuningRecord : public runtime::ObjectRef { + public: + /*! + \brief Constructor of a tuning record. + \param trace The trace of the tuning record. + \param run_secs The running time of the tuning record. + */ + TVM_DLL explicit TuningRecord(Trace trace, Optional> run_secs); + /*! + * \brief Create a tuning record from a json object. + * \param json_obj The json object. + * \return The tuning record created. + */ + TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode); +}; + +/*! \brief The equality check for Workload */ +struct WorkloadEqual { + bool operator()(const meta_schedule::Workload& a, const meta_schedule::Workload& b) const { + return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod); + } +}; + +/* \brief The abstract interface of database. */ +class DatabaseNode : public runtime::Object { + public: + /*! \brief Default destructor */ + virtual ~DatabaseNode() = default; + /*! + * \brief Check if the database has the given workload. + * \param mod The IRModule to be searched for. + * \return Whether the database has the given workload. + */ + virtual bool HasWorkload(const IRModule& mod) = 0; + /*! + * \brief Check if the database has a measurement record for the given workload and target pair. + * \param workload The workload to be searched for. + * \param target The target to be searched for. + * \return Whether the database has the measurement record for given workload and target pair. + */ + virtual bool HasMeasurementRecord(const meta_schedule::Workload& workload, + const Target& target) = 0; + /*! + * \brief Check if the database has a tuning record for the given workload and target pair. + * \param workload The workload to be searched for. + * \param target The target to be searched for. + * \return Whether the database has the tuning record for the given workload and target pair. + */ + virtual bool HasTuningRecord(const meta_schedule::Workload& workload, const Target& target) = 0; + /*! + * \brief Look up or add workload to the database if missing. + * \param mod The IRModule to be searched for or added. + * \return The workload corresponding to the given IRModule. + */ + virtual meta_schedule::Workload CommitWorkload(const IRModule& mod) = 0; + /*! + * \brief Add a measurement record for a given pair of target and workload to the database. + * \param workload Workload to be searched for. + * \param target Target to be searched for. + * \param record Measurement record to be added. + */ + virtual void CommitMeasurementRecord(const meta_schedule::Workload& workload, + const Target& target, const Array& record) = 0; + /*! + * \brief Add a tuning record for a given pair of target and workload to the database. + * \param workload Workload to be searched for. + * \param target Target to be searched for. + * \param record Tuning record to be added. + */ + virtual void CommitTuningRecord(const meta_schedule::Workload& workload, const Target& target, + const TuningRecord& record) = 0; + /*! + * \brief Get the top K tuning records of given workload and target from the database. + * \param workload The workload to be searched for. + * \param target Target to be searched for. + * \param top_k The number of top records to be returned. + * \return An array of top K tuning records for the given workload. + */ + virtual Array GetTopK(const meta_schedule::Workload& workload, const Target& target, + int top_k) = 0; + /*! + * \brief Get the measurement record of given workload and target from the database. + * \param workload The workload to be searched for. + * \param target Target to be searched for. + * \return Measurement. + */ + virtual Array GetMeasurementRecord(const meta_schedule::Workload& workload, + const Target target) = 0; + + static constexpr const char* _type_key = "relax.tuning_api.Database"; + TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object); +}; + +/*! + * \brief Managed reference to DatabaseNode. + * \sa DatabaseNode + */ +class Database : public runtime::ObjectRef { + public: + /*! + * \brief Create a default database that uses JSON file for tuning records. + * \param path_workload The path to the workload table. + * \param path_tuning_record The path to the tuning record table. + * \param path_measurement_record The path to the measurement_record table. + * \param allow_missing Whether to create new file when the given path is not found. + */ + TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record, + String path_measurement_record, bool allow_missing); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode); +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_TUNING_API_H_ diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h new file mode 100644 index 000000000000..9c20a524353a --- /dev/null +++ b/include/tvm/relax/type.h @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/type.h + * \brief Relax Types. + */ +#ifndef TVM_RELAX_TYPE_H_ +#define TVM_RELAX_TYPE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! \brief Indicates the number of dimensions of a tensor is unknown at compile time. */ +static constexpr int kUnknownNDim = -1; + +class ShapeTypeNode : public TypeNode { + public: + /*! \brief size of the shape. */ + int ndim; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { + return equal(ndim, other->ndim); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(ndim); } + + static constexpr const char* _type_key = "relax.ShapeType"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTypeNode, TypeNode); +}; + +class ShapeType : public Type { + public: + // TODO(relax-team): remove the default value later. + TVM_DLL ShapeType(int ndim = kUnknownNDim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode); +}; + +class ObjectTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const ObjectTypeNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.ObjectType"; + TVM_DECLARE_FINAL_OBJECT_INFO(ObjectTypeNode, TypeNode); +}; + +class ObjectType : public Type { + public: + TVM_DLL ObjectType(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectType, Type, ObjectTypeNode); +}; + +class DynTensorTypeNode : public BaseTensorTypeNode { + public: + /*! + * \brief The number of dimensions of the tensor, use -1 to denote tensor with unknwon number of + * dimensions. + */ + int ndim; + /*! \brief The content data type, use void to denote the dtype is unknown. */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("ndim", &ndim); + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + bool SEqualReduce(const DynTensorTypeNode* other, SEqualReducer equal) const { + return equal(ndim, other->ndim) && equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(ndim); + hash_reduce(dtype); + } + + inline bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + + inline bool IsUnknownDtype() const { return dtype.is_void(); } + + static constexpr const char* _type_key = "relax.DynTensorType"; + TVM_DECLARE_FINAL_OBJECT_INFO(DynTensorTypeNode, BaseTensorTypeNode); +}; + +/*! + * \brief Managed reference to DynTensorTypeNode. + * \sa DynTensorTypeNode. + */ +class DynTensorType : public Type { + public: + /*! + * \brief Constructor. + * \param ndim The number of dimensions of the tensor. + * \param dtype The runtime dtype of the tensor's elements. + * \param span The span. + */ + TVM_DLL DynTensorType(int ndim, DataType dtype, Span span = Span()); + + /*! + * \brief Create a DynTensorType with unknown ndim. + */ + TVM_DLL static DynTensorType CreateUnknownNDim(DataType dtype, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode); +}; + +class PackedFuncTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const PackedFuncTypeNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.PackedFuncType"; + TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncTypeNode, TypeNode); +}; + +class PackedFuncType : public Type { + public: + TVM_DLL PackedFuncType(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PackedFuncType, Type, PackedFuncTypeNode); +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_TYPE_H_ diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h new file mode 100644 index 000000000000..e7d928c4aef4 --- /dev/null +++ b/include/tvm/relax/utils.h @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/utils.h + * \brief Utility classes and functions for working with the Relax IR. + */ +#ifndef TVM_RELAX_UTILS_H_ +#define TVM_RELAX_UTILS_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Utility data structure for generating unique names for IR construction. + */ +class NameTable { + public: + /*! + * \brief Generate a unique name with a specified prefix. + * \param prefix The name prefix. + * \return The generated name. + */ + inline std::string GetUniqueName(std::string prefix) { + std::replace(prefix.begin(), prefix.end(), '.', '_'); + std::string unique_prefix = prefix; + auto it = alloc_map_.find(prefix); + if (it != alloc_map_.end()) { + while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) { + } + } + alloc_map_[unique_prefix] = 0; + return unique_prefix; + } + + NameTable() = default; + + template + explicit NameTable(Iter begin, Iter end, Lambda f) { + // static_assert is more reader-friendly than SFINAE when template specialization is not needed. + static_assert(std::is_convertible::value, + "Lambda f must has a signature of [?](*it) -> string {}"); + for (auto it = begin; it != end; ++it) { + const std::string& name = f(*it); + const size_t idx_last_first_num = std::distance( + std::find_if(name.rbegin(), name.rend(), [](char c) { return !std::isdigit(c); }), + name.rend()); + // name = {O = others}{D = consecutive digits} + // let O -> prefix; + std::string prefix = name.substr(0, idx_last_first_num); + ICHECK(prefix.size() > 0 && std::isalpha(prefix[0])) << "Invalid variable name: " << name; + if (0 == alloc_map_.count(prefix)) alloc_map_[prefix] = 0; + if (idx_last_first_num < name.size()) { // has some digits. + // let D's nearest natural number -> idx; + // note: stoul("000123") = 123; + alloc_map_[prefix] = + std::max(alloc_map_[prefix], std::stoi(name.substr(idx_last_first_num))); + } + } + } + + template + explicit NameTable(Iter begin, Iter end) + : NameTable(begin, end, [](const decltype(*begin)& v) { return v; }) {} + + private: + std::unordered_map alloc_map_; +}; + +/*! + * \brief Bind the variables to a Relax expression. This is a helper + * function usually called by other pass functions to help optimizations. + * If any free variables are introduced into a function, those are added + * to the function parameters. + * Additionally this may change the order of parameters if you map a variable + * to a variable. + * + * \param expr The input expression. + * \param binds The variable to expression map that will be used to help the + * binding. + * \param symbolic_var_map The map from symbolic var to the expr it binds to. + * + * \return The updated expression. + */ +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds, + const tvm::Map& symbolic_var_map = {}); + +/*! + * \brief Check if the given StructInfo is for a boolean scalar (tensor of rank 0 with a boolean + * dtype). + * + * \param sinfo The input StructInfo. + * \param permit_unknown_rank If true, it will permit the input type to have unknown rank + * (ndim of -1), which will require a dynamic check. + * \param permit_unknown_dtype If true, it will permit the input type to have an unknown dtype + * (namely, void), which will require a dynamic check. + * + * \return True iff the input type is a boolean scalar type (or, depending on options, has unknown + * rank or dtype) + */ +TVM_DLL bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank = true, + bool permit_unknown_dtype = true); + +/*! + * \brief Check if the given expression is a "leaf" node or tuple node for normalization purposes. + * + * The following expressions are defined as leaf nodes: Var, Constant, ShapeExpr, + * GlobalVar, Op, ExternFunc. + * + * Tuples are included in this list mainly for convenience in grouping operator arguments. + * *Note*: Since tuples can contain nested expressions, it is necessary to ensure that + * values nested inside them are also leaves. + * + * \param expr The input expression + * + * \return True iff the input expression is a "leaf" node (a value allowed to appear + * inline without being bound to a var during normalization). + */ +TVM_DLL bool IsLeafOrTuple(const Expr& expr); + +/*! + * \brief Copy the given function. All variables that are bound inside the original function + * would be copied to satisfy the restriction in the well-formed check: Variables in + * Relax must be bound exactly once. This also ensures that both the function and its copy + * can be inserted into the same IRModule, and be asserted on the structural equality + * agaisnt IRModule created by TVMScript. + * + * \param func The relax function to copy. + * \return The copied function. + */ +TVM_DLL Function CopyWithNewVars(Function func); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_UTILS_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 5f591f1d89ad..4adaac02ab3e 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -60,7 +60,7 @@ using Sequential = tvm::transform::Sequential; */ TVM_DLL Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required); + int opt_level, String name, tvm::Array required, bool traceable = false); /*! \brief Remove let-bound expressions which do not effect the program result. * diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 508b34b3517e..704eb1b576a8 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -218,6 +218,10 @@ class TVM_DLL ModuleNode : public Object { * \return The corresponding function. */ const PackedFunc* GetFuncFromEnv(const std::string& name); + + /*! \brief Clear all imports of the module. */ + void ClearImports() { imports_.clear(); } + /*! \return The module it imports from */ const std::vector& imports() const { return imports_; } diff --git a/include/tvm/runtime/relax_vm/builtin.h b/include/tvm/runtime/relax_vm/builtin.h new file mode 100644 index 000000000000..b994e44ae88d --- /dev/null +++ b/include/tvm/runtime/relax_vm/builtin.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/builtin.h + * \brief Builtin runtime APIs. + */ +#ifndef TVM_RUNTIME_RELAX_VM_BUILTIN_H_ +#define TVM_RUNTIME_RELAX_VM_BUILTIN_H_ + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief Op code used in built-in match-shape function. + * + * The function takes the following signature: + + * MatchShape(input_shape, shape_heap, n, c[0], r[0], c[1], r[1], ... c[n], r[n], err_ctx) + * + * This function provides runtime shape population and checking support for match-cast. + * When a shape variable appears in the first time, we should load the shape and + * populate the variable. When a shape variable already appears, we should + * assert that it already equals an existing shape value. + * + * NOTE: It is OK to pass nullptr shape_heap if all code are AssertEqualToImm. + */ +enum class MatchShapeCode : int { + /*! + * \brief Perform an assertion that shape equals immediate. + * + * assert input_shape[i] == r[i] + */ + kAssertEqualToImm = 0, + /*! + * \brief This is the first time we see a symbolic shape variable, store to heap. + * + * shape_heap[r[i]] = input_shape[i] + */ + kStoreToHeap = 1, + /*! + * \brief skip and do not do anything. + */ + kNoOp = 2, + /*! + * \brief Peform an assertion that the shape equals a loaded value. + * + * assert input_shape[i] == shape_heap[r[i]] + */ + kAssertEqualToLoad = 3, +}; + +/*! + * \brief Op code used in builtin function MakeShape. + * + * MakeShape(shape_heap, n, c[0], r[0], c[1], r[1], ... c[n], r[n]). + * + * \note It is OK to pass nullptr to shape_heap if all code are UseImm. + */ +enum class MakeShapeCode : int { + /*! \brief Use the following r[i] as immediate shape value. */ + kUseImm = 0, + /*! + * \brief Load shape value from the shape_heap[[r[i]]. + */ + kLoadShape = 1, +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RELAX_VM_BUILTIN_H_ diff --git a/include/tvm/runtime/relax_vm/bytecode.h b/include/tvm/runtime/relax_vm/bytecode.h new file mode 100644 index 000000000000..91d182325886 --- /dev/null +++ b/include/tvm/runtime/relax_vm/bytecode.h @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/bytecode.h + * \brief The bytecode for the virtual machine. + */ +#ifndef TVM_RUNTIME_RELAX_VM_BYTECODE_H_ +#define TVM_RUNTIME_RELAX_VM_BYTECODE_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief The storage type for the bytecode in the VM. + */ +using ExecWord = int64_t; + +/*! \brief A register name. */ +using RegName = ExecWord; + +/*! + * \brief An alias for the integer type used ubiquitously in the VM. + */ +using Index = ExecWord; + +/*! + * \brief An enumeration of Relax's opcodes. + * + * The opcode is used to implement instruction + * as a tagged union. + */ +enum class Opcode { + Call = 1U, + Ret = 2U, + Goto = 3U, + If = 4U, +}; + +/*! \brief A single virtual machine instruction. + * + * The representation of the instruction is as + * a tagged union. + * + * The first field represents which instruction, + * and by extension which field of the union + * is active. + */ +struct Instruction { + /*! \brief The number of bit for storing value. */ + static constexpr ExecWord kKindBit = 8; + /*! \brief The number of bit for storing value. */ + static constexpr ExecWord kValueBit = sizeof(ExecWord) * 8 - kKindBit; + /*! \brief The bit mask of the value part. */ + static constexpr ExecWord kValueMask = (static_cast(1) << kValueBit) - 1; + /*! \brief Maximum possible value, use 1 bit for sign. */ + static constexpr ExecWord kValueMaxLimit = (static_cast(1) << (kValueBit - 1)) - 1; + /*! \brief Minimum possible value, remove 1 slot to keep things symmetric. */ + static constexpr ExecWord kValueMinLimit = -kValueMaxLimit; + /*! \brief Begining of special register section. */ + static constexpr RegName kBeginSpecialReg = static_cast(1) << 54; + /*! \brief Random magic number that represents void argument, indicate null value */ + static constexpr RegName kVoidRegister = kBeginSpecialReg + 0; + /*! \brief Random magic number that represents the VM context */ + static constexpr RegName kVMRegister = kBeginSpecialReg + 1; + /*! + * \brief The kind of instruction's argument. + */ + enum class ArgKind : int { kRegister = 0, kImmediate = 1, kConstIdx = 2, kFuncIdx = 3 }; + /*! + * \brief The auxiliary data structure for instruction argument. + */ + class Arg { + public: + /*! \brief Construct a void argument. */ + Arg() : data_(Instruction::kVoidRegister) {} + /*! + * \brief construct Arg from data. + * \param data The data repr. + */ + static Arg FromData(ExecWord data) { return Arg(data); } + /*! + * \brief construct a register Arg. + * \param reg The register number. + * \return The constructed arg. + */ + static Arg Register(RegName reg) { return Arg(ArgKind::kRegister, reg); } + /*! + * \brief construct a ConstIdx arg. + * \param index The constant index. + * \return The constructed arg. + */ + static Arg ConstIdx(Index index) { return Arg(ArgKind::kConstIdx, index); } + /*! + * \brief construct a immediate arg. + * \param imm_value The immediate value. + * \return The constructed arg. + */ + static Arg Immediate(int64_t imm_value) { return Arg(ArgKind::kImmediate, imm_value); } + /*! + * \brief construct a FuncIdx arg. + * \param index The func index in the function table. + * \return The constructed arg. + */ + static Arg FuncIdx(Index index) { return Arg(ArgKind::kFuncIdx, index); } + /*! + * \brief Get the kind of argument.. + * \return The kind of argument. + */ + ArgKind kind() const { + uint8_t kind = (data_ >> kValueBit) & 0xFF; + return Instruction::ArgKind(kind); + } + /*! + * \brief Get the value of argument. + * \return The value of argument. + * \note We store both positive and negative values by sign extension. + */ + ExecWord value() const { return ((data_ & kValueMask) << kKindBit) >> kKindBit; } + /*! + * \brief Get the raw data repr of the arg. + * \return The raw data. + */ + ExecWord data() const { return data_; } + + private: + /*! \brief Construct from the data. */ + explicit Arg(ExecWord data) : data_(data) {} + /*! \brief Construct from the kind and value. */ + Arg(ArgKind kind, Index value) { + ICHECK_LE(value, kValueMaxLimit); + ICHECK_GE(value, kValueMinLimit); + data_ = (static_cast(kind) << kValueBit) | (value & kValueMask); + } + /*! \brief The underlying stored data. */ + ExecWord data_; + }; + /*! \brief The instruction opcode. */ + Opcode op; + union { + struct /* Call */ { + /*! \brief The destination register. */ + RegName dst; + /*! \brief The index into the packed function table. */ + Index func_idx; + /*! \brief The number of arguments to the packed function. */ + Index num_args; + /*! \brief The arguments of the packed function. */ + Arg* args; + }; + struct /* Ret */ { + /*! \brief The return result. */ + RegName result; + }; + struct /* Goto */ { + /*! \brief The jump offset. */ + Index pc_offset; + }; + struct /* If */ { + /*! \brief The register containing the cond value. */ + RegName cond; + /*! \brief The program counter offset for the false branch. */ + Index false_offset; + }; + }; + /*! + * \brief Construct a Call instruction. + * \param func_idx The index of the function to call. + * \param num_args The number of arguments. + * \param args The input arguments. + * \param dst The destination register. + * \return The call instruction. + */ + static Instruction Call(Index func_idx, Index num_args, Arg* args, RegName dst); + /*! + * \brief Construct a return instruction. + * \param result The register containing the return value. + * \return The return instruction. + */ + static Instruction Ret(RegName result); + /*! + * \brief Construct a goto instruction. + * \param pc_offset The register containing the jump offset. + * \return The goto instruction. + */ + static Instruction Goto(RegName pc_offset); + /*! + * \brief Construct an If instruction. + * \param cond The register containing the cond value. + * \param false_offset The program counter offset for the false branch. + * \return The If instruction. + */ + static Instruction If(RegName cond, Index false_offset); +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_BYTECODE_H_ diff --git a/include/tvm/runtime/relax_vm/executable.h b/include/tvm/runtime/relax_vm/executable.h new file mode 100644 index 000000000000..5833cb4718f0 --- /dev/null +++ b/include/tvm/runtime/relax_vm/executable.h @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/executable.h + */ +#ifndef TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ +#define TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ + +#include +#include +#include + +#include +#include +#include + +#include "./bytecode.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief Information entry in executable function table. + * + * Contains metadata about the compiled function, as + * well as the compiled VM instructions. + */ +struct VMFuncInfo { + /*! \brief kind of the function. */ + enum class FuncKind : int { + /*! \brief system level packed function */ + kPackedFunc = 0, + /*! \brief VM function. */ + kVMFunc = 1, + /*! \brief VMTIR function. */ + kVMTIRFunc = 2, + }; + /*! \brief The kind of function. */ + FuncKind kind; + /*! \brief The function's name, global symbol */ + std::string name; + /*! \brief The start instruction index of the function. */ + Index start_instr = 0; + /*! \brief The end instruction index of the function. */ + Index end_instr = 0; + /*! \brief The number of arguments of the function. */ + Index num_args = 0; + /*! \brief The register file size of the function. */ + Index register_file_size = 0; + /*! \brief The function parameter names.*/ + std::vector param_names; + + // defined customized loader save + void Save(dmlc::Stream* writer) const; + bool Load(dmlc::Stream* reader); +}; + +/*! + * \brief The executable emitted by the VM compiler. + * + * The executable contains information (e.g. data in different memory regions) + * to run in a virtual machine. + */ +class Executable : public runtime::ModuleNode { + public: + /*! + * \brief Get a PackedFunc from the executable module. + * \param name the name of the function. + * \param sptr_to_self The shared_ptr that points to this module node. + * \return PackedFunc or nullptr when it is not available. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; + + /*! + * \brief Print the detailed statistics of the given code, i.e. number of + * globals and constants, etc. + * \return The statistics represented by a string. + */ + std::string Stats() const; + /*! + * \brief Get the i-th instruction from the executable. + * \param i The index of the instruction to be fetched. + * \return The instruction. + */ + Instruction GetInstruction(Index i) const; + /*! + * \brief Set j-th byte data of i-th instruction to val. + * \param i The index of the instruction to be updated. + * \param j The index of the byte data of the instruction to be updated. + * \param val The value to be set + */ + void SetInstructionData(Index i, Index j, ExecWord val); + /*! + * \brief Print the instructions as text format. + * \return The text format of the instructions. + */ + String AsText() const; + /*! + * \brief Print the instructions as python program. + * \return The python program of the instructions, represented by a string. + */ + String AsPython() const; + /*! + * \brief Write the Executable to the binary stream in serialized form. + * \param stream The binary stream to save the executable to. + */ + void SaveToBinary(dmlc::Stream* stream) final; + /*! + * \brief Load Executable from the binary stream in serialized form. + * \param stream The binary stream that load the executable from. + * \return The loaded executable, in the form of a `runtime::Module`. + */ + static Module LoadFromBinary(void* stream); + /*! + * \brief Write the Executable to the provided path as a file containing its serialized content. + * \param file_name The name of the file to write the serialized data to. + * \param format The target format of the saved file. + */ + void SaveToFile(const std::string& file_name, const std::string& format) final; + /*! + * \brief Load Executable from the file. + * \param file_name The path of the file that load the executable from. + * \return The loaded executable, in the form of a `runtime::Module`. + */ + static Module LoadFromFile(const std::string& file_name); + + /*! \brief The virtual machine's function table. */ + std::vector func_table; + /*! \brief A map from globals (as strings) to their index in the function map. */ + std::unordered_map func_map; + /*! \brief The global constant pool. */ + std::vector constants; + /*! \brief The offset of instruction. */ + std::vector instr_offset; + /*! \brief The byte data of instruction. */ + std::vector instr_data; + + virtual ~Executable() {} + + const char* type_key() const final { return "relax.Executable"; } + + private: + /*! + * \brief Save the globals. + * \param strm The input stream. + */ + void SaveGlobalSection(dmlc::Stream* strm); + /*! + * \brief Save the constant pool. + * \param strm The input stream. + */ + void SaveConstantSection(dmlc::Stream* strm); + /*! + * \brief Save the instructions. + * \param strm The input stream. + */ + void SaveCodeSection(dmlc::Stream* strm); + /*! + * \brief Save the packed functions. + * \param strm The input stream. + */ + void SavePackedFuncNames(dmlc::Stream* strm); + /*! + * \brief Load the globals. + * \param strm The input stream. + */ + void LoadGlobalSection(dmlc::Stream* strm); + /*! + * \brief Load the constant pool. + * \param strm The input stream. + */ + void LoadConstantSection(dmlc::Stream* strm); + /*! + * \brief Load the instructions. + * \param strm The input stream. + */ + void LoadCodeSection(dmlc::Stream* strm); + /*! + * \brief Save the packed functions. + * \param strm The input stream. + */ + void LoadPackedFuncNames(dmlc::Stream* strm); +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +namespace dmlc { +DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::relax_vm::VMFuncInfo, true); +} // namespace dmlc +#endif // TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ diff --git a/include/tvm/runtime/relax_vm/memory_manager.h b/include/tvm/runtime/relax_vm/memory_manager.h new file mode 100644 index 000000000000..9234e9151ce0 --- /dev/null +++ b/include/tvm/runtime/relax_vm/memory_manager.h @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/memory_manager.h + * \brief Abstract device memory management API + */ +#ifndef TVM_RUNTIME_RELAX_VM_MEMORY_MANAGER_H_ +#define TVM_RUNTIME_RELAX_VM_MEMORY_MANAGER_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +struct Buffer { + /*! \brief The pointer to the allocated block of memory. */ + void* data{nullptr}; + /*! \brief The size of the block. */ + size_t size{0}; + /*! \brief The device of the allocated buffers. */ + Device device; +}; + +enum AllocatorType { + kNaive = 1, + kPooled, +}; + +class Allocator { + public: + explicit Allocator(AllocatorType type) : type_(type) {} + virtual ~Allocator() = default; + /*! \brief Allocate an empty NDArray using from the allocator. + * \param shape The shape of the NDArray. + * \param dtype The datatype of the NDArray. + * \param dev The device where the array is allocated. + * \return The empty NDArray. + */ + runtime::NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev); + /*! \brief Return the allocator type. */ + inline AllocatorType type() const { return type_; } + /*! \brief Allocate a buffer given a size, alignment and type. + * \param nbytes The size of the buffer. + * \param alignment The alignment of the buffer. + * \param type_hint A type hint to the allocator. + * \return A sized allocation in the form of a buffer. + */ + virtual Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) = 0; + /*! \brief Free a buffer allocated by the allocator. + * \param buffer The buffer to free. + */ + virtual void Free(const Buffer& buffer) = 0; + + private: + AllocatorType type_; +}; + +class MemoryManager { + public: + static MemoryManager* Global(); + /*! + * \brief Get or create an allocator given the device and allocator type. + * \param dev The TVM device + * \param type The allocator type + * \return The memory allocator. + */ + static Allocator* GetOrCreateAllocator(Device dev, AllocatorType type); + /*! + * \brief Get an allocator given the device. + * \param dev The TVM device + * \return The memory allocator. + */ + static Allocator* GetAllocator(Device dev); + + private: + MemoryManager() {} + + private: + std::mutex mutex_; + std::unordered_map> allocators_; +}; + +/*! \brief An object representing a storage allocation. */ +class StorageObj : public Object { + public: + /*! \brief The index into the VM function table. */ + Buffer buffer; + + /*! \brief Allocate an NDArray from a given piece of storage. */ + runtime::NDArray AllocNDArray(uint64_t offset, ShapeTuple shape, DLDataType dtype); + + /*! \brief The deleter for an NDArray when allocated from underlying storage. */ + static void Deleter(Object* ptr); + + ~StorageObj() { + auto alloc = MemoryManager::Global()->GetAllocator(buffer.device); + alloc->Free(buffer); + } + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.Storage"; + TVM_DECLARE_FINAL_OBJECT_INFO(StorageObj, Object); +}; + +/*! \brief reference to storage. */ +class Storage : public ObjectRef { + public: + explicit Storage(Buffer buffer); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Storage, ObjectRef, StorageObj); +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_MEMORY_MANAGER_H_ diff --git a/include/tvm/runtime/relax_vm/vm.h b/include/tvm/runtime/relax_vm/vm.h new file mode 100644 index 000000000000..95a208015945 --- /dev/null +++ b/include/tvm/runtime/relax_vm/vm.h @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/vm.h + */ +#ifndef TVM_RUNTIME_RELAX_VM_VM_H_ +#define TVM_RUNTIME_RELAX_VM_VM_H_ + +#ifndef TVM_RELAX_VM_ENABLE_PROFILER +#define TVM_RELAX_VM_ENABLE_PROFILER 1 +#endif + +#include +#include +#include + +#include "./bytecode.h" +#include "./executable.h" +#include "./memory_manager.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief Possible instrument actions. + */ +enum class VMInstrumentReturnKind : int { + /*! \brief Running as normal. */ + kNoOp = 0, + /*! \brief Skip the following run, only valid in before. */ + kSkipRun = 1, +}; + +/*! + * \brief An object representing a vm closure. + */ +class VMClosureObj : public ClosureObj { + public: + /*! + * \brief The function name. The function could be any + * function object that is compatible to the VM runtime. + */ + String func_name; + + /*! + * \brief The implementation of the Closure. + * \note This function takes context pointer(VirtualMachine*) + * as the first argument. The rest of arguments follows + * the same arguments as the normal function call. + */ + PackedFunc impl; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.vm.Closure"; + TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj); +}; + +/*! \brief reference to closure. */ +class VMClosure : public Closure { + public: + VMClosure(String func_name, PackedFunc impl); + TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj); + + /*! + * \brief Create another PackedFunc with last arguments already bound to last_args. + * + * This is a helper function to create captured closures. + * \param func The input func, can be a VMClosure or PackedFunc. + * \param last_args The arguments to bound to in the end of the function. + * \note The new function takes in arguments and append the last_args in the end. + */ + static PackedFunc BindLastArgs(PackedFunc func, std::vector last_args); +}; + +/*! + * \brief The virtual machine. + * + * The virtual machine contains all the current execution state, + * as well as the executable. + * + * The goal is to have a single self-contained object, + * enabling one to easily pass around VMs, execute them on + * multiple threads, or serialize them to disk or over the + * wire. + */ +class VirtualMachine : public runtime::ModuleNode { + public: + /*! + * \brief Initialize the virtual machine for a set of devices. + * \param devices The set of TVM devices. + * \param alloc_types The allocator types for each device. + */ + virtual void Init(const std::vector& devices, + const std::vector& alloc_types) = 0; + /*! + * \brief Load the executable for the virtual machine. + * \param exec The executable. + */ + virtual void LoadExecutable(ObjectPtr exec) = 0; + /*! + * \brief Get global function in the VM. + * \param func_name The name of the function. + * \return The closure + */ + virtual VMClosure GetClosure(const String& func_name) = 0; + /*! + * \brief Invoke closure or packed function using PackedFunc convention. + * \param closure_or_packedfunc A VM closure or a packed_func. + * \param args The input arguments. + * \param rv The return value. + */ + virtual void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args, + TVMRetValue* rv) = 0; + /*! + * \brief Set an instrumentation function. + * + * If instrument is present, the function will be called + * before/after each Call instruction. + * + * bool instrument(func, func_symbol, before_run, args...) + * + * - func: Union[VMClosure, PackedFunc], the function object. + * - func_symbol: string, the symbol of the function. + * - before_run: bool, whether it is before or after call. + * - ret_value: Only valid in after run, otherwise it is null. + * - args: the arguments being passed to call. + * + * instrument can return an int which corresponds to the action value. + * \sa VMInstrumentAction + * + * \param instrument The instrument function. + */ + virtual void SetInstrument(PackedFunc instrument) = 0; + /*! + * \brief Create a specific instance of VM. + * \return Created VM + */ + static ObjectPtr Create(); + /*! + * \brief Create an instance of VM with the profiling feature enabled. + * \return Created VM + */ + static ObjectPtr CreateProfiler(); + /*! + * \brief Helper function for vm closure functions to get the context ptr + * \param arg The argument value. + */ + static VirtualMachine* GetContextPtr(TVMArgValue arg) { + return static_cast(arg.operator void*()); + } + + ~VirtualMachine() {} + + const char* type_key() const final { return "relax.VirtualMachine"; } + + //-------------------------------------------------------------------------- + // The following section contains states that other builtin can depend on + //-------------------------------------------------------------------------- + /*! \brief The memory allocators. */ + std::vector allocators; + /*! \brief Runtime physical device list. */ + std::vector devices; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_VM_H_ diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 61ca3eb9f7eb..a00ea5768e23 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef { * \sa tvm::support::With */ static IRBuilder Current(); + /*! \brief See if the current thread-local scope has an IRBuilder. */ + static bool IsInScope(); /*! * \brief Give a string name to the `obj` * \tparam TObjectRef The type of the object to name. diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 887981ccffc8..6e758372b94b 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -38,13 +39,24 @@ namespace ir { */ class IRModuleFrameNode : public IRBuilderFrameNode { public: - Array global_vars; - Array functions; + /*! \brief A map from string names to global variables that ensures global uniqueness. */ + Map global_var_map; + /*! + * \brief A map from GlobalVar to all global functions. + * \note Only defined functions are in the map, while declared functions are not included. + */ + Map functions; + /*! \brief IRModule's attributes. */ + Map attrs; + /*! \brief IRModule's global_infos */ + Map> global_infos; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); - v->Visit("global_vars", &global_vars); + v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); + v->Visit("attrs", &attrs); + v->Visit("global_infos", &global_infos); } static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame"; diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index f0e7cc6f5c2f..49bdcf60e6fb 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -37,6 +37,23 @@ namespace ir { */ TVM_DLL IRModuleFrame IRModule(); +/*! + * \brief Declare a Function without given the specific function implementation. + * \note It is usually used in cross-function call. And we can specify the function by `DefFunction` + * \param func_name The function unique name. + * \param func_signature A Function w/o body, which used to specify the function signature + * (i.e. func params and func return type/shape). + * \return The corresponding GlobalVar. + */ +TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature); + +/*! + * \brief Define the function which is declared before. + * \param func_name The function unique name. + * \param func The given function implementation + */ +TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func); + } // namespace ir } // namespace ir_builder } // namespace script diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h new file mode 100644 index 000000000000..0f544d3abcc2 --- /dev/null +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +/*! \brief The base ir_builder frame for the relax dialect. */ +class RelaxFrameNode : public IRBuilderFrameNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); } + + static constexpr const char* _type_key = "script.ir_builder.relax.RelaxFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(RelaxFrameNode, IRBuilderFrameNode); +}; + +class RelaxFrame : public IRBuilderFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, IRBuilderFrame, RelaxFrameNode); + + protected: + RelaxFrame() = default; +}; + +/*! \brief The base ir_builder frame for frames with SeqExpr + i.e. Functions, If branches + */ +class SeqExprFrameNode : public RelaxFrameNode { + public: + /*! \brief The binding blocks inside the frame. */ + Array binding_blocks; + /*! \brief The frame output expr. `NullOpt` when undefined. */ + Optional output; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("binding_blocks", &binding_blocks); + v->Visit("output", &output); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.SeqExprFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(SeqExprFrameNode, RelaxFrameNode); + + public: + void EnterWithScope() override; + void ExitWithScope() override; +}; + +class SeqExprFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SeqExprFrame, RelaxFrame, SeqExprFrameNode); +}; + +/*! \brief The ir_builder frame for the relax function. */ +class FunctionFrameNode : public SeqExprFrameNode { + public: + /*! + * \brief The function name. + * \note The name will not be specified in constructor, so it is "Optional", + * However, we must specify the name by `R.func_name` before exit this frame. + */ + Optional name; + /*! \brief The function params. */ + Array params; + /*! + * \brief The function return struct info. + * \note Usually the function return type can be deduced by the function body. + * But we can use this field to specify a more "accurate" return type. + * i.e. If the `ret_struct_info` is None, try to use the deduced type from body + * If the `ret_struct_info` is not None, we can still take body.struct_info + * if we ret_struct_info is base of body.struct_info. If not, we will + * take the specified `ret_struct_info`. + */ + Optional ret_struct_info; + + /*! \brief The function attributes. */ + Map attrs; + /*! \brief The block builder to create Relax function. */ + tvm::relax::BlockBuilder block_builder; + + void VisitAttrs(tvm::AttrVisitor* v) { + SeqExprFrameNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("params", ¶ms); + v->Visit("ret_struct_info", &ret_struct_info); + v->Visit("attrs", &attrs); + v->Visit("binding_blocks", &binding_blocks); + v->Visit("output", &output); + // `block_builder` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.relax.FunctionFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode); + + public: + void ExitWithScope() final; +}; + +class FunctionFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionFrame, SeqExprFrame, FunctionFrameNode); +}; + +/*! \brief The ir_builder frame for relax binding blocks. */ +class BlockFrameNode : public RelaxFrameNode { + public: + /*! \brief The flag that indicates whether the block is a dataflow block. */ + bool is_dataflow; + /*! \brief The variables emitted in this block. */ + Array emitted_vars; + /*! + * \brief A boolean indicating if the dataflow block is ended of construction. + * If it is true, any new binding trying to be emitted into this block will cause an error. + * \note Only used for a dataflow block. + */ + bool block_ended; + /*! + * \brief The output vars of the dataflow block. + * \note Only used for a dataflow block. + */ + Array output_vars; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("is_dataflow", &is_dataflow); + v->Visit("emitted_vars", &emitted_vars); + v->Visit("output_vars", &output_vars); + // `block_ended` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.relax.BlockFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, RelaxFrameNode); + + public: + void EnterWithScope() final; + void ExitWithScope() final; +}; + +class BlockFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, RelaxFrame, BlockFrameNode); +}; + +/*! + * \brief A frame that represents if statement. + * + * \sa IfFrame + */ +class IfFrameNode : public RelaxFrameNode { + public: + /*! \brief The condition of the if statement. */ + tvm::relax::Expr condition; + /*! \brief The Bindings in the true branch. */ + Optional then_expr; + /*! \brief The Bindings in the false branch. */ + Optional else_expr; + /*! \brief The Binding var. */ + tvm::relax::Var var; + /*! \brief The binding var name. */ + String var_name; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + v->Visit("then_expr", &then_expr); + v->Visit("else_expr", &else_expr); + v->Visit("var", &var); + v->Visit("var_name", &var_name); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.IfFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, RelaxFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to IfFrameNode. + * + * \sa IfFrameNode + */ +class IfFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, RelaxFrame, IfFrameNode); +}; + +/*! + * \brief A frame that represents then. + * + * \sa ThenFrame + */ +class ThenFrameNode : public SeqExprFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.relax.ThenFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, SeqExprFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ThenFrameNode. + * + * \sa ThenFrameNode + */ +class ThenFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, SeqExprFrame, ThenFrameNode); +}; + +/*! + * \brief A frame that represents else. + * + * \sa ElseFrame + */ +class ElseFrameNode : public SeqExprFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.relax.ElseFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, SeqExprFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ElseFrameNode. + * + * \sa ElseFrameNode + */ +class ElseFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, SeqExprFrame, ElseFrameNode); +}; + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h new file mode 100644 index 000000000000..42aa591a95b7 --- /dev/null +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +/////////////////////////////// Function //////////////////////////////// + +/*! + * \brief Start a function frame. + * \return The created ir_builder Function frame. + */ +TVM_DLL FunctionFrame Function(); + +/*! + * \brief Add a parameter to the last function frame. + * \param name The name of the parameter. + * \param struct_info The struct_info of the parameter. + * \return The created function parameter var. + */ +TVM_DLL tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_info); + +/*! + * \brief Specify the name of the last function frame. + * \param name The function name. + */ +TVM_DLL void FuncName(const String& name); + +/*! + * \brief Specify the attrs of the last function frame. + * \param attrs The function attrs. + */ +TVM_DLL void FuncAttrs(Map attrs); + +/*! + * \brief Specify the return struct info of the last function frame. + * \param ret_sinfo The return struct info. + */ +TVM_DLL void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo); + +/*! + * \brief Specify the return value of the last function frame. + * \param value The return value. + */ +TVM_DLL void FuncRetValue(const tvm::relax::Expr& value); + +///////////////////////////// BindingBlock ////////////////////////////// + +/*! + * \brief Start a binding block frame. + * \return The created ir_builder Block frame. + */ +TVM_DLL BlockFrame BindingBlock(); + +/*! + * \brief Start a dataflow binding block frame. + * \return The created ir_builder Block frame. + */ +TVM_DLL BlockFrame Dataflow(); + +/*! + * \brief Expose the dataflow block output variables as global ones + * \param vars The output variables of a dataflow block + */ +TVM_DLL void DataflowBlockOutput(const Array& vars); + +////////////////////////////// Bindings //////////////////////////////// + +/*! + * \brief Emit a binding to the last binding block frame. + * \param value The right side value of the bindings to be emitted. + * \param annotate_struct_info The optional struct info annotation for the emitted value. + * \return The left side var of the emitted binding. + */ +TVM_DLL tvm::relax::Var Emit( + const tvm::relax::Expr& value, + const Optional& annotate_struct_info = NullOpt); + +/*! + * \brief Emit a match_cast binding to the last binding block frame. + * \param value The value of the MatchCast to be emitted. + * \param struct_info The struct info of the MatchCast to be emitted. + * \return The left side var of the emitted binding. + */ +TVM_DLL tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, + const tvm::relax::StructInfo& struct_info); + +/*! + * \brief Emit a binding to the last binding block frame. + * \param binding The binding to be emitted. + * \return The left side var of the emitted binding. + */ +TVM_DLL tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding); + +///////////////////////////// If Then Else ///////////////////////////// + +/*! + * \brief Create an if statement. + * \param condition The condition of if statement. + * \return The result IfFrame. + */ +IfFrame If(tvm::relax::Expr condition); +/*! + * \brief Create a then. + * \return The result ThenFrame. + */ +ThenFrame Then(); +/*! + * \brief Create an else. + * \return The result ElseFrame. + */ +ElseFrame Else(); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 2c50f3c3157b..f5753afa560f 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -182,7 +182,7 @@ class PlaceholderOpNode : public OperationNode { } static constexpr const char* _type_key = "PlaceholderOp"; - TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); + TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; /*! diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index d7a2aec0b972..e3a853e4c7ea 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -34,6 +34,18 @@ namespace tvm { namespace tir { +#ifndef TVM_INDEX_DEFAULT_I64 +#define TVM_INDEX_DEFAULT_I64 1 +#endif +/*! \brief if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32 */ +inline DataType DefaultIndexType() { +#if TVM_INDEX_DEFAULT_I64 + return DataType::Int(64); +#else + return DataType::Int(32); +#endif +} + // forward declare Stmt class Stmt; @@ -135,7 +147,7 @@ class BufferNode : public Object { /*! \return preferred index type for this buffer node */ DataType DefaultIndexType() const { - return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); + return shape.size() != 0 ? shape[0].dtype() : tvm::tir::DefaultIndexType(); } /*! \brief Determine the offset in the buffer of the given index. diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index e8bcc028fc58..fa1737f20ed3 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -797,6 +797,50 @@ TVM_DLL const Op& start_profile_intrinsic(); */ TVM_DLL const Op& end_profile_intrinsic(); +/*! + * \brief Get a item from any list and return it. + * + * Any anylist_getitem(Handle anylist, + * int index) + * return anylist[index]; + * } + * + * \note This intrinsic is only applicable when appearing + * in call_packed and anylist_setitem_call_packed. + */ +TVM_DLL const Op& anylist_getitem(); + +/*! + * \brief Reset and clear a item in any list. + * + * void anylist_resetitem(Handle anylist, + * int index) + * anylist[index] = nullptr; + * } + * + * \note This intrinsic is only applicable when appearing + * in call_packed and anylist_setitem_call_packed. + */ +TVM_DLL const Op& anylist_resetitem(); + +/*! + * \brief Set an item into any list by running packed function call. + * + * void anylist_setitem_call_packed(Handle anylist, + * int index, + * name, *args) + * + * anylist[index] = call_packed(name, *args) + * } + * \note This intrinsic can be used in combination with anylist_getitem. + */ +TVM_DLL const Op& anylist_setitem_call_packed(); + +/*! + * \brief Same as anylist_setitem_call_packed but use C calling convention. + */ +TVM_DLL const Op& anylist_setitem_call_cpacked(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 48328263fb55..406411391209 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -332,6 +332,13 @@ constexpr const char* kIsGlobalFunc = "tir.is_global_func"; */ constexpr const char* kIsHostFunc = "tir.is_host_func"; +/*! + * \brief Mark the function as scheduled, so the default schedule will pass will skip it. + * + * Type: Integer + */ +constexpr const char* kIsScheduled = "tir.is_scheduled"; + } // namespace attr } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index d4f537ff3169..72207a096e9c 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -56,7 +56,7 @@ using tvm::transform::Sequential; */ TVM_DLL Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required); + int opt_level, String name, tvm::Array required, bool traceable = false); /*! * \brief Inject prefetch instructions into stmt. @@ -336,6 +336,14 @@ TVM_DLL Pass CombineContextCall(); */ TVM_DLL Pass NarrowDataType(int target_bits); +/*! + * \brief Force to narrow down indexing expressions and integer buffers to int32 dtype. + * + * \return The pass. + * \note This pass should not be used in default cases. + */ +TVM_DLL Pass ForceNarrowIndexToInt32(); + /*! * \brief Legalize bf16 compute Ops. Add a cast to fp32 * before Ops, then add a cast back to bf16. @@ -721,6 +729,18 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage(); */ TVM_DLL Pass InstrumentProfileIntrinsics(); +/*! + * \brief The pass sets default thread bindings for PrimFuncs, including symbolic shape functions, + * allowing their build and execution on GPU devices. It examines all the blocks within the + * PrimFunc and conducts loop fusion, splitting, and reordering operations based on the loop extent + * and target information, such as the maximum thread block number and maximum thread per block. + * \note The primary objective of this pass is not to optimize performance, but rather to + * generate a valid GPU kernel for unscheduled or symbolic shape PrimFuncs. The pass is + * currently only working for CUDA targets. + * \return The Pass. + */ +TVM_DLL Pass DefaultGPUSchedule(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/include/tvm/topi/nn/group_norm.h b/include/tvm/topi/nn/group_norm.h index 43760bab1fd0..5636de11acec 100644 --- a/include/tvm/topi/nn/group_norm.h +++ b/include/tvm/topi/nn/group_norm.h @@ -25,7 +25,6 @@ #define TVM_TOPI_NN_GROUP_NORM_H_ #include -#include #include #include @@ -41,9 +40,17 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& int num_groups, int channel_axis, const Array& axes, double epsilon, std::string name = "T_group_norm", std::string tag = kInjective) { + const auto& data_type = data->dtype; + const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; + const auto& beta_type = beta.defined() ? beta->dtype : data_type; + ICHECK(data_type == gamma_type && data_type == beta_type) + << "group_norm: data, gamma and beta must have the same type"; + ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + << "group_norm: only support float32 and float16 for now"; + bool is_float16 = data_type == DataType::Float(16); // reshape data C -> G, C/G int ndim = data->shape.size(); - channel_axis = GetRealAxis(ndim, {channel_axis})[0]; + channel_axis = GetRealAxis(static_cast(ndim), {channel_axis})[0]; auto shape = data->shape; auto group_size = floordiv(shape[channel_axis], num_groups); @@ -56,8 +63,13 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& new_shape.push_back(shape[i]); } } - auto data_reshaped = reshape(data, new_shape); - // reshape gamma and beta, C -> G, C/G + Tensor data_reshaped; + if (is_float16) { + data_reshaped = cast(reshape(data, new_shape), DataType::Float(32)); + } else { + data_reshaped = reshape(data, new_shape); + } + // reshape gamma and beta, C -> G, C/G, cast to float32 if float16 Tensor gamma_reshaped; if (gamma.defined()) { gamma_reshaped = reshape(gamma, {num_groups, group_size}); @@ -70,7 +82,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& // get the new axes to normalize after reshape std::vector new_axes{channel_axis + 1}; for (auto axis : axes) { - int new_axis = GetRealAxis(ndim, {axis})[0]; + int new_axis = GetRealAxis(static_cast(ndim), {axis})[0]; if (new_axis < channel_axis) { new_axes.push_back(new_axis); } else if (new_axis > channel_axis) { @@ -81,7 +93,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& } std::sort(new_axes.begin(), new_axes.end()); - // sum x and x^2 + // sum x and x^2, cast to float32 if float16 ndim = data_reshaped->shape.size(); auto reduce_axes = MakeReduceAxes(new_axes, data_reshaped); auto target_shape = @@ -113,7 +125,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto temp_x = temp_x_x2[0]; auto temp_x2 = temp_x_x2[1]; - auto reduce_extent = make_const(data->dtype, 1); + auto reduce_extent = make_const(DataType::Float(32), 1); for (auto axis : new_axes) { reduce_extent *= data_reshaped->shape[axis]; } @@ -129,8 +141,11 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& gamma_indices = {indices[channel_axis], indices[channel_axis + 1]}; auto mean = temp_x(non_reduce_indices) / reduce_extent; auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; - auto group_norm = + PrimExpr group_norm = (data_reshaped(indices) - mean) * tvm::rsqrt(var + make_const(data->dtype, epsilon)); + if (is_float16) { + group_norm = Cast(DataType::Float(16), group_norm); + } if (gamma.defined()) { group_norm = topi::multiply(group_norm, gamma_reshaped(gamma_indices)); } diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index 93e5582ef184..ee0cba74dd3b 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -51,6 +51,14 @@ using namespace tvm::te; inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, const Array& axis, double epsilon, std::string name = "T_layer_norm", std::string tag = kInjective) { + const auto& data_type = data->dtype; + const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; + const auto& beta_type = beta.defined() ? beta->dtype : data_type; + ICHECK(data_type == gamma_type && data_type == beta_type) + << "layer_norm: data, gamma and beta must have the same type"; + ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + << "layer_norm: only support float32 and float16 for now"; + bool is_float16 = data_type == DataType::Float(16); // sum x and x^2 auto ndim = data->shape.size(); ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; @@ -60,7 +68,8 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true); auto func = MakeTupleSumReducer(); - auto compute = [ndim, &real_axis, &reduce_axes, &func, &data](const Array& indices) { + auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, + &data](const Array& indices) { Array eval_range; int arg_counter = 0; int red_counter = 0; @@ -75,8 +84,18 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& arg_counter++; } } - auto square = [](const PrimExpr& x) { return x * x; }; - return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr); + auto square = [is_float16](const PrimExpr& x) { + if (is_float16) { + return Cast(DataType::Float(32), x) * Cast(DataType::Float(32), x); + } + return x * x; + }; + if (is_float16) { + return func({Cast(DataType::Float(32), data(eval_range)), square(data(eval_range))}, + reduce_axes, nullptr); + } else { + return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr); + } }; auto temp_x_x2 = @@ -101,6 +120,9 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto mean = temp_x(non_reduce_indices) / reduce_extent; auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; auto layer_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon)); + if (is_float16) { + layer_norm = Cast(DataType::Float(16), layer_norm); + } layer_norm = topi::multiply(layer_norm, gamma(reduce_indices)); if (beta.defined()) { layer_norm = topi::add(layer_norm, beta(reduce_indices)); diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index 9a3fa2720017..11d4bfb23415 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -74,9 +74,15 @@ def get_dll_directories(): dll_path.append(install_lib_dir) - if os.path.isdir(source_dir): - dll_path.append(os.path.join(source_dir, "web", "dist", "wasm")) - dll_path.append(os.path.join(source_dir, "web", "dist")) + # use extra TVM_HOME environment for finding libraries. + if os.environ.get("TVM_HOME", None): + tvm_source_home_dir = os.environ["TVM_HOME"] + else: + tvm_source_home_dir = source_dir + + if os.path.isdir(tvm_source_home_dir): + dll_path.append(os.path.join(tvm_source_home_dir, "web", "dist", "wasm")) + dll_path.append(os.path.join(tvm_source_home_dir, "web", "dist")) dll_path = [os.path.realpath(x) for x in dll_path] return [x for x in dll_path if os.path.isdir(x)] diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py new file mode 100644 index 000000000000..f7dee4e3b80a --- /dev/null +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import +"""Generator for CUTLASS attention kernels.""" +from .library import * + + +def instantiate_attention_template(attrs, func_args): + """Return CUTLASS host code for fused multi head attention + based on a template and the provided attribute map.""" + + bias_template = { + "B11S'": """ + CHECK(${arg3}->ndim == 2); // B, 1, 1, S' + + p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.bias_strideM = 0; // 0 + p.bias_strideH = 0; // 0 + p.bias_strideB = p.num_keys; // S' +""", + "B1SS'": """ + CHECK(${arg3}->ndim == 3); // B, 1, S, S' + + p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.bias_strideM = p.num_keys; // S' + p.bias_strideH = 0; // 0 + p.bias_strideB = p.bias_strideM * p.num_queries; // S' * S +""", + "BNSS'": """ + CHECK(${arg3}->ndim == 4); // B, N, S, S' + + p.attn_bias_ptr = reinterpret_cast(${arg3}->data); + p.bias_strideM = p.num_keys; // S' + p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S + p.bias_strideB = p.bias_strideH * p.num_heads; // S' * S * N +""", + } + + template = """ + using T = ${data_type}; + + CHECK(${arg0}->ndim == 4); // B, S, N, H + CHECK(${arg1}->ndim == 4); // B, S', N, H + CHECK(${arg2}->ndim == 4); // B, S', N, H' + CHECK(out0->ndim == 4); // B, S, N, H' + + using Attention = + AttentionKernel; + + typename Attention::Params p; + + p.query_ptr = reinterpret_cast(${arg0}->data); + p.key_ptr = reinterpret_cast(${arg1}->data); + p.value_ptr = reinterpret_cast(${arg2}->data); + p.logsumexp_ptr = nullptr; + p.output_ptr = reinterpret_cast(out0->data); + p.output_accum_ptr = nullptr; + if (Attention::kNeedsOutputAccumulatorBuffer) { + cudaMalloc( + &p.output_accum_ptr, + ${output_size} * sizeof(Attention::output_accum_t) + ); + } + + p.num_heads = ${num_heads}; // N + p.num_batches = ${num_batches}; // B + p.head_dim = ${head_dim}; // H + p.head_dim_value = ${head_dim_value}; // H' + p.num_queries = ${num_queries}; // S + p.num_keys = ${num_keys}; // S' + p.scale = ${scale}; + + // stride for N + p.q_strideH = p.head_dim; // H + p.k_strideH = p.head_dim; // H + p.v_strideH = p.head_dim_value; // H' + + // stride for S + p.q_strideM = p.q_strideH * p.num_heads; // H * N + p.k_strideM = p.k_strideH * p.num_heads; // H * N + p.v_strideM = p.v_strideH * p.num_heads; // H' * N + p.o_strideM = p.head_dim_value * p.num_heads; // H' * N + + // stride for B + p.q_strideB = p.q_strideM * p.num_queries; // H * N * S + p.k_strideB = p.k_strideM * p.num_keys; // H * N * S' + p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S' + + ${bias_template} + + constexpr auto kernel_fn = attention_kernel_batched_impl; + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + static bool once = [&]() { + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + return true; + }(); + } + + CHECK(Attention::check_supported(p)); + kernel_fn<<>>(p); +""" + + template = substitute_template( + template, + {"bias_template": bias_template[attrs["bias_layout"]] if "bias_layout" in attrs else ""}, + ) + + for i, arg in enumerate(func_args): + attrs["arg{}".format(i)] = arg + return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 363548fb2ba0..93d1331ac443 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -16,17 +16,23 @@ # under the License. # pylint: disable=invalid-name, dangerous-default-value, arguments-differ """Driver for partitioning and building a Relay module for CUTLASS offload.""" +import itertools import logging -import os import multiprocessing +import operator +import os +from functools import reduce +from typing import Optional, Sequence + import tvm -from tvm import relay, runtime +from tvm import relax, relay, runtime from tvm._ffi.registry import register_func from tvm.contrib.nvcc import get_cuda_version +from tvm.topi.utils import get_const_tuple from .gen_conv2d import CutlassConv2DProfiler from .gen_gemm import CutlassGemmProfiler -from .library import ConvKind +from .library import ConvKind, LayoutType logger = logging.getLogger("cutlass") @@ -52,6 +58,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): cutlass_root = _get_cutlass_path() cutlass_include = os.path.join(cutlass_root, "include") cutlass_util_include = os.path.join(cutlass_root, "tools/util/include") + cutlass_attention_include = os.path.join(cutlass_root, "examples/41_fused_multi_head_attention") kwargs = {} kwargs["cc"] = "nvcc" @@ -66,6 +73,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): "-std=c++17", "-I" + cutlass_include, "-I" + cutlass_util_include, + "-I" + cutlass_attention_include, ] if use_fast_math: kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID") @@ -521,6 +529,349 @@ def tune_cutlass_function( ) +def _get_call_node(expr: relax.Expr, op_name: str) -> Optional[relax.Call]: + node = None + + def fvisit(e): + nonlocal node + if isinstance(e, relax.Call) and e.op.name == op_name: + node = e + + relax.analysis.post_order_visit(expr, fvisit) + return node + + +def _extract_relax_function_signature(f): + signature = {} + + for i, arg in enumerate(f.params): + sinfo = arg.struct_info + signature["arg%d_shape" % i] = get_const_tuple(sinfo.shape) + signature["arg%d_dtype" % i] = sinfo.dtype + + ret_sinfo = f.ret_struct_info + if ret_sinfo.shape is not None: + signature["ret_shape"] = list(ret_sinfo.shape) + else: + signature["ret_shape"] = None + signature["ret_dtype"] = ret_sinfo.dtype + + return signature + + +def _extract_arg_idx(pattern_name, f): + pattern_entry = relax.backend.get_pattern(pattern_name) + if pattern_entry is None: + raise ValueError(f"Unsupported op_type {pattern_name}") + var2val = relax.analysis.get_var2val(f) + matched_expr = pattern_entry.pattern.extract_matched_expr(f.body.body, var2val) + + func_args = list(f.params) + + arg_idx = {} + for name, annotation_pattern in pattern_entry.annotation_patterns.items(): + arg_expr = matched_expr[annotation_pattern] + if arg_expr not in func_args: + continue + arg_idx[name] = func_args.index(arg_expr) + + return arg_idx + + +def is_shape_valid_for_cutlass_matmul( + lhs_shape: Sequence[tvm.ir.PrimExpr], + rhs_shape: Sequence[tvm.ir.PrimExpr], +) -> bool: + """ + Check whether the shape of inputs can be handled by CUTLASS GEMM. + + The stride-based batch matmul in CUTLASS cannot handle cases that some of + the batch dimensions need to be stretched while others don't. This means + it can only handle ND x ND whose batch dimensions match exactly on both side, + as well as ND x 2D and 2D x ND. For example, it cannot handle matmul with shape + (2, 1, 4, 8) x (2, 3, 8, 16), because the batch stride of lhs is not constant. + """ + if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)): + # Reduction axis must be constant + return False + + lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) + rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1) + if lhs_batches == 1 or rhs_batches == 1: + # This could be regular matmul or batch matmul with shape ND x 2D or 2D x ND + return True + + analyzer = tvm.arith.Analyzer() + # If one side has less dimensions, use 1 to fill the gap + batch_dim_pairs = list( + itertools.zip_longest( + list(lhs_shape)[-3::-1], # Remove the last two dimensions and reverse + list(rhs_shape)[-3::-1], + fillvalue=1, + ) + ) + return all(analyzer.can_prove_equal(p[0], p[1]) for p in batch_dim_pairs) + + +@relax.expr_functor.mutator +class CutlassRelaxFunctionAnnotator(relax.PyExprMutator): + """A Relax function mutator that tunes and annotates CUTLASS composite functions + with shape, dtype and generated templates. + """ + + def __init__( + self, + mod, + conv2d_profiler: CutlassConv2DProfiler, + gemm_profiler: CutlassGemmProfiler, + options, + ): + super().__init__(mod) + self.options = options + self.conv2d_profiler = conv2d_profiler + self.gemm_profiler = gemm_profiler + + def handle_conv2d(self, f, op_type): + """Tune and annotate a conv2d op.""" + signature = _extract_relax_function_signature(f) + arg_idx = _extract_arg_idx(op_type, f) + op_attrs = _get_call_node(f.body, "relax.nn.conv2d").attrs + + data_arg = f"arg{arg_idx['lhs']}" + weight_arg = f"arg{arg_idx['rhs']}" + + d_shape = signature[f"{data_arg}_shape"] + w_shape = signature[f"{weight_arg}_shape"] + out_shape = signature["ret_shape"] + data_dtype = signature[f"{data_arg}_dtype"] + weight_dtype = signature[f"{weight_arg}_dtype"] + out_dtype = signature["ret_dtype"] + padding = op_attrs["padding"] + strides = op_attrs["strides"] + dilation = op_attrs["dilation"] + conv_kind = ConvKind.Fprop + + use_3xtf32 = self.options.get("use_3xtf32", False) + profile_all_alignments = self.options.get("profile_all_alignments", False) + find_first_valid = self.options.get("find_first_valid", True) + use_multiprocessing = self.options.get("use_multiprocessing", True) + split_k_slices = self.options.get("split_k_slices", [1]) + + op_name, op_def, _ = self.conv2d_profiler.profile( + op_type, + d_shape, + w_shape, + padding, + strides, + dilation, + out_dtype, + data_dtype, + weight_dtype, + use_3xtf32, + conv_kind, + split_k_slices, + profile_all_alignments, + find_first_valid=find_first_valid, + use_multiprocessing=use_multiprocessing, + ) + + return f.with_attrs( + { + "op_type": op_type, + "data_arg_idx": arg_idx["lhs"], + "weight_arg_idx": arg_idx["rhs"], + "bias_arg_idx": arg_idx.get("bias"), + "residual_arg_idx": arg_idx.get("residual"), + "arg0_dtype": data_dtype, + "arg1_dtype": weight_dtype, + "ret_dtype": out_dtype, + "arg0_shape": d_shape, + "arg1_shape": w_shape, + "ret_shape": out_shape, + "strides": strides, + "padding": padding, + "dilation": dilation, + "cutlass_op_name": op_name, + "cutlass_op_def": op_def, + } + ) + + def handle_matmul(self, f, op_type): + """Tune and annotate a dense op.""" + signature = _extract_relax_function_signature(f) + arg_idx = _extract_arg_idx(op_type, f) + + lhs_arg = f"arg{arg_idx['lhs']}" + rhs_arg = f"arg{arg_idx['rhs']}" + + lhs_shape = signature[f"{lhs_arg}_shape"] + rhs_shape = signature[f"{rhs_arg}_shape"] + out_shape = signature["ret_shape"] + lhs_dtype = signature[f"{lhs_arg}_dtype"] + rhs_dtype = signature[f"{rhs_arg}_dtype"] + out_dtype = signature["ret_dtype"] + + if not is_shape_valid_for_cutlass_matmul(lhs_shape, rhs_shape): + raise ValueError(f"Cannot handle the input shapes, lhs: {lhs_shape}, rhs: {rhs_shape}") + + MM = lhs_shape[-2] + KK = lhs_shape[-1] + if "transposed" in op_type: + NN = rhs_shape[-2] + ldb = "K" + layout_b = LayoutType.ColumnMajor + else: + NN = rhs_shape[-1] + ldb = "N" + layout_b = LayoutType.RowMajor + + lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) + rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1) + if lhs_batches == 1 and rhs_batches == 1: + # Regular matmul + is_batched = False + batch_attrs = {} + else: + is_batched = True + batch_attrs = { + # If both lhs_batches and rhs_batches are greater than 1, + # they must be equal. This is checked by is_shape_valid_for_cutlass_matmul. + "batch": lhs_batches if rhs_batches == 1 else rhs_batches, + "batch_stride_A": 0 if lhs_batches == 1 else MM * KK, + "batch_stride_B": 0 if rhs_batches == 1 else KK * NN, + "batch_stride_C": MM * NN, + } + + use_3xtf32 = self.options.get("use_3xtf32", False) + find_first_valid = self.options.get("find_first_valid", True) + use_multiprocessing = self.options.get("use_multiprocessing", True) + + op_name, op_def, _ = self.gemm_profiler.profile( + op_type, + MM, + NN, + KK, + out_dtype, + lhs_dtype, + rhs_dtype, + use_3xtf32, + batched=is_batched, + find_first_valid=find_first_valid, + use_multiprocessing=use_multiprocessing, + layout_b=layout_b, + ) + + return f.with_attrs( + { + "op_type": op_type, + "lhs_arg_idx": arg_idx["lhs"], + "rhs_arg_idx": arg_idx["rhs"], + "residual_arg_idx": arg_idx.get("residual"), + "bias_arg_idx": arg_idx.get("bias"), + "arg0_dtype": signature["arg0_dtype"], + "arg1_dtype": signature["arg1_dtype"], + "ret_dtype": out_dtype, + "arg0_shape": signature["arg0_shape"], + "arg1_shape": signature["arg1_shape"], + "ret_shape": out_shape, + "lda": "K", + "ldb": ldb, + "ldc": "N", + "cutlass_op_name": op_name, + "cutlass_op_def": op_def, + **batch_attrs, + } + ) + + def handle_attention(self, f, op_type): + """Tune and annotate a dense op.""" + signature = _extract_relax_function_signature(f) + if _get_call_node(f.body, "relax.nn.attention") is not None: + op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs + elif _get_call_node(f.body, "relax.nn.attention_bias") is not None: + op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs + else: + raise ValueError(f"Cannot find call node for attention") + q_shape = signature["arg0_shape"] + k_shape = signature["arg1_shape"] + v_shape = signature["arg2_shape"] + out_shape = signature["ret_shape"] + q_dtype = signature["arg0_dtype"] + k_dtype = signature["arg1_dtype"] + v_dtype = signature["arg2_dtype"] + out_dtype = signature["ret_dtype"] + num_batches, num_queries, num_heads, head_dim = q_shape + _, num_keys, _, _ = k_shape + _, _, _, head_dim_value = v_shape + scale = op_attrs.scale + bias = {} + if "arg3_dtype" in signature: + bias["arg3_dtype"] = signature["arg3_dtype"] + if "arg3_shape" in signature: + bias["arg3_shape"] = signature["arg3_shape"] + + return f.with_attrs( + { + "op_type": op_type, + "arg0_dtype": q_dtype, + "arg1_dtype": k_dtype, + "arg2_dtype": v_dtype, + "ret_dtype": out_dtype, + "arg0_shape": q_shape, + "arg1_shape": k_shape, + "arg2_shape": v_shape, + "ret_shape": out_shape, + "num_batches": num_batches, + "num_queries": num_queries, + "num_keys": num_keys, + "num_heads": num_heads, + "head_dim": head_dim, + "head_dim_value": head_dim_value, + "scale": scale, + "arch": self.options["sm"], + **bias, + } + ) + + def visit_function_(self, f): + if "Composite" not in f.attrs: + body = super().visit_expr(f.body) + return relax.Function(f.params, body, f.ret_struct_info, f.attrs, f.span) + + op_type = f.attrs["Composite"] + + if "conv2d" in op_type: + return self.handle_conv2d(f, op_type) + elif "matmul" in op_type: + return self.handle_matmul(f, op_type) + elif "attention" in op_type: + return self.handle_attention(f, op_type) + + raise ValueError("Unsupported composite {}".format(op_type)) + + def visit_span(self, span): + return span + + +@register_func("contrib.cutlass.tune_relax_function") +def profile_relax_function(functions, options): + """Tune and annotate CUTLASS composite functions with shape, dtype and generated templates.""" + tmp_dir = options.get("tmp_dir", "./tmp") + sm = options.get("sm", 80) + conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir) + gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir) + + annotated_functions = [] + + for f in functions: + annotator = CutlassRelaxFunctionAnnotator( + tvm.IRModule.from_expr(f), conv2d_profiler, gemm_profiler, options + ) + annotated_functions.append(annotator.visit_expr(f)) + + return annotated_functions + + @register_func("contrib.cutlass.compile") def compile_cutlass_module(c_source_module, options): """Compile all CUTLASS kernels in the given C-source module. diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index 1444009799fe..f2d2f0127626 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -354,7 +354,7 @@ def emit( return substitute_template(template, values) -def instantiate_conv2d_template(attrs, func_args): +def instantiate_conv2d_template(attrs): """Return CUTLASS host code for conv2d based on a template and the provided attribute map.""" template = """ ${cutlass_op_def} @@ -382,8 +382,8 @@ def instantiate_conv2d_template(attrs, func_args): cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, split_k_slices); const cutlass::conv::SplitKMode split_k_mode = cutlass::conv::SplitKMode::${split_k_mode}; - void* ptr_a = (void*)(${arg0}->data); - void* ptr_b = (void*)(${arg1}->data); + void* ptr_a = (void*)(${data_arg}->data); + void* ptr_b = (void*)(${weight_arg}->data); ${bias_decl} ${residual_decl} void* ptr_out = (void*)(out0->data); @@ -466,7 +466,7 @@ def instantiate_conv2d_template(attrs, func_args): use_split_k = "splitk" in attrs["cutlass_op_name"] is_wgrad = "backward_weight" in op_type is_dgrad = "conv2d_transpose" in op_type - has_residual_blcok = "residual" in op_type + has_residual_block = "residual" in op_type no_bias_scaling = op_type not in [ "cutlass.conv2d_bias_sigmoid", "cutlass.conv2d_bias_silu", @@ -475,18 +475,18 @@ def instantiate_conv2d_template(attrs, func_args): aux_map = {} - if (not has_bias or no_bias_scaling) and not has_residual_blcok: - aux_map["beta"] = "0" + if (not has_bias or no_bias_scaling) and not has_residual_block: + aux_map["beta"] = 0 else: - aux_map["beta"] = "1" + aux_map["beta"] = 1 - if has_residual_blcok: - aux_map["bias_decl"] = "void* ptr_bias = (void*)(${arg2}->data);\n" - aux_map["residual_decl"] = "void* ptr_residual = (void*)(${arg3}->data);" + if has_residual_block: + aux_map["bias_decl"] = "void* ptr_bias = (void*)(${bias_arg}->data);\n" + aux_map["residual_decl"] = "void* ptr_residual = (void*)(${residual_arg}->data);" aux_map["tensor_c"] = "ptr_residual" aux_map["tensor_c_layout"] = "layout_C" elif has_bias: - aux_map["bias_decl"] = "void* ptr_c_bias = (void*)(${arg2}->data);\n" + aux_map["bias_decl"] = "void* ptr_c_bias = (void*)(${bias_arg}->data);\n" aux_map["residual_decl"] = "" aux_map["tensor_c"] = "ptr_c_bias" aux_map["tensor_c_layout"] = "cutlass::layout::TensorNHWC::Stride(0)" @@ -496,12 +496,12 @@ def instantiate_conv2d_template(attrs, func_args): aux_map["tensor_c"] = "ptr_out" aux_map["tensor_c_layout"] = "layout_C" - if has_bias and no_bias_scaling and not has_residual_blcok: + if has_bias and no_bias_scaling and not has_residual_block: aux_map["alpha_beta"] = "alpha" else: aux_map["alpha_beta"] = "alpha, beta" - if has_residual_blcok: + if has_residual_block: aux_map["additional_args"] = ", static_cast(ptr_bias), nullptr, 0, K" else: aux_map["additional_args"] = "" @@ -534,7 +534,4 @@ def instantiate_conv2d_template(attrs, func_args): template = substitute_template(template, aux_map) - for i, arg in enumerate(func_args): - attrs["arg{}".format(i)] = arg - return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index 58f5de6a9c9a..b820ead016fe 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -164,6 +164,7 @@ def __init__(self): ${element_accumulator}, ${element_epilogue} >""" + self.epilogue_no_beta_scaling = """ ${epilogue_functor}< ${element_c}, @@ -172,6 +173,19 @@ def __init__(self): ${element_epilogue}, cutlass::epilogue::thread::ScaleType::NoBetaScaling >""" + + self.epilogue_residual_block = """ + ${epilogue_functor}< + ${element_c}, + ${element_accumulator}, + ${element_epilogue}, + ${element_c}, + ${epilogue_vector_length}, + ${activation}, + ${binary_op}, + ${unary_op} + >""" + self.gemm_template = """ // Gemm operator ${operation_name} using Operation_${operation_name} = cutlass::gemm::device::${kernel_name}< @@ -188,13 +202,11 @@ def __init__(self): ${swizzling_functor}, ${stages}, ${align_a}, - ${align_b}, - ${split_k_serial} - ${math_operation} + ${align_b} >; """ - def emit(self, operation, no_beta_scaling=False, batched=False): + def emit(self, operation, no_beta_scaling=False, batched=False, residual_block_info=False): """Instantiate a GEMM kernel from given `operation`.""" warp_shape = [ operation.tile_description.threadblock_shape[idx] @@ -246,22 +258,73 @@ def emit(self, operation, no_beta_scaling=False, batched=False): } values["kernel_name"] = "GemmBatched" if batched else "Gemm" - values["split_k_serial"] = "" if batched else "false," - gemm_template = substitute_template( - self.gemm_template, - { - "epilogue": self.epilogue_no_beta_scaling - if no_beta_scaling - else self.epilogue_default - }, - ) - return substitute_template(gemm_template, values) + if residual_block_info: + values["kernel_name"] = "GemmUniversalWithBroadcast" + template = substitute_template( + self.gemm_template, {"epilogue": self.epilogue_residual_block} + ) + values.update( + { + "unary_op": residual_block_info["unary_op"], + "binary_op": residual_block_info["binary_op"], + "activation": residual_block_info["activation"], + } + ) + elif no_beta_scaling: + template = substitute_template( + self.gemm_template, {"epilogue": self.epilogue_no_beta_scaling} + ) + else: + template = substitute_template(self.gemm_template, {"epilogue": self.epilogue_default}) + + return substitute_template(template, values) -def instantiate_gemm_template(attrs, func_args): +def instantiate_gemm_template(attrs): """Return CUTLASS host code for GEMM based on a template and the provided attribute map.""" + argument_template_default = """ + typename ${kernel}::Arguments arguments{ + problem_size, + {static_cast(ptr_a), ${lda}}, ${batch_stride_A} + {static_cast(ptr_b), ${ldb}}, ${batch_stride_B} + {static_cast(${ptr_c}), ${c_stride}}, ${batch_stride_C} + {static_cast(ptr_out), ${ldc}}, ${batch_stride_D} + {${alpha_beta}}, + ${split_k_slices_or_batch} + }; + """ + + # See cutlass/gemm/kernel/gemm_with_fused_epilogue.h + # Batched GEMM + residual fusion is not supported for now. + argument_template_residual = """ + typename ${kernel}::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + 1, // batch_count, + {${alpha_beta}}, + static_cast(ptr_a), + static_cast(ptr_b), + static_cast(ptr_residual), + static_cast(ptr_out), + static_cast(ptr_bias), + nullptr, // ptr_Tensor + 0, // batch_stride_A, + 0, // batch_stride_B, + 0, // batch_stride_C, + 0, // batch_stride_D, + 0, // batch_stride_Vector, + 0, // batch_stride_Tensor, + ${lda}, + ${ldb}, + ${ldc}, + ${ldc}, + 0, // ldv, the stride for bias + 0, // ldt + }; + """ + template = """ using ElementInputA = ${ElementInputA}; using ElementInputB = ${ElementInputB}; @@ -277,20 +340,13 @@ def instantiate_gemm_template(attrs, func_args): cutlass::gemm::GemmCoord problem_size(M, N, K); ElementComputeEpilogue alpha = ElementComputeEpilogue(1); ElementComputeEpilogue beta = ElementComputeEpilogue(${beta}); - void* ptr_a = (void*)(${arg0}->data); - void* ptr_b = (void*)(${arg1}->data); + void* ptr_a = (void*)(${lhs_arg}->data); + void* ptr_b = (void*)(${rhs_arg}->data); ${bias_decl} + ${residual_decl} void* ptr_out = (void*)(out0->data); - typename ${kernel}::Arguments arguments{ - problem_size, - {static_cast(ptr_a), ${lda}}, ${batch_stride_A} - {static_cast(ptr_b), ${ldb}}, ${batch_stride_B} - {static_cast(${ptr_c}), ${c_stride}}, ${batch_stride_C} - {static_cast(ptr_out), ${ldc}}, ${batch_stride_C} - {${alpha_beta}}, - ${split_k_slices_or_batch} - }; + ${argument} size_t workspace_size = ${kernel}::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); ${kernel} gemm_op; @@ -301,30 +357,32 @@ def instantiate_gemm_template(attrs, func_args): status = gemm_op(); CHECK(status == cutlass::Status::kSuccess); """ - has_bias = "bias" in attrs["op_type"] - is_gelu = "gelu" in attrs["op_type"] - batched = "batch_matmul" in attrs["op_type"] - + op_type = attrs["op_type"] + has_bias = "bias" in op_type + is_gelu = "gelu" in op_type + batched = "batch" in attrs + has_residual_block = "residual" in op_type aux_map = {"kernel": "Gemm"} if has_bias: aux_map.update( { - "bias_decl": "void* ptr_c_bias = (void*)(${arg2}->data);\n", - "ptr_c": "ptr_c_bias", - "c_stride": "0", + "bias_decl": "void* ptr_bias = (void*)(${bias_arg}->data);\n", + "ptr_c": "ptr_bias", + "c_stride": "(${bias_arg}->ndim == 1 || ${bias_arg}->shape[0] == 1) ? 0 : " + + attrs["ldc"], } ) else: aux_map.update({"bias_decl": "", "ptr_c": "ptr_out", "c_stride": attrs["ldc"]}) - if is_gelu: + if is_gelu or has_residual_block: # GeLU epilogue does not compile with NoBetaScaling, so we explicitly specify the scale. - aux_map["beta"] = "1" + aux_map["beta"] = 1 else: - aux_map["beta"] = "0" + aux_map["beta"] = 0 - if has_bias and not is_gelu: + if has_bias and not is_gelu and not has_residual_block: aux_map["alpha_beta"] = "alpha" else: aux_map["alpha_beta"] = "alpha, beta" @@ -335,14 +393,23 @@ def instantiate_gemm_template(attrs, func_args): else: aux_map[key] = attrs[key] + "," + aux_map["batch_stride_D"] = aux_map["batch_stride_C"] + if has_bias and batched: + aux_map["batch_stride_C"] = "0," + if batched: attrs["split_k_slices_or_batch"] = attrs["batch"] else: - attrs["split_k_slices_or_batch"] = "1" + attrs["split_k_slices_or_batch"] = 1 - template = substitute_template(template, aux_map) + if has_residual_block: + assert not batched, "Residual fusion is supported only for non-batched GEMM for now." + template = substitute_template(template, {"argument": argument_template_residual}) + aux_map["residual_decl"] = "void* ptr_residual = (void*)(${residual_arg}->data);\n" + else: + template = substitute_template(template, {"argument": argument_template_default}) + aux_map["residual_decl"] = "" - for i, arg in enumerate(func_args): - attrs["arg{}".format(i)] = arg + template = substitute_template(template, aux_map) return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/gemm_profiler.py b/python/tvm/contrib/cutlass/gemm_profiler.py index 13679cd05c42..e89e7defbfb7 100644 --- a/python/tvm/contrib/cutlass/gemm_profiler.py +++ b/python/tvm/contrib/cutlass/gemm_profiler.py @@ -55,7 +55,7 @@ def __init__(self): } template -cudaError_t CutlassGemmRCR( +cudaError_t CutlassGemm( int M, int N, int K, @@ -148,7 +148,7 @@ def __init__(self): cudaFree(B); return result; } - result = CutlassGemmRCR(M, N, K, alpha, A, lda, B, ldb, + result = CutlassGemm(M, N, K, alpha, A, lda, B, ldb, beta, C_cutlass, ldc); if (result != cudaSuccess) { std::cerr << "CUTLASS GEMM kernel failed: " diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index bb26a47a5548..9e9e16426ba6 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name, dangerous-default-value """Conv2d kernel generator and profiler for CUTLASS.""" +import os +import pickle from functools import partial from .conv2d_operation import Conv2dOperation, EmitConv2dInstance from .gen_gemm import CutlassGemmProfiler @@ -40,6 +42,7 @@ def create_conv2d_operator_with_epilogue( tile_description, data_type, alignment, + alignment_epilogue, swizzling_functor, split_k_slices, ): @@ -78,7 +81,7 @@ def create_conv2d_operator_with_epilogue( A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment) B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment) - C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment) + C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment_epilogue) op = Conv2dOperation( conv_kind, @@ -110,6 +113,7 @@ def enumerate_conv2d_operators( conv_kind, stride_support, split_k_slices, + alignment_c, tile_descriptions, data_type, alignment_constraints, @@ -128,47 +132,49 @@ def enumerate_conv2d_operators( for split_k_slice in split_k_slices: for tile in tile_descriptions: - for alignment in alignment_constraints: - - A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment) - B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment) - C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment) - - if element_c == DataType.s32 and A.alignment == 1: - tile.threadblock_shape[0] = min(tile.threadblock_shape[0], 128) - tile.threadblock_shape[1] = min(tile.threadblock_shape[1], 128) - - op = Conv2dOperation( - conv_kind, - IteratorAlgorithm.Optimized, - tile.minimum_compute_capability, - tile, - A, - B, - C, - element_epilogue, - stride_support, - EpilogueFunctor.LinearCombination, - swizzling_functor, - split_k_slice, - ) - - ret.append( - { - "src": profiler_emitter.emit( - kernel_emitter.emit(op, emit_reduction=split_k_slice > 1), - op.procedural_name(), - element_output=element_c, - split_k_slices=split_k_slice, - ), - "name": op.procedural_name(), - "tile_description": tile, - "alignment": alignment, - "data_type": data_type, - "swizzle_functor": swizzling_functor, - "split_k_slices": split_k_slice, - } - ) + for alignmentAB in alignment_constraints: + for alignmentC in alignment_c: + + A = TensorDescription(element_a, LayoutType.TensorNHWC, alignmentAB) + B = TensorDescription(element_b, LayoutType.TensorNHWC, alignmentAB) + C = TensorDescription(element_c, LayoutType.TensorNHWC, alignmentC) + + if element_c == DataType.s32 and A.alignment == 1: + tile.threadblock_shape[0] = min(tile.threadblock_shape[0], 128) + tile.threadblock_shape[1] = min(tile.threadblock_shape[1], 128) + + op = Conv2dOperation( + conv_kind, + IteratorAlgorithm.Optimized, + tile.minimum_compute_capability, + tile, + A, + B, + C, + element_epilogue, + stride_support, + EpilogueFunctor.LinearCombination, + swizzling_functor, + split_k_slice, + ) + + ret.append( + { + "src": profiler_emitter.emit( + kernel_emitter.emit(op, emit_reduction=split_k_slice > 1), + op.procedural_name(), + element_output=element_c, + split_k_slices=split_k_slice, + ), + "name": op.procedural_name(), + "tile_description": tile, + "alignment": alignmentAB, + "alignment_epilogue": alignmentC, + "data_type": data_type, + "swizzle_functor": swizzling_functor, + "split_k_slices": split_k_slice, + } + ) return ret @@ -181,7 +187,11 @@ def __init__(self, sm, cutlass_path, binary_path): self.sm = sm assert sm in GENERATOR_FUNC_TABLE, "sm%d not supported yet." % sm self.engine = ProfilerEngine(sm, cutlass_path, binary_path) - self.cache = {} + self.cache_path = os.path.join(binary_path, "cutlass_conv2d_cache.pickle") + if os.path.exists(self.cache_path): + self.cache = pickle.load(open(self.cache_path, "rb")) + else: + self.cache = {} def get_default( self, @@ -216,6 +226,7 @@ def get_default( tile_description, data_type, alignment, + alignment, swizzling_functor, split_k_slices=1, ) @@ -265,12 +276,27 @@ def select_op( if workload in self.cache: return self.cache[workload] + def alignments(dtype): + if dtype in ["float16"]: + alignments = [8, 4, 2, 1] + elif dtype in ["float", "float32"]: + alignments = [4, 2, 1] + else: + raise ValueError("Unsupported data type: %s" % dtype) + return alignments + ops = GENERATOR_FUNC_TABLE[self.sm]( out_dtype, data_dtype, weight_dtype, - partial(enumerate_conv2d_operators, conv_kind, stride_support, split_k_slices), - lambda align: all([dim % align == 0 for dim in [IC, OC]]), + partial( + enumerate_conv2d_operators, + conv_kind, + stride_support, + split_k_slices, + [align for align in alignments(out_dtype) if OC % align == 0], + ), + lambda align: all([dim % align == 0 for dim in [IC]]), use_3xtf32, profile_all_alignments, # Use fp32 accumulation for wgrad to align with cuDNN @@ -294,6 +320,8 @@ def select_op( op = min(ops, key=lambda i: i["runtime"]) self.cache[workload] = op + with open(self.cache_path, "wb") as f: + pickle.dump(self.cache, f) return op def profile( @@ -350,6 +378,7 @@ def profile( op["tile_description"], op["data_type"], op["alignment"], + op["alignment_epilogue"], op["swizzle_functor"], op["split_k_slices"], ) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index ddeddbd39cac..0ea6231b8196 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -16,6 +16,10 @@ # under the License. # pylint: disable=invalid-name """GEMM kernel generator and profiler for CUTLASS.""" +import os +import pickle +from functools import partial + from .gemm_operation import EmitGemmInstance, GemmOperation from .gemm_profiler import GemmProfilerEmitter from .gen_tensor_op import EPILOGUE_MAP, GENERATOR_FUNC_TABLE, ProfilerEngine @@ -36,6 +40,7 @@ def create_gemm_operator_with_epilogue( alignment, swizzling_functor, batched=False, + layout_b=LayoutType.ColumnMajor, ): """ Instantiate a cutlass kernel from the given configuration, @@ -44,13 +49,42 @@ def create_gemm_operator_with_epilogue( element_a, element_b, element_c, element_epilogue = data_type A = TensorDescription(element_a, LayoutType.RowMajor, alignment) - B = TensorDescription(element_b, LayoutType.ColumnMajor, alignment) + B = TensorDescription(element_b, layout_b, alignment) C = TensorDescription(element_c, LayoutType.RowMajor, alignment) if batched: swizzling_functor = SwizzlingFunctor.Batched - epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] + if "residual" in op_type: + if "hardswish" in op_type: + activation = "cutlass::epilogue::thread::HardSwish" + elif "silu" in op_type: + activation = "cutlass::epilogue::thread::SiLu" + elif "sigmoid" in op_type: + activation = "cutlass::epilogue::thread::Sigmoid" + elif "gelu" in op_type: + activation = "cutlass::epilogue::thread::GELU" + elif "relu" in op_type: + activation = "cutlass::epilogue::thread::ReLu" + else: + activation = "cutlass::epilogue::thread::Identity" + + binary_op = "cutlass::multiplies" if "residual_multiply" in op_type else "cutlass::plus" + unary_op = ( + "cutlass::epilogue::thread::ReLu" + if op_type.endswith("relu") + else "cutlass::epilogue::thread::Identity" + ) + residual_block_info = { + "activation": activation, + "binary_op": binary_op, + "unary_op": unary_op, + } + epilogue = EpilogueFunctor.LinearCombinationResidualBlock + no_beta_scaling = False + else: + residual_block_info = None + epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] op = GemmOperation( tile_description.minimum_compute_capability, @@ -65,7 +99,12 @@ def create_gemm_operator_with_epilogue( return ( op.procedural_name(), - EmitGemmInstance().emit(op, no_beta_scaling=no_beta_scaling, batched=batched), + EmitGemmInstance().emit( + op, + no_beta_scaling=no_beta_scaling, + batched=batched, + residual_block_info=residual_block_info, + ), ) @@ -74,6 +113,7 @@ def enumerate_gemm_operators( data_type, alignment_constraints, swizzling_functor=SwizzlingFunctor.Identity8, + layout_b=LayoutType.ColumnMajor, ): """Exhaustively instantiate all kernels from a given configuration.""" ret = [] @@ -85,7 +125,7 @@ def enumerate_gemm_operators( for tile_description in tile_descriptions: for alignment in alignment_constraints: A = TensorDescription(element_a, LayoutType.RowMajor, alignment) - B = TensorDescription(element_b, LayoutType.ColumnMajor, alignment) + B = TensorDescription(element_b, layout_b, alignment) C = TensorDescription(element_c, LayoutType.RowMajor, alignment) if element_c == DataType.s32 and A.alignment == 1: @@ -157,10 +197,21 @@ def __init__(self, sm, cutlass_path, binary_path): assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not supported yet." % sm self.engine = ProfilerEngine(sm, cutlass_path, binary_path) self.sm = sm - self.cache = {} + self.cache_path = os.path.join(binary_path, "cutlass_gemm_cache.pickle") + if os.path.exists(self.cache_path): + self.cache = pickle.load(open(self.cache_path, "rb")) + else: + self.cache = {} def get_default( - self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False + self, + op_type, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32=True, + batched=False, + layout_b=LayoutType.ColumnMajor, ): """Return the default kernel for the requested architecture. For now, the default kernel was picked arbitrary. @@ -169,7 +220,7 @@ def get_default( out_dtype, arg0_dtype, arg1_dtype, - enumerate_gemm_operators, + partial(enumerate_gemm_operators, layout_b=layout_b), lambda align: align == 1, # Only request align1 kernels use_3xtf32, profile_all_alignments=True, # To include all align1 kernels @@ -194,6 +245,7 @@ def get_default( op["alignment"], op["swizzle_functor"], batched=batched, + layout_b=layout_b, ) op.update({"name": name, "opdef": opdef}) return op @@ -210,6 +262,7 @@ def select_op( profile_all_alignments=False, find_first_valid=False, use_multiprocessing=False, + layout_b=LayoutType.ColumnMajor, ): """ Profile and select the best kernel from candidate kernels. @@ -227,7 +280,7 @@ def select_op( out_dtype, arg0_dtype, arg1_dtype, - enumerate_gemm_operators, + partial(enumerate_gemm_operators, layout_b=layout_b), lambda align: all([dim % align == 0 for dim in [M, N, K]]), use_3xtf32, profile_all_alignments=profile_all_alignments, @@ -247,6 +300,8 @@ def select_op( op = min(ops, key=lambda i: i["runtime"]) self.cache[(M, N, K)] = op + with open(self.cache_path, "wb") as f: + pickle.dump(self.cache, f) return op def profile( @@ -263,6 +318,7 @@ def profile( find_first_valid=False, use_multiprocessing=False, batched=False, + layout_b=LayoutType.ColumnMajor, ): """Profile and select the best kernel from candidate kernels. If find_first_valid is True, return immediately after the first applicable kernel is found. @@ -279,6 +335,7 @@ def profile( profile_all_alignments=profile_all_alignments, find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, + layout_b=layout_b, ) name, opdef = create_gemm_operator_with_epilogue( @@ -288,6 +345,7 @@ def profile( op["alignment"], op["swizzle_functor"], batched=batched, + layout_b=layout_b, ) return name, opdef, op["runtime"] diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 1eeb0f4b26b6..61c88c657f05 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -17,27 +17,31 @@ # pylint: disable=invalid-name """Common functions and classes for CUTLASS GEMM and Conv2d geneator.""" import logging +import math +import multiprocessing import os import re -import tempfile import subprocess -import multiprocessing +import tempfile + import tvm._ffi -from tvm.tir import IntImm from tvm.runtime import Object +from tvm.tir import IntImm + from . import _ffi_api as ffi +from .attention_operation import instantiate_attention_template +from .conv2d_operation import instantiate_conv2d_template +from .gemm_operation import instantiate_gemm_template from .library import ( - MathInstruction, DataType, + DataTypeSize, DataTypeTag, - OpcodeClass, + EpilogueFunctor, + MathInstruction, MathOperation, + OpcodeClass, TileDescription, - EpilogueFunctor, ) -from .gemm_operation import instantiate_gemm_template -from .conv2d_operation import instantiate_conv2d_template - logger = logging.getLogger("cutlass") @@ -91,6 +95,7 @@ def generate_sm50_simt( DataType.f32, DataType.f32, DataType.f32, + DataType.f32, OpcodeClass.Simt, MathOperation.multiply_add, ), @@ -227,8 +232,9 @@ def generate_sm80_tensor_op_16816( def get_default_tile_descriptions(block_k_factor): return [ - ([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc), ([128, 256, int(32 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc), + ([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc), + ([256, 64, int(32 * block_k_factor)], 3, [4, 1, 1], min_cc, max_cc), ([256, 64, int(32 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc), ([64, 256, int(32 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc), ([128, 128, int(32 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), @@ -242,6 +248,9 @@ def get_default_tile_descriptions(block_k_factor): ([256, 64, int(64 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc_smem_limited), ([64, 256, int(64 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc_smem_limited), ([128, 128, int(64 * block_k_factor)], 4, [2, 2, 1], min_cc, max_cc), + ([256, 64, int(64 * block_k_factor)], 3, [4, 1, 1], min_cc, max_cc), + ([64, 256, int(64 * block_k_factor)], 3, [1, 4, 1], min_cc, max_cc), + ([128, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), ([128, 64, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), ([64, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), ([64, 64, int(64 * block_k_factor)], 5, [2, 2, 1], min_cc, max_cc), @@ -367,6 +376,14 @@ def get_tile_descriptions(math_inst): "cutlass.dense_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True), "cutlass.dense_bias_gelu_fp16": (EpilogueFunctor.LinearCombinationGelu, False), "cutlass.dense_bias_gelu_fp32": (EpilogueFunctor.LinearCombinationGelu, False), + "cutlass.matmul": (EpilogueFunctor.LinearCombination, False), + "cutlass.matmul_bias": (EpilogueFunctor.LinearCombinationBias, True), + "cutlass.matmul_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True), + "cutlass.matmul_bias_gelu": (EpilogueFunctor.LinearCombinationGelu, False), + "cutlass.matmul_transposed": (EpilogueFunctor.LinearCombination, False), + "cutlass.matmul_transposed_bias": (EpilogueFunctor.LinearCombinationBias, True), + "cutlass.matmul_transposed_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True), + "cutlass.matmul_transposed_bias_gelu": (EpilogueFunctor.LinearCombinationGelu, False), "cutlass.batch_matmul": (EpilogueFunctor.LinearCombination, False), "cutlass.conv2d_bias_hardswish": (EpilogueFunctor.LinearCombinationHardSwish, False), "cutlass.conv2d_bias_silu": (EpilogueFunctor.LinearCombinationSilu, False), @@ -450,6 +467,13 @@ def __init__(self, code, headers): self.__init_handle_by_constructor__(ffi.CodegenResult, code, headers) +def _get_optional_int_annotation(annotations, key, default=None): + value = annotations.get(key, default) + if value is None: + return default + return int(value) + + @tvm._ffi.register_func("contrib.cutlass.instantiate_template") def instantiate_template(func_name, annotations, func_args): """Return CUTLASS host code based on a template and the provided annotations. @@ -513,41 +537,114 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ return dim1 + " * " + dim2 if "dense" in func_name or "matmul" in func_name: - batched = "batch_matmul" in func_name - batched_offset = 1 if batched else 0 - attrs["K"] = str(int(arg0_shape[batched_offset + 1])) - attrs["M"] = get_dim(arg0_shape[batched_offset], func_args[0], 0, batched_offset) - - if annotations["ldb"] == "N": - attrs["N"] = get_dim(arg1_shape[batched_offset + 1], func_args[1], 1, batched_offset) + batched = "batch" in annotations + transposed = "transposed" in func_name + lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx", 0) + rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx", 1) + bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None) + residual_arg_idx = _get_optional_int_annotation(annotations, "residual_arg_idx", None) + + lhs_arg = func_args[lhs_arg_idx] + rhs_arg = func_args[rhs_arg_idx] + lhs_shape = annotations[f"arg{lhs_arg_idx}_shape"] + rhs_shape = annotations[f"arg{rhs_arg_idx}_shape"] + lhs_batched_offset = len(lhs_shape) - 2 + rhs_batched_offset = len(rhs_shape) - 2 + + attrs["lhs_arg"] = lhs_arg + attrs["rhs_arg"] = rhs_arg + + if bias_arg_idx is not None: + attrs["bias_arg"] = func_args[bias_arg_idx] + if residual_arg_idx is not None: + attrs["residual_arg"] = func_args[residual_arg_idx] + + attrs["ElementInputA"] = DataTypeTag[dtype_map[annotations[f"arg{lhs_arg_idx}_dtype"]]] + attrs["ElementInputB"] = DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]] + attrs["ElementOutput"] = DataTypeTag[dtype_map[annotations["ret_dtype"]]] + + attrs["K"] = lhs_shape[lhs_batched_offset + 1] + attrs["M"] = get_dim(lhs_shape[lhs_batched_offset], lhs_arg, 0, lhs_batched_offset) + + if transposed: + attrs["N"] = get_dim(rhs_shape[rhs_batched_offset], rhs_arg, 0, rhs_batched_offset) else: - attrs["N"] = get_dim(arg1_shape[batched_offset], func_args[1], 0, batched_offset) + attrs["N"] = get_dim(rhs_shape[rhs_batched_offset + 1], rhs_arg, 1, rhs_batched_offset) if batched: headers.append("cutlass/gemm/device/gemm_batched.h") - attrs["batch"] = get_dim(arg0_shape[0], func_args[0], 0) - attrs["batch_stride_A"] = get_batch_stride(annotations["batch_stride_A"], 0, 0, 1, 2) - attrs["batch_stride_B"] = get_batch_stride(annotations["batch_stride_B"], 1, 1, 1, 2) - if annotations["ldb"] == "N": + def get_batch_on_arg(arg_name, arg_shape): + return " * ".join( + "{}->shape[{}]".format(arg_name, i) for i in range(len(arg_shape) - 2) + ) + + if isinstance(annotations["batch"], IntImm): + attrs["batch"] = str(int(annotations["batch"])) + elif annotations["batch_stride_A"] == 0: + # 2D x ND + attrs["batch"] = get_batch_on_arg(rhs_arg, rhs_shape) + else: + # ND x 2D or ND x ND + attrs["batch"] = get_batch_on_arg(lhs_arg, lhs_shape) + + attrs["batch_stride_A"] = get_batch_stride( + annotations["batch_stride_A"], + lhs_arg_idx, + lhs_arg_idx, + lhs_batched_offset, + lhs_batched_offset + 1, + ) + attrs["batch_stride_B"] = get_batch_stride( + annotations["batch_stride_B"], + rhs_arg_idx, + rhs_arg_idx, + rhs_batched_offset, + rhs_batched_offset + 1, + ) + + if transposed: attrs["batch_stride_C"] = get_batch_stride( - annotations["batch_stride_C"], 0, 1, 1, 2 + annotations["batch_stride_C"], + lhs_arg_idx, + rhs_arg_idx, + lhs_batched_offset, + rhs_batched_offset, ) else: attrs["batch_stride_C"] = get_batch_stride( - annotations["batch_stride_C"], 0, 1, 1, 1 + annotations["batch_stride_C"], + lhs_arg_idx, + rhs_arg_idx, + lhs_batched_offset, + rhs_batched_offset + 1, ) else: headers.append("cutlass/gemm/device/gemm.h") - code = instantiate_gemm_template(attrs, func_args) + if "residual" in func_name: + headers.append("cutlass/gemm/device/gemm_universal_with_broadcast.h") + + code = instantiate_gemm_template(attrs) return CodegenResult(code, headers) elif "conv2d" in func_name: - activation_shape = arg0_shape - weight_shape = arg1_shape + data_arg_idx = _get_optional_int_annotation(annotations, "data_arg_idx", 0) + weight_arg_idx = _get_optional_int_annotation(annotations, "weight_arg_idx", 1) + bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None) + residual_arg_idx = _get_optional_int_annotation(annotations, "residual_arg_idx", None) + + attrs["data_arg"] = func_args[data_arg_idx] + attrs["weight_arg"] = func_args[weight_arg_idx] + + if bias_arg_idx is not None: + attrs["bias_arg"] = func_args[bias_arg_idx] + if residual_arg_idx is not None: + attrs["residual_arg"] = func_args[residual_arg_idx] + + activation_shape = annotations[f"arg{data_arg_idx}_shape"] + weight_shape = annotations[f"arg{weight_arg_idx}_shape"] output_shape = annotations["ret_shape"] - activation_var = func_args[0] if "conv2d_transpose" in func_name: headers.append("cutlass/conv/kernel/default_conv2d_dgrad.h") @@ -573,30 +670,82 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ "cutlass/reduction/thread/reduction_operators.h", ] - attrs["N"] = get_dim(activation_shape[0], activation_var, 0) - attrs["H"] = get_dim(activation_shape[1], activation_var, 1) - attrs["W"] = get_dim(activation_shape[2], activation_var, 2) - attrs["C"] = str(int(activation_shape[3])) + data_arg = attrs["data_arg"] + attrs["N"] = get_dim(activation_shape[0], data_arg, 0) + attrs["H"] = get_dim(activation_shape[1], data_arg, 1) + attrs["W"] = get_dim(activation_shape[2], data_arg, 2) + attrs["C"] = activation_shape[3] attrs["P"] = get_dim(output_shape[1], "out0", 1) attrs["Q"] = get_dim(output_shape[2], "out0", 2) - attrs["K"] = str(int(output_shape[3])) - attrs["R"] = str(int(weight_shape[1])) - attrs["S"] = str(int(weight_shape[2])) - attrs["pad_h"] = str(int(annotations["padding"][0])) - attrs["pad_w"] = str(int(annotations["padding"][1])) - attrs["stride_h"] = str(int(annotations["strides"][0])) - attrs["stride_w"] = str(int(annotations["strides"][1])) - attrs["dilation_h"] = str(int(annotations["dilation"][0])) - attrs["dilation_w"] = str(int(annotations["dilation"][1])) + attrs["K"] = output_shape[3] + attrs["R"] = weight_shape[1] + attrs["S"] = weight_shape[2] + attrs["pad_h"] = annotations["padding"][0] + attrs["pad_w"] = annotations["padding"][1] + attrs["stride_h"] = annotations["strides"][0] + attrs["stride_w"] = annotations["strides"][1] + attrs["dilation_h"] = annotations["dilation"][0] + attrs["dilation_w"] = annotations["dilation"][1] if "splitk" in op_name: attrs["split_k_mode"] = "kParallel" attrs["split_k_slices"] = str(re.search(r"splitk(\d+)", op_name).group(1)) else: attrs["split_k_mode"] = "kSerial" - attrs["split_k_slices"] = "1" + attrs["split_k_slices"] = 1 + + code = instantiate_conv2d_template(attrs) + return CodegenResult(code, headers) - code = instantiate_conv2d_template(attrs, func_args) + elif "attention" in func_name: + headers.append("kernel_forward.h") + data_type = dtype_map[annotations["arg0_dtype"]] + attrs["data_type"] = DataTypeTag[data_type] + attrs["num_batches"] = b = annotations["num_batches"] + attrs["num_queries"] = s = annotations["num_queries"] + attrs["num_keys"] = annotations["num_keys"] + attrs["num_heads"] = n = annotations["num_heads"] + attrs["head_dim"] = h = annotations["head_dim"] + attrs["head_dim_value"] = h_v = annotations["head_dim_value"] + data_type_size = DataTypeSize[data_type] + if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0: + attrs["kIsAligned"] = True + elif (h % 4 == 0) and (h_v % 4 == 0): + attrs["kIsAligned"] = False + else: + raise NotImplementedError() + if h_v > 64: + attrs["kQueriesPerBlock"] = 32 + attrs["kKeysPerBlock"] = 128 + attrs["kSingleValueIteration"] = h_v <= 128 + else: + attrs["kQueriesPerBlock"] = 64 + attrs["kKeysPerBlock"] = 64 + attrs["kSingleValueIteration"] = True + attrs["output_size"] = b * s * n * h_v + attrs["scale"] = ( + float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"] + ) + assert ( + attrs["scale"] > 0 or attrs["scale"] < 0 + ), "Cutlass may generate nan occasionally when scale == 0.0" + attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) + attrs["kSupportsDropout"] = False + if len(func_args) > 3: + attrs["kSupportsBias"] = True + if len(annotations["arg3_shape"]) == 4: + attrs["bias_layout"] = "BNSS'" + elif len(annotations["arg3_shape"]) == 3: + attrs["bias_layout"] = "B1SS'" + elif len(annotations["arg3_shape"]) == 2: + attrs["bias_layout"] = "B11S'" + else: + raise NotImplementedError() + else: + # To support negative scale in current Cutlass implementation, + # kSupportsBias should be set true, or there are nan's as result. + attrs["kSupportsBias"] = attrs["scale"] < 0 + code = instantiate_attention_template(attrs, func_args) return CodegenResult(code, headers) raise ValueError("Do not have a template for {}".format(func_name)) diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index 8632ab15641d..ead5804b59a0 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -20,6 +20,8 @@ import enum from enum import auto as enum_auto +from tvm.tir.expr import IntImm, FloatImm + class GeneratorTarget(enum.Enum): Library = enum_auto() @@ -143,6 +145,12 @@ def substitute_template(template, values): while changed: changed = False for key, value in values.items(): + if isinstance(value, (int, IntImm)): + value = str(int(value)) + if isinstance(value, (float, FloatImm)): + value = str(float(value)) + elif isinstance(value, bool): + value = str(value).lower() regex = "\\$\\{%s\\}" % key newtext = re.sub(regex, value, text) if newtext != text: diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index 0fcbcb7c790d..e21a0a00e648 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -23,7 +23,9 @@ from typing import Union import tvm +from tvm import relax from tvm import rpc as _rpc +from tvm.contrib import utils import tvm.contrib.hexagon as hexagon from tvm.relay.backend.executor_factory import ( ExecutorFactoryModule, @@ -283,13 +285,13 @@ def get_graph_debug_executor( graph_json, graph_debug_mod, self.device, dump_root=str(dump_root) ) - def get_executor_from_factory(self, module: ExecutorFactoryModule): + def get_executor_from_factory(self, module: Union[ExecutorFactoryModule, relax.Executable]): """Create a local GraphModule which consumes a remote libmod. Parameters ---------- - module : ExecutorFactoryModule + module : Union[ExecutorFactoryModule, relax.Executable] The module to upload to the remote session and load. @@ -298,6 +300,8 @@ def get_executor_from_factory(self, module: ExecutorFactoryModule): return self._aot_executor_from_factory(module) if isinstance(module, GraphExecutorFactoryModule): return self._graph_executor_from_factory(module) + if isinstance(module, relax.Executable): + return self._relax_vm_executable_executor(module) raise TypeError(f"Unsupported executor type: {type(module)}") @@ -349,6 +353,35 @@ def _graph_executor_from_factory( """ return self.get_graph_executor(module.get_graph_json(), module.get_lib()) + def _relax_vm_executable_executor(self, vm_exec: relax.Executable): + """Create a local TVM module which consumes a remote vm executable. + + Paramters + --------- + + vm_exec : relax.Executable + The Relax VM Executable to upload to the remote and load. This will typically be the + output of `relax.build`. + + Returns + ------- + TVMModule : + TVM module object + """ + assert self._rpc is not None, "Hexagon session must be started using __enter__ prior to use" + + temp_dir = utils.tempdir() + path_exec = temp_dir.relpath("exec.so") + + vm_exec.mod.export_library( + path_exec, + fcompile=hexagon.create_aot_shared, + hexagon_arch="v68", + ) + + path = self.upload(path_exec, "exec.so") + return self._rpc.get_function("tvm.hexagon.load_module")(str(path)) + def _aot_executor_from_factory( self, module: Union[str, pathlib.Path, AOTExecutorFactoryModule], diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py new file mode 100644 index 000000000000..3783baefe01b --- /dev/null +++ b/python/tvm/contrib/tvmjs.py @@ -0,0 +1,301 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Namespace to store utilities for building web runtime.""" +# pylint: disable=unused-import +import sys +import os +import json +import shutil +from typing import Mapping, Union + +import numpy as np + +import tvm +from tvm._ffi.libinfo import find_lib_path +from .emcc import create_tvmjs_wasm + + +def _convert_f32_to_bf16(value): + cap = np.finfo("float32").max + assert -np.finfo("float32").max == np.finfo("float32").min + bf16_limit = ((np.array([cap.view("uint32")]) >> 16) << 16).view("float32")[0] + # When the value is in [-bf16_limit, bf16_limit], round to nearest even. + # We can afford to do it in dumping phase to reduce overall rounding error. + # + # When the value is out of bound(usually mask values in attention), use truncation + # so it is equivalent to clip to the limit values + data = value.view("uint32") + rounding_bias = np.where( + np.logical_and(value < bf16_limit, value > -bf16_limit), + ((data >> 16) & 1) + 0x7FFF, + np.zeros_like(data), + ) + return ((data + rounding_bias) >> 16).astype("uint16") + + +def _convert_bf16_to_f32(value): + data = value.view("uint16") + return (data.astype("uint32") << 16).view("float32") + + +class NDArrayCacheShardingManager: + """Internal helper to shard ndarrays.""" + + def __init__(self, cache_dir: str, prefix: str, shard_cap_nbytes: int): + self.cache_dir = cache_dir + self.prefix = prefix + self.curr_records = [] + self.curr_data = bytearray() + self.shard_records = [] + self.shard_cap_nbytes = shard_cap_nbytes + self.counter = 0 + + def append(self, data, name, shape, dtype, encode_format): + """Commit a record to the manager. + + Parameters + ---------- + data: bytes + Raw bytes to be appended. + + name: str + The name of the parameter + + shape: tuple + The shape of the array + + dtype: str + The dtype information + + encode_format: + The encode format of the entry + """ + rec = { + "name": name, + "shape": shape, + "dtype": dtype, + "format": encode_format, + "nbytes": len(data), + } + + if self.pending_nbytes + len(data) >= self.shard_cap_nbytes: + if len(data) * 2 >= self.shard_cap_nbytes: + # out of band data + rec["byteOffset"] = 0 + self._commit_internal(data, [rec]) + return + self.commit() + rec["byteOffset"] = self.pending_nbytes + self.curr_records.append(rec) + self.curr_data += data + + def commit(self): + """Commit a record""" + if self.pending_nbytes != 0: + self._commit_internal(self.curr_data, self.curr_records) + self.curr_data = bytearray() + self.curr_records = [] + + def finish(self): + """Finish building and return shard records.""" + self.commit() + return self.shard_records + + def _commit_internal(self, data, records): + data_path = f"{self.prefix}_{self.counter}.bin" + self.counter += 1 + with open(os.path.join(self.cache_dir, data_path), "wb") as outfile: + outfile.write(data) + shard_record = { + "dataPath": data_path, + "format": "raw-shard", + "nbytes": len(data), + "records": records, + } + self.shard_records.append(shard_record) + + @property + def pending_nbytes(self): + """Return total bytes stored so far""" + return len(self.curr_data) + + +def dump_ndarray_cache( + params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]], + cache_dir: str, + encode_format="f32-to-bf16", + meta_data=None, + shard_cap_mb=32, +): + """Dump parameters to NDArray cache. + + Parameters + ---------- + params: Mapping[str, tvm.runtime.NDArray], + The parameter dictionary + + cache_dir: str + The path to the cache + + encode_format: {"f32-to-bf16", "raw"} + Encoding format. + + meta_data: json-compatible-struct + Extra meta_data to be stored in the cache json file. + + shard_cap_mb: int + Maxinum number of MB to be kept per shard + """ + if encode_format not in ("raw", "f32-to-bf16"): + raise ValueError(f"Invalie encode_format {encode_format}") + + meta_data = {} if meta_data is None else meta_data + records = [] + total = len(params) + counter = 0 + max_out_length = 0 + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + f32_to_bf16_triggered = False + + print("Start storing to cache %s" % cache_dir) + shard_cap_nbytes = shard_cap_mb * (1 << 20) + + shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard", shard_cap_nbytes) + + for k, v in params.items(): + shape = list(v.shape) + + if not isinstance(v, np.ndarray): + v = v.numpy() + + # convert fp32 to bf16 + if encode_format == "f32-to-bf16" and v.dtype == "float32": + data = _convert_f32_to_bf16(v).tobytes() + dtype = "bfloat16" + f32_to_bf16_triggered = True + else: + data = v.tobytes() + dtype = str(v.dtype) + + shard_manager.append(data, name=k, shape=shape, dtype=dtype, encode_format=encode_format) + + counter += 1 + last_cmd = "[%04d/%04d] saving %s" % (counter, total, k) + flush = "\r" + (" " * max_out_length) + "\r" + max_out_length = max(len(last_cmd), max_out_length) + sys.stdout.write(flush + last_cmd) + + records = shard_manager.finish() + + nd_cache_json = os.path.join(cache_dir, "ndarray-cache.json") + + with open(nd_cache_json, "w") as outfile: + json.dump({"metadata": meta_data, "records": records}, outfile, indent=4) + print( + f"\nAll finished, %d total shards committed, record saved to %s" + % (shard_manager.counter, nd_cache_json) + ) + + if f32_to_bf16_triggered: + for shard in records: + for item in shard["records"]: + if item["dtype"] == "float32": + item["format"] = "raw" + item["dtype"] = "bfloat16" + b16_nd_cache_json = os.path.join(cache_dir, "ndarray-cache-b16.json") + # also dump a file that contains bf16 + with open(b16_nd_cache_json, "w") as outfile: + json.dump({"metadata": meta_data, "records": records}, outfile, indent=4) + print("Also saved a bf16 record to %s" % b16_nd_cache_json) + + +def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device): + """Load the ndarray cache from the directory or json. + + + Parameters + ---------- + cachepath: str + Path to the location or json file. + + device: tvm.runtime.Device + The device we would like to load the data from. + """ + if not cachepath.endswith(".json"): + cachepath = os.path.join(cachepath, "ndarray-cache.json") + + cachedir = os.path.dirname(cachepath) + json_info = json.loads(open(cachepath, "r").read()) + result_dict = {} + + for shard_rec in json_info["records"]: + data_path = shard_rec["dataPath"] + full_data_path = os.path.join(cachedir, data_path) + raw_data = open(full_data_path, "rb").read() + assert shard_rec["format"] == "raw-shard" + assert shard_rec["nbytes"] == len(raw_data) + + for rec in shard_rec["records"]: + name = rec["name"] + shape = rec["shape"] + dtype = rec["dtype"] + encode_format = rec["format"] + offset = rec["byteOffset"] + nbytes = rec["nbytes"] + + arr = tvm.nd.empty(shape, dtype, device=device) + assert offset + nbytes <= len(raw_data) + buffer_source = raw_data[offset : offset + nbytes] + if encode_format == "f32-to-bf16": + data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape) + arr.copyfrom(_convert_bf16_to_f32(data)) + else: + data = np.frombuffer(buffer_source, dtype=dtype).reshape(shape) + arr.copyfrom(data) + result_dict[name] = arr + return result_dict, json_info["metadata"] + + +def export_runtime(runtime_dir): + """Export TVMJS runtime to the runtime_dir + + Parameters + ---------- + runtime_dir: str + The runtime directory + """ + web_hint = ( + "make sure you setup tvm web runtime correctly." + + " obtain a copy of TVM source code, set TVM_HOME env variable:\n" + + " cd /path/to/tvm/web; make; npm run bundle" + ) + + jsbundle = find_lib_path("tvmjs.bundle.js", optional=True) + if not jsbundle: + raise RuntimeError("Cannot find tvmjs.bundle.js, " + web_hint) + + wasi = find_lib_path("tvmjs_runtime.wasi.js", optional=True) + if not wasi: + raise RuntimeError("Cannot find tvmjs_runtime.wasi.js, " + web_hint) + + print(f"Copy {jsbundle[0]} to {runtime_dir}") + shutil.copy(jsbundle[0], runtime_dir) + print(f"Copy {wasi[0]} to {runtime_dir}") + shutil.copy(wasi[0], runtime_dir) diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py index 7eae4fe1742f..8cf1e4010bea 100644 --- a/python/tvm/exec/rpc_proxy.py +++ b/python/tvm/exec/rpc_proxy.py @@ -19,6 +19,7 @@ import logging import argparse import os +import glob from tvm.rpc.proxy import Proxy @@ -27,19 +28,38 @@ def find_example_resource(): curr_path = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) base_path = os.path.abspath(os.path.join(curr_path, "..", "..", "..")) index_page = os.path.join(base_path, "web", "apps", "browser", "rpc_server.html") + default_plugin_page = os.path.join(base_path, "web", "apps", "browser", "rpc_plugin.html") + resource_files = [ - os.path.join(base_path, "web", "dist", "tvmjs.bundle.js"), - os.path.join(base_path, "web", "dist", "wasm", "tvmjs_runtime.wasi.js"), + ("/", os.path.join(base_path, "web", "dist", "tvmjs.bundle.js")), + ("/", os.path.join(base_path, "web", "dist", "wasm", "tvmjs_runtime.wasi.js")), + ("/", index_page), + ] + allow_format = ("json", "bin", "js", "wasm", "html") + + # recursively apend things in www, up to two levels + resource_bases = [ + os.path.join(base_path, "web", "dist", "www"), + os.path.join(base_path, "web", ".ndarray_cache"), ] - resource_base = os.path.join(base_path, "web", "dist", "www") - if os.path.isdir(resource_base): - for fname in os.listdir(resource_base): - full_name = os.path.join(resource_base, fname) - if os.path.isfile(full_name): - resource_files.append(full_name) - for fname in [index_page] + resource_files: + for base in resource_bases: + if not os.path.isdir(base): + continue + for full_name in glob.glob("%s/**" % base, recursive=True): + fname = os.path.relpath(full_name, base) + dirname = os.path.dirname(fname) + fmt = fname.rsplit(".", 1)[-1] + if os.path.isfile(full_name) and fmt in allow_format: + resource_files.append((dirname, full_name)) + + for item in resource_files: + fname = item[-1] if not os.path.exists(fname): raise RuntimeError("Cannot find %s" % fname) + + if not any(item[-1].endswith("rpc_plugin.html") for item in resource_files): + resource_files.append(("/", default_plugin_page)) + return index_page, resource_files diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 4f63cbecd9d1..01fea2abbda7 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -34,6 +34,7 @@ from .container import Array, Map from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr from .function import BaseFunc, CallingConv +from .global_info import GlobalInfo, DummyGlobalInfo from .memory_pools import ( ConstantMemoryPools, ConstantPoolInfo, diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 3c3fefb6d6c6..721e12e7f8d9 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -51,6 +51,17 @@ def checked_type(self): raise ValueError("The type checker has not populated" " the checked_type for this node") return ret + @property + def struct_info(self) -> "tvm.relax.StructInfo": + """Get the struct info field + + Returns + ------- + struct_info : tvm.relax.StructInfo + The struct info if available. + """ + return _ffi_api.ExprStructInfo(self) + @tvm._ffi.register_object("GlobalVar") class GlobalVar(RelayExpr): @@ -82,10 +93,17 @@ def __call__(self, *args): A call taking the variable as a function. """ # pylint: disable=import-outside-toplevel + + # TODO(@relax-team): replace with Relax base class after it's introduced if all(isinstance(x, RelayExpr) for x in args): - from tvm import relay + if all(is_relax_expr(x) for x in args): + from tvm import relax + + return relax.Call(self, args) + else: + from tvm import relay - return relay.Call(self, args) + return relay.Call(self, args) arg_types = [type(x) for x in args] raise RuntimeError( "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types) @@ -174,3 +192,42 @@ def from_min_extent(min_value, extent, span=None): The constructed range. """ return _ffi_api.Range_from_min_extent(min_value, extent, span) + + +# TODO(@relax-team): remove when we have a RelaxExpr base class +def is_relax_expr(expr: RelayExpr) -> bool: + """check if a RelayExpr is a Relax expresssion. + + Parameters + ---------- + expr : RelayExpr + The expression to check. + + Returns + ------- + res : bool + If the expression is Relax expression, return True; otherwise return False. + """ + from tvm import relax # pylint: disable=import-outside-toplevel + + if isinstance( + expr, + ( + relax.Call, + relax.Constant, + relax.Tuple, + relax.TupleGetItem, + relax.If, + relax.Var, + relax.DataflowVar, + relax.ShapeExpr, + relax.SeqExpr, + relax.Function, + relax.ExternFunc, + relax.PrimValue, + relax.StringImm, + relax.DataTypeImm, + ), + ): + return True + return False diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index c3f1bf5f562a..b64553d31ce1 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -14,11 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Function defintiions.""" +"""Function definitions.""" +from typing import Union, Dict from enum import IntEnum import tvm.runtime - +from tvm.runtime.object import Object from .expr import RelayExpr +from .attrs import DictAttrs from . import _ffi_api @@ -38,7 +40,7 @@ def attrs(self): """Return the attrs member of the function.""" return _ffi_api.BaseFunc_Attrs(self) - def with_attr(self, attr_key_or_dict, attr_value=None): + def with_attr(self, attr_key_or_dict, attr_value=None) -> "BaseFunc": """Create a new copy of the function and update the attribute. Parameters @@ -51,7 +53,7 @@ def with_attr(self, attr_key_or_dict, attr_value=None): Returns ------- - func : Function + func : BaseFunc A new copy of the function """ # make sure we first copy so that we can safely do copy on write @@ -66,3 +68,35 @@ def with_attr(self, attr_key_or_dict, attr_value=None): return _ffi_api.BaseFuncWithAttr( res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value) ) + + def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> "BaseFunc": + """Copy the IRModule and add the given attribute map to it. + Parameters + ---------- + attr_map: Union[DictAttrs, Dict[str, Object]] + The attribute map + Returns + ------- + func : BaseFunc + A new copy of the function + """ + if isinstance(attr_map, tvm.ir.DictAttrs): + attr_map = attr_map._dict() + + return _ffi_api.BaseFuncWithAttrs(self, attr_map) + + def without_attr(self, attr_key: str) -> "BaseFunc": + """Create a new copy of the function with an attribute without provided key. + + Parameters + ---------- + attr_key : str + The attribute key to delete from the attrubte pairs. + + + Returns + ------- + func : BaseFunc + A new copy of the function + """ + return _ffi_api.BaseFuncWithoutAttr(self, attr_key) diff --git a/python/tvm/ir/global_info.py b/python/tvm/ir/global_info.py new file mode 100644 index 000000000000..17011e76a66c --- /dev/null +++ b/python/tvm/ir/global_info.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Global Info.""" +import tvm +from tvm.runtime.object import Object +from . import _ffi_api + + +class GlobalInfo(Object): + """Base node for all global info that can appear in the IR""" + + def __eq__(self, other): + """Compare two struct info for structural equivalence.""" + return tvm.ir.structural_equal(self, other) + + def __ne__(self, other): + return not self.__eq__(other) + + def same_as(self, other): + """Overload with structural equality.""" + return super().__eq__(other) + + +class DummyGlobalInfo(GlobalInfo): + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.DummyGlobalInfo, + ) diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 3daffb2640c5..707d46d0cdf8 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -15,13 +15,18 @@ # specific language governing permissions and limitations # under the License. """IRModule that holds the functions and type definitions.""" +from __future__ import annotations + +from typing import Dict, Union import tvm._ffi from tvm._ffi.base import string_types from tvm.runtime import Scriptable +from tvm.runtime.object import Object from . import _ffi_api from . import expr as _expr from . import type as _ty +from .attrs import DictAttrs from .base import Node @@ -37,7 +42,7 @@ class IRModule(Node, Scriptable): Map of global var to BaseFunc """ - def __init__(self, functions=None, type_definitions=None): + def __init__(self, functions=None, type_definitions=None, attrs=None, global_infos=None): if functions is None: functions = {} elif isinstance(functions, dict): @@ -60,7 +65,20 @@ def __init__(self, functions=None, type_definitions=None): raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") mapped_type_defs[k] = v type_definitions = mapped_type_defs - self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions) + + attrs = None if not attrs else attrs + if attrs is not None: + attrs = ast.literal_eval(str(attrs)) + attrs = tvm.ir.make_node("DictAttrs", **attrs) + if global_infos is None: + global_infos = {} + self.__init_handle_by_constructor__( + _ffi_api.IRModule, + functions, + type_definitions, + attrs, + global_infos, + ) def __setitem__(self, var, val): """Add a mapping to the module. @@ -135,6 +153,19 @@ def update_func(self, var, func): """ return _ffi_api.Module_UpdateFunction(self, var, func) + def update_global_info(self, name, global_info): + """Update global info in the module + + Parameters + ---------- + name: str + The name for the global info. + + global_info: List[GlobalInfo] + The global info to be updated. + """ + return _ffi_api.Module_UpdateGlobalInfo(self, name, global_info) + def get_global_var(self, name): """Get a global variable in the function by name. @@ -286,6 +317,36 @@ def with_attr(self, attr_key, attr_value): return _ffi_api.Module_WithAttr(self, attr_key, attr_value) + def without_attr(self, attr_key: str) -> "IRModule": + """Copy the IRModule and remove an attribute key and its associated value. + Parameters + ---------- + attr_key : str + The attribute key. + Returns + ------- + mod : IRModule + A new copy of the IRModule without the attribute + """ + + return _ffi_api.Module_WithoutAttr(self, attr_key) + + def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> "IRModule": + """Copy the IRModule and add the given attribute map to it. + Parameters + ---------- + attr_map: Union[DictAttrs, Dict[str, Object]] + The attribute map + Returns + ------- + mod : IRModule + A new copy of the IRModule with the attribute + """ + if isinstance(attr_map, tvm.ir.DictAttrs): + attr_map = attr_map._dict() + + return _ffi_api.Module_WithAttrs(self, attr_map) + def astext(self, show_meta_data=True, annotate=None): """Get the text format of the expression. diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index f7d40dc68147..c93b4eda2664 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -44,8 +44,10 @@ class PassInfo(tvm.runtime.Object): The list of passes that are required by a certain pass. """ - def __init__(self, opt_level, name, required=None): - self.__init_handle_by_constructor__(_ffi_transform_api.PassInfo, opt_level, name, required) + def __init__(self, opt_level, name, required=None, traceable=False): + self.__init_handle_by_constructor__( + _ffi_transform_api.PassInfo, opt_level, name, required, traceable + ) @tvm._ffi.register_object("transform.PassContext") @@ -69,6 +71,20 @@ class PassContext(tvm.runtime.Object): config : Optional[Dict[str, Object]] Additional configurations for specific passes. + + trace: Optional[relax.tuning.Trace] + Initial trace for trace mode. + + trace_stack: Optional[List[relax.tuning_api.Trace]] + Initial trace stack for trace mode. + + make_traceable: Optional[List[str]] + List of passes to make traceable. + + num_evals: int + initial number of evaluations conducted in the pipeline. + + tuning_api_database: Optional[relax.tuning_api.JSONDatabase] """ def __init__( @@ -78,6 +94,11 @@ def __init__( disabled_pass=None, instruments=None, config=None, + trace=None, + trace_stack=None, + make_traceable=None, + num_evals=0, + tuning_api_database=None, ): required = list(required_pass) if required_pass else [] if not isinstance(required, (list, tuple)): @@ -91,9 +112,25 @@ def __init__( if not isinstance(instruments, (list, tuple)): raise TypeError("instruments is expected to be the type of " + "list/tuple/set.") + # Convert to Map + # TODO(sunggg): Replace this to Set equivalent if exists + make_traceable = {name: True for name in make_traceable} if make_traceable else None + + if not trace_stack: + trace_stack = [trace] if trace else [] + config = config if config else None self.__init_handle_by_constructor__( - _ffi_transform_api.PassContext, opt_level, required, disabled, instruments, config + _ffi_transform_api.PassContext, + opt_level, + required, + disabled, + instruments, + config, + trace_stack, + make_traceable, + num_evals, + tuning_api_database, ) def __enter__(self): @@ -130,6 +167,47 @@ def list_configs(): """ return _ffi_transform_api.ListConfigs() + def push_trace(self, trace): + """Push a trace into the stack.""" + return _ffi_transform_api.PushTrace(self, trace) + + def pop_trace(self, return_current=True): + """Pop a topmost trace from the stack. + Returns + ------- + Trace : Optional[relax.tuning.Trace] + """ + if return_current: + cur_trace = self.get_current_trace() + _ffi_transform_api.PopTrace(self) + return cur_trace + + return _ffi_transform_api.PopTrace(self) + + def get_trace_stack(self): + """Get the current trace stack.""" + return _ffi_transform_api.GetTraceStack(self) + + def get_trace_stack_size(self): + """Get the size of current stack.""" + return _ffi_transform_api.GetTraceStackSize(self) + + def get_current_trace(self): + """Get the trace on the top of the stack.""" + return _ffi_transform_api.GetCurrentTrace(self) + + def set_num_evals(self, num: int): + """Set the number of evaluations conducted in the pipeline.""" + return _ffi_transform_api.SetNumEvals(self, num) + + def inc_num_evals(self, num: int): + """Increment the number of evaluations conducted in the pipeline.""" + return _ffi_transform_api.IncNumEvals(self, num) + + def get_tuning_api_database(self): + """Get tuning api database.""" + return _ffi_transform_api.GetTuningAPIDatabase(self) + @tvm._ffi.register_object("transform.Pass") class Pass(tvm.runtime.Object): @@ -198,7 +276,7 @@ class Sequential(Pass): The list of passes that the sequential pass is dependent on. """ - def __init__(self, passes=None, opt_level=0, name="sequential", required=None): + def __init__(self, passes=None, opt_level=0, name="sequential", required=None, traceable=False): passes = passes if passes else [] if not isinstance(passes, (list, tuple)): raise TypeError("passes must be a list of Pass objects.") @@ -208,7 +286,7 @@ def __init__(self, passes=None, opt_level=0, name="sequential", required=None): raise TypeError("Required is expected to be the type of list/tuple.") self.__init_handle_by_constructor__( - _ffi_transform_api.Sequential, passes, opt_level, name, required + _ffi_transform_api.Sequential, passes, opt_level, name, required, traceable ) @@ -244,7 +322,7 @@ def __getattr__(self, name): return PyModulePass -def module_pass(pass_func=None, opt_level=None, name=None, required=None): +def module_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False): """Decorate a module pass. This function returns a callback when pass_func is provided. @@ -269,6 +347,9 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None): required : Optional[List[str]] The list of passes that the module pass is dependent on. + traceable: Boolean + Boolean variable whether the module pass is traceable + Returns ------- create_module_pass : Union[Callable, ModulePass] @@ -336,7 +417,7 @@ def transform(mod, ctx): def create_module_pass(pass_arg): """Internal function that creates a module pass""" fname = name if name else pass_arg.__name__ - info = PassInfo(opt_level, fname, required) + info = PassInfo(opt_level, fname, required, traceable) if inspect.isclass(pass_arg): return _wrap_class_module_pass(pass_arg, info) if not callable(pass_arg): diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 30a4fc6d9467..21a11ff9e84d 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -25,6 +25,7 @@ mutator, postproc, relay_integration, + relax_integration, runner, schedule, schedule_rule, diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py new file mode 100644 index 000000000000..db22214b768f --- /dev/null +++ b/python/tvm/meta_schedule/relax_integration.py @@ -0,0 +1,352 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta schedule integration with high-level IR""" +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +# isort: off +from typing_extensions import Literal + +# isort: on + +from tvm._ffi import get_global_func, register_func +from tvm.ir import IRModule +from tvm.ir.transform import PassContext +from tvm.runtime import NDArray +from tvm.target import Target +from tvm.tir.expr import IntImm + +from .builder import Builder +from .cost_model import CostModel +from .database import Database +from .extracted_task import ExtractedTask +from .logging import get_loggers_from_work_dir +from .measure_callback import MeasureCallback +from .runner import Runner +from .search_strategy import SearchStrategy +from .space_generator import SpaceGenerator +from .task_scheduler import TaskScheduler +from .tune import tune_tasks +from .tune_context import TuneContext +from .utils import fork_seed + +if TYPE_CHECKING: + from tvm import relax + +_extract_task_func = get_global_func( # pylint: disable=invalid-name + "relax.backend.MetaScheduleExtractTask", + allow_missing=False, +) + + +def extract_tasks( + mod: Union[IRModule, "relax.Function"], + target: Target, + params: Optional[Dict[str, NDArray]] = None, +) -> List[ExtractedTask]: + """Extract tuning tasks from a relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + target : tvm.target.Target + The compilation target + + Returns + ------- + tasks: List[ExtractedTask] + The tasks extracted from this module + """ + # pylint: disable=import-outside-toplevel + from tvm.relax.expr import Function as RelaxFunc + from tvm.relax.transform import BindParams + + # pylint: enable=import-outside-toplevel + if isinstance(mod, RelaxFunc): + mod = IRModule({"main": mod}) + if not isinstance(target, Target): + target = Target(target) + if params: + mod = BindParams("main", params)(mod) + return list(_extract_task_func(mod, target)) + + +def extracted_tasks_to_tune_contexts( + extracted_tasks: List[ExtractedTask], + work_dir: str, + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + num_threads: Union[Literal["physical", "logical"], int] = "physical", + seed: Optional[int] = None, +) -> Tuple[List[TuneContext], List[float]]: + """Convert ExtractedTask to TuneContext. + + Parameters + ---------- + tasks : List[ExtractedTask] + The tasks to be converted + work_dir : str + The working directory to store logs and databases + space : SpaceGenerator.SpaceGeneratorType + The space generator to use. + strategy : SearchStrategy.SearchStrategyType + The search strategy to use. + num_threads : Union[Literal["physical", "logical"], int] + The number of threads to use in multi-threaded search algorithm. + seed : Optional[int] + The random seed to use. + + Returns + ------- + tasks : List[TuneContext] + The converted tasks + task_weights : List[float] + The weights of the tasks + """ + tasks: List[TuneContext] = [] + task_weights: List[float] = [] + for task, logger, rand_state in zip( + extracted_tasks, + get_loggers_from_work_dir(work_dir, [t.task_name for t in extracted_tasks]), + fork_seed(seed, n=len(extracted_tasks)), + ): + tasks.append( + TuneContext( + mod=task.dispatched[0], + target=task.target, + space_generator=space, + search_strategy=strategy, + task_name=task.task_name, + logger=logger, + rand_state=rand_state, + num_threads=num_threads, + ).clone() + ) + task_weights.append(task.weight) + return tasks, task_weights + + +def tune_relax( + mod: Union[IRModule, "relax.Function"], + params: Dict[str, NDArray], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + max_trials_per_task: Optional[int] = None, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + seed: Optional[int] = None, +) -> Database: + """Tune a Relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + target : Union[Target, str] + The compilation target + work_dir : str + The working directory to store the tuning records + max_trials_global : int + The maximum number of trials to run + max_trials_per_task : Optional[int] + The maximum number of trials to run for each task + num_trials_per_iter : int + The number of trials to run per iteration + builder : BuilderType + The builder to use + runner : RunnerType + The runner to use + database : DatabaseType + The database to use + cost_model : CostModelType + The cost model to use + measure_callbacks : CallbackListType + The measure callbacks to use + task_scheduler : TaskSchedulerType + The task scheduler to use + space : SpaceGeneratorType + The space generator to use + strategy : SearchStrategyType + The search strategy to use + seed : Optional[int] + The random seed + + Returns + ------- + database : Database + The database that contains the tuning records + """ + tasks, task_weights = extracted_tasks_to_tune_contexts( + extracted_tasks=extract_tasks(mod, target, params), + work_dir=work_dir, + space=space, + strategy=strategy, + seed=seed, + ) + return tune_tasks( + tasks=tasks, + task_weights=task_weights, + work_dir=work_dir, + max_trials_global=max_trials_global, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + ) + + +@register_func("tvm.meta_schedule.tune_relax") +def _tune_relax( + mod: Union[IRModule, "relax.Function"], + params: Dict[str, NDArray], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + max_trials_per_task: Optional[int] = None, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + seed: Optional[int] = None, +) -> Database: + """Interface with tuning api to tune a Relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + target : Union[Target, str] + The compilation target + work_dir : str + The working directory to store the tuning records + max_trials_global : int + The maximum number of trials to run + max_trials_per_task : Optional[int] + The maximum number of trials to run for each task + num_trials_per_iter : int + The number of trials to run per iteration + builder : BuilderType + The builder to use + runner : RunnerType + The runner to use + database : DatabaseType + The database to use + cost_model : CostModelType + The cost model to use + measure_callbacks : CallbackListType + The measure callbacks to use + task_scheduler : TaskSchedulerType + The task scheduler to use + space : SpaceGeneratorType + The space generator to use + strategy : SearchStrategyType + The search strategy to use + seed : Optional[int] + The random seed + + Returns + ------- + ret_mod : IRModule + IRModule + """ + if isinstance(max_trials_global, IntImm): + max_trials_global = int(max_trials_global) + + tune_relax( + mod, + params, + target, + work_dir, + max_trials_global, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + space=space, + strategy=strategy, + seed=seed, + ) + # Return original IRModule + # This pass only makes optimization decision + return mod + + +def compile_relax( + database: Database, + mod: IRModule, + target: Union[Target, str], + params: Optional[Dict[str, NDArray]], +) -> "relax.Executable": + """Compile a relax program with a MetaSchedule database. + + Parameters + ---------- + database : Database + The database to use + mod : IRModule + The Relax program to be compiled + target : tvm.target.Target + The compilation target + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + + Returns + ------- + lib : relax.Executable + The built runtime module or vm Executable for the given relax workload. + """ + # pylint: disable=import-outside-toplevel + from tvm.relax.transform import BindParams, MetaScheduleApplyDatabase + from tvm.relax import build as relax_build + + # pylint: enable=import-outside-toplevel + if not isinstance(target, Target): + target = Target(target) + if params: + mod = BindParams("main", params)(mod) + + with target, database, PassContext(opt_level=3): + relax_mod = MetaScheduleApplyDatabase()(mod) + relax_ex = relax_build(relax_mod, target=target) + return relax_ex diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index f3d505c28b0e..d5f5ee86e0b8 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -22,7 +22,9 @@ # isort: on from tvm import ir, tir +from tvm._ffi import register_func from tvm.target import Target +from tvm.tir.expr import IntImm from .builder import Builder from .cost_model import CostModel @@ -128,6 +130,93 @@ def tune_tir( ) +@register_func("tvm.meta_schedule.tune_tir") +def _tune_tir( + mod: Union[ir.IRModule, tir.PrimFunc], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "round-robin", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + task_name: str = "main", + num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", + seed: Optional[int] = None, +) -> Database: + """Interface with tuning api to tune a TIR program. + + Parameters + ---------- + mod : Union[ir.IRModule, tir.PrimFunc] + The TIR function to tune. + target : Union[str, Target] + The target to tune for. + work_dir : str + The working directory. + max_trials_global : int + The maximum number of trials to run globally. + num_trials_per_iter : int + The number of trials to run per iteration + builder : Builder.BuilderType + The builder. + runner : Runner.RunnerType + The runner. + database : Database.DatabaseType + The database. + cost_model : CostModel.CostModelType + The cost model. + measure_callbacks : MeasureCallback.CallbackListType + The measure callbacks. + task_scheduler : TaskScheduler.TaskSchedulerType + The task scheduler. + space : SpaceGenerator.SpaceGeneratorType + The space generator. + strategy : SearchStrategy.SearchStrategyType + The search strategy. + task_name : str + The name of the task. + num_tuning_cores : Union[Literal["physical", "logical"], int] + The number of CPU cores to use during tuning. + seed : Optional[int] + The seed for the random number generator. + + Returns + ------- + ret_mod : IRModule + IRModule + """ + if isinstance(max_trials_global, IntImm): + max_trials_global = int(max_trials_global) + tune_tir( + mod, + target, + work_dir, + max_trials_global, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + space=space, + strategy=strategy, + task_name=task_name, + num_tuning_cores=num_tuning_cores, + seed=seed, + ) + # Return original IRModule + # This pass only makes optimization decision + return mod + + def compile_tir( database: Database, mod: Union[ir.IRModule, tir.PrimFunc], diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index f15976b1cc47..bf9fa792234a 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -24,7 +24,7 @@ # isort: on from tvm import IRModule -from tvm._ffi import register_object +from tvm._ffi import register_object, register_func from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule @@ -41,6 +41,7 @@ from .space_generator import SpaceGenerator +@register_func("tvm.meta_schedule.normalize_mod") def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: """Normalize the input to an IRModule""" if isinstance(mod, PrimFunc): diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index fb1ddd6585f2..5bf96aef775e 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -75,14 +75,27 @@ def _extract(inst: type, name: str): def method(*args, **kwargs): return getattr(inst, name)(*args, **kwargs) - if getattr(base, name) is getattr(cls, name) and name != "__str__": - # for task scheduler return None means calling default function - # otherwise it will trigger a TVMError of method not implemented - # on the c++ side when you call the method, __str__ not required - return None - return method + for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]): + # extract functions that differ from the base class + if not hasattr(base_cls, name): + continue + if getattr(base_cls, name) is getattr(inherit_cls, name) and name != "__str__": + continue + return method + + # for task scheduler return None means calling default function + # otherwise it will trigger a TVMError of method not implemented + # on the c++ side when you call the method, __str__ not required + return None assert isinstance(cls.__base__, type) + if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": # type: ignore + raise TypeError( + ( + f"Inheritance from a decorated object `{cls.__name__}` is not allowed. " + f"Please inherit from `{cls.__name__}._cls`." + ) + ) assert hasattr( cls, "_tvm_metadata" ), "Please use the user-facing method overriding class, i.e., PyRunner." @@ -95,6 +108,9 @@ def method(*args, **kwargs): class TVMDerivedObject(metadata["cls"]): # type: ignore """The derived object to avoid cyclic dependency.""" + _cls = cls + _type = "TVMDerivedObject" + def __init__(self, *args, **kwargs): """Constructor.""" self.handle = None @@ -111,12 +127,22 @@ def __init__(self, *args, **kwargs): # using weakref to avoid cyclic dependency self._inst._outer = weakref.ref(self) - def __getattr__(self, name: str): - """Bridge the attribute function.""" - try: - return self._inst.__getattribute__(name) - except AttributeError: - return super(TVMDerivedObject, self).__getattr__(name) + def __getattr__(self, name): + # fall back to instance attribute if there is not any + # return self._inst.__getattribute__(name) + import inspect # pylint: disable=import-outside-toplevel + + result = self._inst.__getattribute__(name) + if inspect.ismethod(result): + + def method(*args, **kwargs): + return result(*args, **kwargs) + + # set __own__ to aviod implicit deconstruction + setattr(method, "__own__", self) + return method + + return result def __setattr__(self, name, value): if name not in ["_inst", "key", "handle"]: diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py new file mode 100644 index 000000000000..5b17fb5d115d --- /dev/null +++ b/python/tvm/relax/__init__.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, wrong-import-position +"""The Relax IR namespace containing the IR, type, operator, builder, vm, etc.""" +from tvm.runtime import relax_vm as vm +from tvm.runtime.relax_vm import VirtualMachine, VMInstrumentReturnKind + +# Expr +from .expr import ( + Expr, + Span, + SourceName, + Id, + GlobalVar, + Var, + DataflowVar, + Binding, + MatchCast, + VarBinding, + BindingBlock, + DataflowBlock, + SeqExpr, + ShapeExpr, + Tuple, + TupleGetItem, + Function, + ExternFunc, + Call, + If, + Constant, + PrimValue, + DataTypeImm, + StringImm, +) + +from .expr import const, extern, get_shape_of + +# Type +from .ty import Type, ObjectType, ShapeType, DynTensorType, TupleType, FuncType, PackedFuncType + +# VM +from .exec_builder import ExecBuilder + +# Operator +from .op.base import call_tir, call_dps_packed + +# BlockBuilder +from .block_builder import BlockBuilder + +# ExprFunctor +from .expr_functor import ExprFunctor, PyExprVisitor, PyExprMutator + +# StructInfo +from .struct_info import ( + StructInfo, + ObjectStructInfo, + PrimStructInfo, + ShapeStructInfo, + TensorStructInfo, + TupleStructInfo, + FuncStructInfo, +) + +# pipeline +from .pipeline import get_pipeline + +# Import submodules in the last to avoid dependency +from . import exec_builder +from . import expr +from . import ty +from . import analysis +from . import transform +from . import block_builder +from . import op +from . import struct_info +from . import backend +from . import frontend + +# VM +from .vm_build import build, Executable + +from .binding_rewrite import DataflowBlockRewrite diff --git a/python/tvm/relax/_ffi_api.py b/python/tvm/relax/_ffi_api.py new file mode 100644 index 000000000000..a127e1c81378 --- /dev/null +++ b/python/tvm/relax/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI API for Relax.""" +import tvm._ffi + +tvm._ffi._init_api("relax", __name__) diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py new file mode 100644 index 000000000000..7ba56ff40840 --- /dev/null +++ b/python/tvm/relax/analysis/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax IR analysis. """ + +from .analysis import * +from .estimate_memory_usage import estimate_memory_usage diff --git a/python/tvm/relax/analysis/_ffi_api.py b/python/tvm/relax/analysis/_ffi_api.py new file mode 100644 index 000000000000..40ee05c3960d --- /dev/null +++ b/python/tvm/relax/analysis/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs""" +import tvm._ffi + +tvm._ffi._init_api("relax.analysis", __name__) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py new file mode 100644 index 000000000000..2a2c3d87b88d --- /dev/null +++ b/python/tvm/relax/analysis/analysis.py @@ -0,0 +1,424 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +""" +This file contains the set of passes for Relax, which exposes an interface for +configuring the passes and scripting them in Python. +""" + +from typing import Dict, List, Union, Callable +from enum import IntEnum + +import tvm +from tvm import tir +from tvm import IRModule +from tvm.relax.ty import Type +from tvm.relax.struct_info import StructInfo, FuncStructInfo +from tvm.relax.expr import DataflowBlock, Var, GlobalVar, Expr, Function, Call, Binding +from tvm.tir import IndexMap, PrimFunc, Block, Buffer +from . import _ffi_api + + +def get_static_type(sinfo: StructInfo) -> Type: + """Get the corresponding static type from a StructInfo. + + Parameters + ---------- + sinfo : StructInfo + The input struct info. + + Returns + ------- + ret : Type + The corresponding static type. + """ + return _ffi_api.GetStaticType(sinfo) # type: ignore + + +def erase_to_well_defined( + sinfo: StructInfo, + shape_var_map: Dict[tir.Var, tir.PrimExpr] = None, + var_map: Dict[Var, Expr] = None, +) -> StructInfo: + """Erase sinfo into a well defined form. + + This function removes the StructInfo's dependencies on shape and vars that + are not defined in given maps. + + Parameters + ---------- + sinfo : StructInfo + The input struct info. + + shape_var_map : Dict[tir.Var, tir.PrimExpr] + Specifies the defined shape vars and the values they should map to. + + var_map : Dict[Var, Expr] + Specifies the defined vars and the values they should map to. + + Returns + ------- + ret : StructInfo + The corresponding erased struct info. + """ + shape_var_map = {} if shape_var_map is None else shape_var_map + var_map = {} if var_map is None else var_map + + return _ffi_api.EraseToWellDefined(sinfo, shape_var_map, var_map) # type: ignore + + +class BaseCheckResult(IntEnum): + """Return result of fine-grained base check. + + Note + ---- + Base check comes with fine-grained fail levels. + + - FAIL_L0: The lhs and rhs have no intersection at all. + - FAIL_L1: We get the failure by looking at static information. + - FAIL_L2: We get the failure due to unknown symbolic variable relations. + """ + + FAIL_L0 = 0 + FAIL_L1 = 1 + FAIL_L2 = 2 + PASS = 3 + + +def struct_info_base_check(base: StructInfo, derived: StructInfo) -> BaseCheckResult: + """Run a base check to see if base subsumes derived. + + Parameters + ---------- + base: StructInfo + The base struct info. + + derived: StructInfo + The derived struct info. + + Returns + ------- + ret : StructInfo + The derived return value struct info. + """ + return _ffi_api.StructInfoBaseCheck(base, derived) # type: ignore + + +def derive_call_ret_struct_info( + func_sinfo: FuncStructInfo, call: Call, ctx: "tvm.relax.BlockBuilder" +) -> StructInfo: + """Derive the call's ret value struct info from inputs. + + Parameters + ---------- + func_sinfo: FuncStructInfo + The call's function signature. + + call: Call + The call expression + + ctx: tvm.relax.BlockBuilder + The context block builder. + + Returns + ------- + ret : StructInfo + The derived return value struct info. + + Note + ---- + This is an internal derivation function, call.op field is + ignored in this case and the derivation only depends on func_sinfo. + """ + return _ffi_api.DeriveCallRetStructInfo(func_sinfo, call, ctx) # type: ignore + + +def struct_info_lca(lhs: StructInfo, rhs: StructInfo) -> StructInfo: + """Unify the two struct info to their least common ancestor. + + Parameters + ---------- + lhs: StructInfo + The left operand. + + rhs: StructInfo + The right operand. + + Returns + ------- + ret : StructInfo + The corresponding lca result. + """ + return _ffi_api.StructInfoLCA(lhs, rhs) # type: ignore + + +def bound_vars(expr: Expr) -> List[Var]: + """ + Return all bound variables from expression expr. + Bound variables are all variables that are declared in the expr. + They only have meaning inside that expr, and can only be used in it. + Parameters + ---------- + expr: Expr + The expression. + Returns + ------- + ret: List[Var] + List of bound vars in expr, in post-DFS order + """ + return _ffi_api.bound_vars(expr) + + +def free_vars(expr: Expr) -> List[Var]: + """ + Return all free variables from expression expr. + Free variables are variables that are not bound by a + VarBinding or a function parameter in the expression. + Parameters + ---------- + expr: Expr + The expression. + Returns + ------- + ret: List[Var] + List of free vars in expr, in post-DFS order + """ + return _ffi_api.free_vars(expr) + + +def all_vars(expr: Expr) -> List[Var]: + """ + Return all (local) variables from expression expr. + Parameters + ---------- + expr: Expr + The expression. + Returns + ------- + ret: List[Var] + List of vars in expr, in post-DFS order + """ + return _ffi_api.all_vars(expr) + + +def all_global_vars(expr: Expr) -> List[GlobalVar]: + """ + Return all global variables from expression expr. + Parameters + ---------- + expr: Expr + The expression. + Returns + ------- + ret: List[GlobalVar] + List of global vars in expr, in post-DFS order + """ + return _ffi_api.all_global_vars(expr) + + +def post_order_visit(expr, fvisit): + """Recursively visit the ir in post DFS order node, + apply fvisit. Each node is guaranteed to be visited + only once. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + fvisit : function + The visitor function to be applied. + """ + return _ffi_api.post_order_visit(expr, fvisit) # type: ignore + + +def has_reshape_pattern(func: tir.PrimFunc) -> bool: + """Check if the given PrimFunc is essentially doing a reshape operation. + The reshape operation also includes expand_dims, squeeze, flatten, etc. + + Here the allowed reshape pattern is: for example, assume the operation is + `B[l_0, l_1, ..., l_b] = A[r_0, r_1, ..., r_a]`, we check if we can prove + that the flattened index of l_0, ..., l_b under buffer B equals to the + flattened index of r_0, ..., r_a under buffer A. + + Parameters + ---------- + func : tir.PrimFunc + The function to be examined. + + Returns + ------- + ret : bool + A boolean indicating if the given PrimFunc is doing a reshape. + + Notes + ----- + According to the description above, the returned result can only be + false-negative and cannot be false-positive, since whenever we cannot + prove the equality, we return false. This property guarantees the safety + of this function. + """ + return _ffi_api.has_reshape_pattern(func) # type: ignore + + +def get_var2val(func: Function) -> Dict[Var, Expr]: + """ + Get a mapping from Var to Expr for each variable in the function. + + Parameters + ---------- + func : Function + The input function to be analyzed. + + Returns + ------- + Dict[Var, Expr] + A mapping from Var to Expr. + """ + return _ffi_api.get_var2val(func) # type: ignore + + +def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]: + """ + Analyze the variable use-def chain in a dataflow block. + + Parameters + ---------- + dfb : DataflowBlock + The dataflow block to analyze + + Returns + ------- + Dict[Var, List[Var]] + A mapping from variable definition to its uses. + """ + return _ffi_api.udchain(dfb) # type: ignore + + +def name_to_binding(func: Function) -> Dict[str, List[Binding]]: + """Return a map from variable name to its bindings.""" + return _ffi_api.name_to_binding(func) # type: ignore + + +def remove_all_unused(func: Function) -> Function: + """It removes: + 1. Unused local VarBindings in a DataflowBlock. + 2. Unused DataflowBlocks in a function. + + Parameters + ---------- + func : Function + The input function to be analyzed. + + Notes + ----- + For IRModule-wise DCE, use py:func:`tvm.relax.transform.DeadCodeElimination`. + + Returns + ------- + Function + The function with unused variables removed. + """ + return _ffi_api.remove_all_unused(func) # type: ignore + + +def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool: + """Check if the IRModule is well formed. + + Parameters + ---------- + mod : tvm.IRModule + The input IRModule. + + check_struct_info : bool + A boolean flag indicating if the property "every Expr must + have defined structure info" will be checked. + + Returns + ------- + ret: bool + True if the IRModule is well formed, False if not. + + Note + ---- + By default the structure info is always checked. It is only in test cases + where `check_struct_info` might be false, so that other well-formed requirements + will be well tested and will not be blocked by not having structure info. + """ + return _ffi_api.well_formed(mod, check_struct_info) # type: ignore + + +def suggest_layout_transforms( + func: PrimFunc, write_buffer_transforms: List[Union[IndexMap, Callable]] +) -> Dict[Block, Dict[Union[Block, Buffer], IndexMap]]: + """Suggest Layout transformations of blocks and buffers in a PrimFunc. + + Parameters + ---------- + func: PrimFunc + PrimFunc on which analysis will be performed and transformations suggested. + + write_buffer_transforms: List[Union[IndexMap, Callable] + List of layout transformations on the output buffers. The number of layout + transformations must match the number of outputs of the PrimFunc. + + Returns + ------- + ret: Dict[Block, Dict[Union[Block, Buffer], IndexMap]] + Suggested transforms per block in `func`. For each block the returned value is a map + from the object (block or buffer) to it's index map transformation. + """ + write_buffer_index_maps = [] + for transform in write_buffer_transforms: + if callable(transform): + transform = IndexMap.from_func(transform) + assert isinstance(transform, IndexMap) + write_buffer_index_maps.append(transform) + return _ffi_api.suggest_layout_transforms(func, write_buffer_index_maps) # type: ignore + + +def detect_recursion(mod: tvm.IRModule) -> List[List[GlobalVar]]: + """ + Find all sets of recursive or mutually recursive functions in the module. + + Two or more functions are mutually recursive if there is some cycle of references + among them. For example, if there are two functions A and B, they are + mutually recursive if A calls B and B calls A. Another case would be with + three functions A, B, and C, where A calls B, B calls C, and C calls A. + + (Note that functions do not have to call each other to reference each other. + For example, if a function returns another function, that is still a reference + that could potentially be recursive, even without a call.) + + + If a function is simply recursive and not mutually recursive with any other, + it will be reported as a group by itself. + + Parameters + ---------- + mod: The module + + Returns + ------- + ret: List[List[GlobalVar]] + Each member of the list is a list of global functions + that references each other mutually recursively. + If a function is simply recursive and not mutually recursive + with any other, it will be a singleton in this list. + """ + return _ffi_api.detect_recursion(mod) # type: ignore diff --git a/python/tvm/relax/analysis/estimate_memory_usage.py b/python/tvm/relax/analysis/estimate_memory_usage.py new file mode 100644 index 000000000000..55f82740ec9c --- /dev/null +++ b/python/tvm/relax/analysis/estimate_memory_usage.py @@ -0,0 +1,164 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=abstract-method,unused-argument +# pylint: disable=missing-function-docstring,missing-module-docstring +from typing import Union +import tvm +from tvm.ir import Op +from tvm.ir.module import IRModule + +from ..expr import Call, Expr, Function, ShapeExpr +from ..expr_functor import visitor, PyExprVisitor + + +def estimate_memory_usage(mod: Union[IRModule, Function]) -> str: + """Analysis function that estimates the memory usage of Relax functions + in an IRModule. The estimation includes the total memory size needed to + be allocated before and after memory planning. + + The result might be over-estimated, as the estimation is static, which + does not consider control flows (such as "if" and cross-function calls). + It simply accumulates the size of every alloc_tensor and alloc_storage. + + This analysis function is used to demonstrate the effect of memory + planning. + + Parameters + ---------- + mod : Union[IRModule, Function] + The input IRModule whose functions inside are to be analyzed. + If the input is a Function, we will wrap it with a IRModule, with + the function named "main". + + Returns + ------- + est : str + The estimation information, in the form of a string. + + Notes + ----- + We regards "relax.memory.alloc_tensor/storage" as the results produced by memory planning. + """ + + @visitor + class MemoryEstimator(PyExprVisitor): + """The IR visitor which estimates the memory usage of each Relax function. + + Attributes + ---------- + total_alloc_tensor_mem : int + The total memory size of alloc_tensor, in bytes. + + total_const_size_tensor_num : int + The number of constant-size tensors. + + total_dyn_size_tensor_num : int + The number of dynamic-size tensors. + + planned_alloc_mem : int + The total memory size of memory.alloc_storage after memory planning, in bytes. + + planned_mem_num : int + The number of memory.alloc_storages. + """ + + total_alloc_tensor_mem: int + total_const_size_tensor_num: int + total_dyn_size_tensor_num: int + planned_alloc_mem: int + planned_mem_num: int + builtin_alloc_tensor_op = Op.get("relax.builtin.alloc_tensor") + memory_alloc_tensor_op = Op.get("relax.memory.alloc_tensor") + memory_alloc_storage_op = Op.get("relax.memory.alloc_storage") + + def estimate(self, mod: IRModule) -> str: + estimation: str = "" + for global_var, func in mod.functions.items(): + if not isinstance(func, Function): + continue + + self.cleanup() + self.visit_expr(func) + estimation += self.generate_est_string(global_var.name_hint) + + if estimation != "": + estimation = "Memory usage estimation:\n" + estimation + return estimation + + def cleanup(self) -> None: + self.total_alloc_tensor_mem = 0 + self.total_const_size_tensor_num = 0 + self.total_dyn_size_tensor_num = 0 + self.planned_alloc_mem = 0 + self.planned_mem_num = 0 + + def visit_call_(self, call: Call) -> None: # pylint: disable=arguments-differ + if call.op == self.builtin_alloc_tensor_op: + self.accumulate_tensor_alloc(shape=call.args[0], dtype_str=call.args[1].value) + elif call.op == self.memory_alloc_tensor_op: + self.accumulate_tensor_alloc(shape=call.args[2], dtype_str=call.args[3].value) + elif call.op == self.memory_alloc_storage_op: + self.accumulate_storage_alloc(size=call.args[0]) + + def accumulate_tensor_alloc(self, shape: Expr, dtype_str: str) -> None: + if not isinstance(shape, ShapeExpr): + raise TypeError( + "The shape of relax.builtin.alloc_tensor and " + "relax.memory.alloc_tensor is expected to be ShapeExpr" + ) + size: int = 1 + for dim_len in shape.values: + if not isinstance(dim_len, tvm.tir.IntImm): + self.total_dyn_size_tensor_num += 1 + return + size *= dim_len.value + + dtype = tvm.DataType(dtype_str) + self.total_const_size_tensor_num += 1 + self.total_alloc_tensor_mem += (size * dtype.bits * dtype.lanes + 7) // 8 + + def accumulate_storage_alloc(self, size: Expr) -> None: + if not isinstance(size, ShapeExpr): + raise TypeError( + "The size of relax.memory.alloc_storage is expected to be ShapeExpr" + ) + + self.planned_mem_num += 1 + self.planned_alloc_mem += size.values[0].value + + def generate_est_string(self, func_name: str) -> str: + est = ( + f" * Without memory planning, there are {self.total_const_size_tensor_num} " + "constant-size memory allocation(s) with total size " + "{0:.4} GB".format(self.total_alloc_tensor_mem / 2**30) + ) + if self.total_dyn_size_tensor_num > 0: + est += f", and {self.total_dyn_size_tensor_num} dynamic-size allocation(s)" + est += ( + f".\n * With memory planning, there are {self.planned_mem_num} constant-size " + "memory allocation(s) with total size " + "{0:.4} GB.\n".format(self.planned_alloc_mem / 2**30) + ) + est += " * Memory planning reduces constant memory size to " "{0:.1%}.".format( + self.planned_alloc_mem / self.total_alloc_tensor_mem + ) + return "- Function " + func_name + ":\n" + est + + if isinstance(mod, Function): + mod = tvm.IRModule({tvm.ir.GlobalVar("foo"): mod}) + + return MemoryEstimator().estimate(mod) diff --git a/python/tvm/relax/backend/__init__.py b/python/tvm/relax/backend/__init__.py new file mode 100644 index 000000000000..c3786591e310 --- /dev/null +++ b/python/tvm/relax/backend/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax backends""" + +from . import contrib +from .pattern_registry import get_pattern, get_patterns_with_prefix diff --git a/python/tvm/relax/backend/_ffi_api.py b/python/tvm/relax/backend/_ffi_api.py new file mode 100644 index 000000000000..d1378b2eacc2 --- /dev/null +++ b/python/tvm/relax/backend/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI API for Relax backend.""" + +import tvm._ffi + +tvm._ffi._init_api("relax.backend", __name__) diff --git a/python/tvm/relax/backend/contrib/__init__.py b/python/tvm/relax/backend/contrib/__init__.py new file mode 100644 index 000000000000..a094c97d24bf --- /dev/null +++ b/python/tvm/relax/backend/contrib/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""External backend codegen modules for Relax.""" + +from .cutlass import partition_for_cutlass diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py new file mode 100644 index 000000000000..c03c913d63cd --- /dev/null +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -0,0 +1,295 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern table for CUTLASS backend""" + +from typing import Mapping, Optional, Sequence, Tuple + +import tvm +from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul +from tvm.relax import DataflowVar, ShapeExpr, Var, transform +from tvm.relax.transform import PatternCheckContext + +from ..pattern_registry import get_patterns_with_prefix, register_patterns +from ..patterns import ( + make_attention_pattern, + make_fused_bias_activation_pattern, + make_matmul_pattern, + make_residual_block_pattern, +) + + +def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]: + result = [] + for dim in shape.values: + if isinstance(dim, tvm.tir.expr.IntImm): + result.append(int(dim)) + else: + return None + return result + + +def _is_supported_dtype(lhs_dtype, rhs_dtype): + """Check if dtypes in the given workload are supported by CUTLASS.""" + return ( + (lhs_dtype == "float16" and rhs_dtype == "float16") + or (lhs_dtype == "float32" and rhs_dtype == "float32") + or (lhs_dtype in ("int8", "uint8") and rhs_dtype in ("int8", "uint8")) + ) + + +def _has_leaking_intermediate_variables(context: PatternCheckContext) -> bool: + """ + Check whether intermediate variables in the region to be fused are used outside + the fused region. + """ + defined_vars = set(context.matched_bindings.keys()) + output_var = context.value_to_bound_var[context.matched_expr] + intermediate_vars = {v for v in context.matched_bindings if v != output_var} + + if any(not isinstance(v, DataflowVar) for v in intermediate_vars): + # If intermediate variable is not a DataflowVar, it can be accessed and potentially + # used outside the DataflowBlock. + return True + + # Check whether all users of an intermediate variable are inside the fused region. + for var in intermediate_vars: + if any(var_user not in defined_vars for var_user in context.var_usages[var]): + return True + + return False + + +def _has_dependency(from_var: Var, to_var: Var, var_usages: Mapping[Var, Sequence[Var]]): + if from_var == to_var: + return True + + checked = set() + vars_to_check = [to_var] + while vars_to_check: + current_var = vars_to_check.pop() + for user in var_usages.get(current_var, []): + if user == from_var: + return True + if user not in checked: + checked.add(user) + vars_to_check.append(user) + + return False + + +def _check_conv2d(context: PatternCheckContext) -> bool: + """Check if the given conv2d workload can be offloaded to CUTLASS.""" + if _has_leaking_intermediate_variables(context): + return False + + conv2d_call = context.annotated_expr["root"] + data_layout = conv2d_call.attrs.data_layout + kernel_layout = conv2d_call.attrs.kernel_layout + data, weight, *_ = conv2d_call.args + if ( + data_layout != "NHWC" + or kernel_layout != "OHWI" + or not _is_supported_dtype(data.struct_info.dtype, weight.struct_info.dtype) + ): + return False + + if "residual" in context.annotated_expr: + residual = context.annotated_expr["residual"] + if not isinstance(residual, Var): + residual = context.value_to_bound_var[residual] + conv2d_var = context.value_to_bound_var[conv2d_call] + if _has_dependency(from_var=residual, to_var=conv2d_var, var_usages=context.var_usages): + # If residual depends on the result of conv2d, this cannot be handled by cutlass. + return False + + # pylint: disable=invalid-name + IC = data.struct_info.shape.values[3] + OC = weight.struct_info.shape.values[0] + # not depthwise conv2d + return not IC == OC == conv2d_call.attrs.groups + + +def _check_matmul(context: PatternCheckContext) -> bool: + """Check if the given matmul workload can be offloaded to CUTLASS.""" + if _has_leaking_intermediate_variables(context): + return False + + lhs = context.annotated_expr["lhs"] + rhs = context.annotated_expr["rhs"] + + lhs_dtype = lhs.struct_info.dtype + rhs_dtype = rhs.struct_info.dtype + if not _is_supported_dtype(lhs_dtype, rhs_dtype): + return False + + lhs_shape = lhs.struct_info.shape.values + rhs_shape = rhs.struct_info.shape.values + return is_shape_valid_for_cutlass_matmul(lhs_shape, rhs_shape) + + +def _get_activation_from_name(pattern_name): + if "_relu" in pattern_name: + return "relax.nn.relu" + elif "_gelu" in pattern_name: + return "relax.nn.gelu" + elif "_silu" in pattern_name: + return "relax.nn.silu" + else: + return None + + +def matmul_patterns(): + """ + Returns a list of all matmul patterns in cutlass BYOC backend. + """ + + def _matmul_pattern(pattern_name): + transposed_rhs = "_transposed" in pattern_name + with_bias = "_bias" in pattern_name + activation = _get_activation_from_name(pattern_name) + + return ( + pattern_name, + *make_matmul_pattern( + transposed_rhs=transposed_rhs, + with_bias=with_bias, + activation=activation, + ), + _check_matmul, + ) + + return [ + _matmul_pattern("cutlass.matmul"), + _matmul_pattern("cutlass.matmul_bias"), + _matmul_pattern("cutlass.matmul_bias_relu"), + _matmul_pattern("cutlass.matmul_bias_gelu"), + _matmul_pattern("cutlass.matmul_transposed"), + _matmul_pattern("cutlass.matmul_transposed_bias"), + _matmul_pattern("cutlass.matmul_transposed_bias_relu"), + _matmul_pattern("cutlass.matmul_transposed_bias_gelu"), + ] + + +def conv2d_patterns(): + """ + Returns a list of all conv2d patterns in cutlass BYOC backend. + """ + + def _conv2d_pattern(pattern_name): + with_bias = "_bias" in pattern_name + activation = _get_activation_from_name(pattern_name) + + return ( + pattern_name, + *make_fused_bias_activation_pattern( + "relax.nn.conv2d", + with_bias=with_bias, + activation=activation, + ), + _check_conv2d, + ) + + return [ + _conv2d_pattern("cutlass.conv2d"), + _conv2d_pattern("cutlass.conv2d_bias"), + _conv2d_pattern("cutlass.conv2d_bias_relu"), + _conv2d_pattern("cutlass.conv2d_bias_silu"), + ] + + +def residual_block_patterns(): + """ + Returns a list of all residual block patterns in cutlass BYOC backend. + """ + patterns = [] + + for activation, name_postfix in [(None, ""), ("relax.nn.relu", "_relu")]: + for check, base_patterns in [ + (_check_conv2d, conv2d_patterns()), + (_check_matmul, matmul_patterns()), + ]: + for name, pat, arg_pat, _ in base_patterns: + # Append residual patterns only to those base patterns with bias add, + # since conv2d or matmul + residual add without bias is already supported + # via conv2d or matmul + bias patterns (the residual input is treated as "bias"). + if "bias" in name: + for bin_op in ["relax.add", "relax.multiply"]: + patterns.append( + ( + name + "_residual_" + bin_op.split(".")[-1] + name_postfix, + *make_residual_block_pattern( + (pat, arg_pat), binary_op=bin_op, activation=activation + ), + check, + ) + ) + + return patterns + + +def attention_patterns(): + """ + Returns a list of all attention patterns in cutlass BYOC backend. + """ + return [ + ( + "cutlass.attention", + *make_attention_pattern(), + ), + ( + "cutlass.attention_bias", + *make_attention_pattern(with_bias=True), + ), + ] + + +register_patterns( + [ + *conv2d_patterns(), + *matmul_patterns(), + *residual_block_patterns(), + *attention_patterns(), + ] +) + + +def partition_for_cutlass(mod, annotate_codegen=True): + """ + Partition the input module into CUTLASS-supported subgraphs. + + Parameters + ---------- + mod: tvm.IRModule + The IRModule to be partitioned. + + annotate_codegen: bool + Whether to wrap each created composite function with another function, whose + body consists only of a call to the composite function. See the doc of FuseOpsByPattern + for more detail. + + Returns + ------- + mod: tvm.IRModule + The resulting IRModule, containing partitioned subgraphs to be + compiled by the CUTLASS backend. + """ + + patterns = get_patterns_with_prefix("cutlass") + return transform.FuseOpsByPattern( + patterns, bind_constants=False, annotate_codegen=annotate_codegen + )(mod) diff --git a/python/tvm/relax/backend/pattern_registry.py b/python/tvm/relax/backend/pattern_registry.py new file mode 100644 index 000000000000..5ec57164eb6b --- /dev/null +++ b/python/tvm/relax/backend/pattern_registry.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern registry for BYOC backends""" + +import atexit +from typing import Callable, List, Mapping, Optional, Set, Tuple, Union + +from tvm.relax.dpl import DFPattern +from tvm.relax.transform import FusionPattern + +from ..expr import Expr +from . import _ffi_api + +_REGISTERED_PATTERN_NAMES: Set[str] = set() + + +def _cleanup_registered_patterns(): + _ffi_api.RemovePatterns(list(_REGISTERED_PATTERN_NAMES)) # type: ignore # pylint: disable=no-member + + +_CLEANUP_REGISTERED = False + + +def _ensure_cleanup_function_registered(): + """ + Add a cleanup function to be called on interpreter termination, to remove all + patterns registered on the Python side. Without cleaning up those patterns, + program will segfault on termination. It's because the check functions of pattern + entries are referenced from the static memory of libtvm, thus they will be cleaned + up at the very end, making calls to Py_DecRef after Python interpreter terminates. + """ + global _CLEANUP_REGISTERED # pylint: disable=global-statement + + if not _CLEANUP_REGISTERED: + atexit.register(_cleanup_registered_patterns) + _CLEANUP_REGISTERED = True + + +CheckFunc = Callable[[Mapping[DFPattern, Expr], Expr], bool] +Pattern = Union[ + FusionPattern, + Tuple[str, DFPattern], + Tuple[str, DFPattern, Mapping[str, DFPattern]], + Tuple[str, DFPattern, Mapping[str, DFPattern], CheckFunc], +] + + +def register_patterns(patterns: List[Pattern]): + """ + Register patterns which will be used to partition the DataflowBlock into + subgraphs that are supported by external backends. + + Parameters + ---------- + patterns: List[Pattern] + Patterns to be registered. Patterns that appear later in the list have + higher priority when partitioning DataflowBlock. + """ + _ensure_cleanup_function_registered() + + entries = [] + for item in patterns: + if isinstance(item, FusionPattern): + entries.append(item) + elif isinstance(item, tuple): + entries.append(FusionPattern(*item)) + _REGISTERED_PATTERN_NAMES.add(item[0]) + else: + raise TypeError(f"Cannot register type {type(item)} as pattern") + _ffi_api.RegisterPatterns(entries) + + +def get_patterns_with_prefix(prefix: str) -> List[FusionPattern]: + """ + Get a list of patterns whose names startwith `prefix`. + + Parameters + ---------- + prefix: str + The prefix of pattern name. + + Returns + ------- + patterns: FusionPattern + Matched patterns, ordered by priority from high to low. + """ + return _ffi_api.GetPatternsWithPrefix(prefix) + + +def get_pattern(name: str) -> Optional[FusionPattern]: + """ + Find the pattern with a particular name. + + Parameters + ---------- + name: str + The pattern name. + + Returns + ------- + pattern: Optional[FusionPattern] + The matched pattern. Returns None if such pattern is not found. + """ + return _ffi_api.GetPattern(name) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py new file mode 100644 index 000000000000..e27b91b3eaa6 --- /dev/null +++ b/python/tvm/relax/backend/patterns.py @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Common patterns used in BYOC""" + +from typing import Dict, Mapping, Tuple, Union + +from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard + + +def _with_bias_activation_pattern( + out: DFPattern, + annotations: Dict[str, DFPattern], + with_bias: bool = False, + activation: str = None, +) -> Tuple[DFPattern, Mapping[str, DFPattern]]: + if with_bias: + annotations["bias"] = bias = wildcard() + out = is_op("relax.add")(out, bias) + + if activation: + out = is_op(activation)(out) + + return out, annotations + + +def make_fused_bias_activation_pattern( + op_name: str, + with_bias: bool = False, + activation: str = None, +) -> Tuple[DFPattern, Mapping[str, DFPattern]]: + """ + A simple utility to create patterns for an operation fused with bias addition and activation. + + Parameters + ---------- + op_name: str + The name of a Relax op, such as "relax.nn.conv2d" + + with_bias: bool + Whether or not to include bias addition + + activation: str + The name of an activation Relax op, such as "relax.nn.relu" + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a fused operation + + annotations: Mapping[str, DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ + lhs = wildcard() + rhs = wildcard() + out = is_op(op_name)(lhs, rhs) + annotations = {"lhs": lhs, "rhs": rhs, "root": out} + + return _with_bias_activation_pattern(out, annotations, with_bias, activation) + + +def make_residual_block_pattern( + node_output: Union[DFPattern, Tuple[DFPattern, Mapping[str, DFPattern]]], + binary_op="relax.add", + activation=None, +) -> Tuple[DFPattern, Mapping[str, DFPattern]]: + """ + Create pattern for residual block. + + Parameters + ---------- + node_output: Union[DFPattern, Tuple[DFPattern, Mapping[str, DFPattern]]] + The output of previous node. + + binary_op: str + The op used to combine previous node output and residual input. + + activation: str + The activation function of this residual block. It should be a name of + activation Relax op, such as "relax.nn.relu". + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a matrix multiplication. + + annotations: Mapping[str, DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ + + if isinstance(node_output, tuple): + node_output, arg_patterns = node_output + else: + arg_patterns = {} + + residual_input = wildcard() + op = is_op(binary_op) + output = op(node_output, residual_input) | op(residual_input, node_output) + + if activation is not None: + output = is_op(activation)(output) + + return output, {**arg_patterns, "residual": residual_input} + + +def make_matmul_pattern( + with_bias: bool = False, + activation: str = None, + transposed_rhs: bool = False, +) -> Tuple[DFPattern, Mapping[str, DFPattern]]: + """ + Create pattern for matrix multiplication. + + Parameters + ---------- + with_bias: bool + Whether or not to include bias addition + + activation: str + The name of an activation Relax op, such as "relax.nn.relu" + + transposed_rhs: bool + Whether the right hand side of multiplication is transposed. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a matrix multiplication. + + annotations: Mapping[str, DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ + + lhs = wildcard() + rhs = wildcard() + annotations = {"lhs": lhs, "rhs": rhs} + + if transposed_rhs: + rhs = is_op("relax.permute_dims")(rhs) + + out = is_op("relax.matmul")(lhs, rhs) + annotations["root"] = out + + return _with_bias_activation_pattern(out, annotations, with_bias, activation) + + +def make_attention_pattern(with_bias: bool = False): + """ + Create pattern for fused multi head attention. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a fused multi head attention. + + annotations: Mapping[str, DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ + query = wildcard() + key = wildcard() + value = wildcard() + annotations = {"query": query, "key": key, "value": value} + if with_bias: + bias = wildcard() + annotations["bias"] = bias + out = is_op("relax.nn.attention_bias")(query, key, value, bias) + else: + out = is_op("relax.nn.attention")(query, key, value) + + return out, annotations diff --git a/python/tvm/relax/backend_tir/__init__.py b/python/tvm/relax/backend_tir/__init__.py new file mode 100644 index 000000000000..eeb8fe438f6e --- /dev/null +++ b/python/tvm/relax/backend_tir/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax backends, tir based""" + +from . import contrib +from .pattern import get_tir_pattern diff --git a/python/tvm/relax/backend_tir/contrib/__init__.py b/python/tvm/relax/backend_tir/contrib/__init__.py new file mode 100644 index 000000000000..9274f22374b9 --- /dev/null +++ b/python/tvm/relax/backend_tir/contrib/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""External backend codegen modules for Relax, tir based.""" + +from .cutlass import cutlass_fcodegen diff --git a/python/tvm/relax/backend_tir/contrib/cutlass.py b/python/tvm/relax/backend_tir/contrib/cutlass.py new file mode 100644 index 000000000000..0dbe31c468ad --- /dev/null +++ b/python/tvm/relax/backend_tir/contrib/cutlass.py @@ -0,0 +1,720 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,comparison-with-callable,unused-variable,missing-function-docstring +"""codegen for cutlass""" +import operator +from functools import reduce +from typing import List, Dict, Any + +from tvm.contrib.cutlass.build import _get_cutlass_path, _get_cutlass_compile_options +from tvm.contrib.nvcc import get_target_compute_version +from tvm.contrib.cutlass.library import LayoutType, ConvKind +from tvm.contrib.cutlass.gen_tensor_op import instantiate_template +from tvm.contrib.cutlass.gen_gemm import CutlassGemmProfiler +from tvm.contrib.cutlass.gen_conv2d import CutlassConv2DProfiler +from ..pattern import ( + MatchResult, + matmul_rrr_fp16, + bias_row_2d_fp16, + bias_row_1d_fp16, + batch_bias_row_2d_fp16, + batch_bias_row_1d_fp16, + relu_fp16, + erf_3d_fp32, + batch_matmul_rrr_2d_fp16, + batch_matmul_rrr_3d_fp16, + conv2d_nhwc_fp16, + padding_2d_nhwc_fp16, + copy_4d_fp16, + bias_add_nhwc_2d_fp16, + bias_add_nhwc_1d_fp16, + elem_add_4d_fp16, + elem_mul_3d_fp16, + scalar_add_3d_fp16, + scalar_mul_3d_fp16, + cast_3d_fp16, + cast_3d_fp32, +) + +#### helper functions #### +# list representing the anchor ops +# in the future more layouts/dtypes can be supported +MATMUL_LIST = [matmul_rrr_fp16] +MATMUL_BIAS_LIST = [bias_row_2d_fp16, bias_row_1d_fp16] +BATCH_MATMUL_LIST = [batch_matmul_rrr_2d_fp16, batch_matmul_rrr_3d_fp16] +BATCH_MATMUL_BIAS_LIST = [batch_bias_row_2d_fp16, batch_bias_row_1d_fp16] +CONV2D_LIST = [conv2d_nhwc_fp16] + +# attributes for anchor ops used in code generation +OP_PATTERN_ATTR_LIST = { + matmul_rrr_fp16: { + "arg0_dtype": "float16", + "arg1_dtype": "float16", + "ret_dtype": "float16", + }, + batch_matmul_rrr_2d_fp16: { + "arg0_dtype": "float16", + "arg1_dtype": "float16", + "ret_dtype": "float16", + }, + batch_matmul_rrr_3d_fp16: { + "arg0_dtype": "float16", + "arg1_dtype": "float16", + "ret_dtype": "float16", + }, + conv2d_nhwc_fp16: { + "arg0_dtype": "float16", + "arg1_dtype": "float16", + "ret_dtype": "float16", + # in the future we can add layout here + }, +} + + +def _get_cutlass_code(attr): + pattern = attr["op_type"] + if pattern.startswith("cutlass.matmul"): + return cutlass_codegen_gemm(attr) + elif pattern.startswith("cutlass.conv2d"): + return cutlass_codegen_conv2d(attr) + else: + raise ValueError("op not supported") + + +def _final_code(code, headers, func_args): + res = "" + res += "#define DMLC_USE_LOGGING_LIBRARY \n" + res += "#include \n" + res += "#include \n" + res += "#include \n" + res += "#include \n" + res += "#include \n" + res += "#include \n" + res += "#include \n" + res += "#include \n" + + for header in headers: + res += "#include <" + header + ">\n" + res += "namespace {\n" + res += "using namespace tvm;\n" + res += "using namespace tvm::runtime;\n" + res += "void _cutlass_kernel(" + for arg in func_args: + res += "NDArray " + arg + ", " + res += "NDArray out0) {" + res += code + res += "}\n" + res += "} // namespace\n" + res += "TVM_DLL_EXPORT_TYPED_FUNC({global_symbol}, _cutlass_kernel);\n" + return res + + +#### cutlass patterns #### +def matmul_bias_relu(match_results, attr, get_code=True): + if len(match_results) < 3: + return None + attr = matmul_bias(match_results[:2], attr, get_code=False) + if attr is None or match_results[2].pattern != relu_fp16: + return None + m_bias, n_bias = match_results[1].symbol_values + m_relu, n_relu = match_results[2].symbol_values + A_bias, B_bias, C_bias = match_results[1].matched_buffers + A_relu, B_relu = match_results[2].matched_buffers + if m_bias == m_relu and n_bias == n_relu and C_bias == A_relu: + attr["op_type"] = "cutlass.matmul_bias_relu" + return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code else attr + return None + + +def matmul_bias(match_results, attr, get_code=True): + if len(match_results) < 2: + return None + attr = matmul(match_results[:1], attr, get_code=False) + if attr is None or match_results[1].pattern not in MATMUL_BIAS_LIST: + return None + m_matmul, n_matmul, k_matmul = match_results[0].symbol_values + m_bias, n_bias = match_results[1].symbol_values + A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers + A_bias, B_bias, C_bias = match_results[1].matched_buffers + if m_matmul == m_bias and n_matmul == n_bias and C_matmul == A_bias: + attr["op_type"] = "cutlass.matmul_bias" + attr["bias_arg_idx"] = 2 + attr["args"].append(B_bias) + return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code else attr + return None + + +def matmul(match_results, attr, get_code=True): + if len(match_results) < 1: + return None + if match_results[0].pattern in MATMUL_LIST: + # matmul + attr["op_type"] = "cutlass.matmul" + return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code else attr + return None + + +def batch_matmul_bias_gelu(match_results, attr, get_code=True): + if len(match_results) < 9: + return None + attr = batch_matmul_bias(match_results[:2], attr, get_code=False) # batch_matmul, batch_bias + if ( + attr is None + or match_results[2].pattern != scalar_mul_3d_fp16 + or match_results[3].pattern != cast_3d_fp32 + or match_results[4].pattern != erf_3d_fp32 + or match_results[5].pattern != cast_3d_fp16 + or match_results[6].pattern != scalar_mul_3d_fp16 + or match_results[7].pattern != scalar_add_3d_fp16 + or match_results[8].pattern != elem_mul_3d_fp16 + ): + return None + + def shape_match_3d(shape1, shape2): + if len(shape1) < 3 or len(shape2) < 3: + return False + return shape1[0] == shape2[0] and shape1[1] == shape2[1] and shape1[2] == shape2[2] + + for i in range(1, 8): + if not shape_match_3d(match_results[i].symbol_values, match_results[i + 1].symbol_values): + return None + + if not ( + match_results[1].matched_buffers[-1] == match_results[2].matched_buffers[0] + and match_results[2].matched_buffers[-1] == match_results[3].matched_buffers[0] + and match_results[3].matched_buffers[-1] == match_results[4].matched_buffers[0] + and match_results[4].matched_buffers[-1] == match_results[5].matched_buffers[0] + and match_results[5].matched_buffers[-1] == match_results[6].matched_buffers[0] + and match_results[6].matched_buffers[-1] == match_results[7].matched_buffers[0] + and match_results[1].matched_buffers[-1] == match_results[8].matched_buffers[0] + and match_results[7].matched_buffers[-1] == match_results[8].matched_buffers[1] + ): + return None + + if ( + abs(float(match_results[2].symbol_values[-1] - 0.5**0.5)) > 1e-5 + or abs(float(match_results[6].symbol_values[-1] - 0.5)) > 1e-5 + or abs(float(match_results[7].symbol_values[-1] - 0.5)) > 1e-5 + ): + return None + + attr["op_type"] = "cutlass.matmul_bias_gelu" + return [_get_cutlass_code(attr=attr), 9, attr["args"]] if get_code else attr + + +def batch_matmul_bias_residual_mul(match_results, attr, get_code=True): + if len(match_results) < 3: + return None + attr = batch_matmul_bias(match_results[:2], attr, get_code=False) # batch_matmul, batch_bias + if attr is None or match_results[2].pattern != elem_mul_3d_fp16: + return None + ( + b_bias, + m_bias, + n_bias, + ) = match_results[1].symbol_values + ( + b_mul, + m_mul, + n_mul, + ) = match_results[2].symbol_values + A_bias, B_bias, C_bias = match_results[1].matched_buffers + A_mul, B_mul, C_mul = match_results[2].matched_buffers + if b_bias == b_mul and m_bias == m_mul and n_bias == n_mul and C_bias == A_mul: + attr["op_type"] = "cutlass.matmul_bias_residual_multiply" + attr["residual_arg_idx"] = 3 + return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code else attr + return None + + +def batch_matmul_bias(match_results, attr, get_code=True): + if len(match_results) < 2: + return None + attr = batch_matmul(match_results[:1], attr, get_code=False) + if attr is None or match_results[1].pattern not in BATCH_MATMUL_BIAS_LIST: + return None + ( + b_matmul, + m_matmul, + n_matmul, + k_matmul, + ) = match_results[0].symbol_values + ( + b_bias, + m_bias, + n_bias, + ) = match_results[1].symbol_values + A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers + A_bias, B_bias, C_bias = match_results[1].matched_buffers + if b_matmul == b_bias and m_matmul == m_bias and n_matmul == n_bias and C_matmul == A_bias: + attr["op_type"] = "cutlass.matmul_bias" + attr["bias_arg_idx"] = 2 + attr["args"].append(B_bias) + return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code else attr + return None + + +def batch_matmul(match_results, attr, get_code=True): + if len(match_results) < 1: + return None + if match_results[0].pattern in BATCH_MATMUL_LIST: + attr["op_type"] = "cutlass.matmul" + return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code else attr + return None + + +def conv2d_bias_residual_add(match_results, attr, get_code=True): + if len(match_results) < 4: + return None + attr = conv2d_bias(match_results[:3], attr, get_code=False) + if attr is None or match_results[3].pattern != elem_add_4d_fp16: + return None + N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values + in1_bias, in2_bias, out_bias = match_results[2].matched_buffers + N_add, H_add, W_add, C_add = match_results[3].symbol_values + in1_add, in2_add, out_add = match_results[3].matched_buffers + if ( + N_bias == N_add + and H_bias == H_add + and W_bias == W_add + and C_bias == C_add + and out_bias in [in1_add, in2_add] + ): + attr["op_type"] = "cutlass.conv2d_bias_residual_add" + attr["residual_arg_idx"] = 3 + attr["args"].append(in2_add if out_bias == in1_add else in1_add) + return [_get_cutlass_code(attr=attr), 4, attr["args"]] if get_code else attr + return None + + +def conv2d_bias(match_results, attr, get_code=True): + if len(match_results) < 3: + return None + attr = conv2d(match_results[:2], attr, get_code=False) + if attr is None or ( + match_results[2].pattern not in [bias_add_nhwc_2d_fp16, bias_add_nhwc_1d_fp16] + ): + return None + (N_conv, pH_conv, pW_conv, H_conv, W_conv, C_conv, O_conv,) = match_results[ + 1 + ].symbol_values[:7] + A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers + N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values + A_bias, B_bias, out_bias = match_results[2].matched_buffers + if ( + N_bias == N_conv + and H_bias == H_conv + and W_bias == W_conv + and C_bias == O_conv + and out_conv == A_bias + ): + attr["op_type"] = "cutlass.conv2d_bias" + attr["bias_arg_idx"] = 2 + attr["args"].append(B_bias) + return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code else attr + return None + + +def conv2d(match_results, attr, get_code=True): + if len(match_results) < 2: + return None + if ( + match_results[0].pattern in [padding_2d_nhwc_fp16, copy_4d_fp16] + and match_results[1].pattern == conv2d_nhwc_fp16 + ): + if match_results[0].pattern == padding_2d_nhwc_fp16: + ( + N_pad, + H_pad, + W_pad, + C_pad, + pH_pad, + pW_pad, + lH_pad, + lW_pad, + rH_pad, + rW_pad, + ) = match_results[0].symbol_values + else: + ( + N_pad, + H_pad, + W_pad, + C_pad, + ) = match_results[0].symbol_values + pH_pad = rH_pad = H_pad + pW_pad = rW_pad = W_pad + lH_pad = lW_pad = 0 + ( + N_conv, + pH_conv, + pW_conv, + H_conv, + W_conv, + C_conv, + O_conv, + KH_conv, + KW_conv, + stride_h_conv, + stride_w_conv, + dilation_h_conv, + dilation_w_conv, + ) = match_results[1].symbol_values + A, A_pad = match_results[0].matched_buffers + A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers + if ( + N_pad == N_conv + and pH_pad == pH_conv + and pW_pad == pW_conv + and C_pad == C_conv + and A_pad == A_pad_conv + ): + if ( + lH_pad == pH_pad - rH_pad + and lW_pad == pW_pad - rW_pad + and lH_pad + H_pad == rH_pad + and lW_pad + W_pad == rW_pad + ): + padding = (lH_pad, lW_pad) + strides = (stride_h_conv, stride_w_conv) + dilation = (dilation_h_conv, dilation_w_conv) + attr["padding"] = padding + attr["strides"] = strides + attr["dilation"] = dilation + attr["op_type"] = "cutlass.conv2d" + return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code else attr + return None + + +### cutlass codegen functions ### +def compile_options(target, threads=-1, use_fast_math=False): + compute_version = int("".join(get_target_compute_version(target).split("."))) + kwargs = _get_cutlass_compile_options(compute_version, threads, use_fast_math) + kwargs["options"].remove("-c") + return kwargs + + +def cutlass_fcodegen(sm=80, bin_dir="./bin"): + gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), bin_dir) + conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), bin_dir) + + def cutlass_codegen_with_match_results(match_results: List[MatchResult]): + """generate cutlass code with match results""" + nonlocal gemm_profiler + nonlocal conv2d_profiler + + assert len(match_results) > 0 + + # add shape into attr + if match_results[0].pattern in MATMUL_LIST: + A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers + attr: Dict[Any, Any] = OP_PATTERN_ATTR_LIST[match_results[0].pattern] + attr["args"] = [A_matmul, B_matmul] + attr["arg0_shape"] = A_matmul.shape + attr["arg1_shape"] = B_matmul.shape + attr["ret_shape"] = C_matmul.shape + attr["lhs_arg_idx"] = 0 + attr["rhs_arg_idx"] = 1 + elif match_results[0].pattern in BATCH_MATMUL_LIST: + A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers + attr = OP_PATTERN_ATTR_LIST[match_results[0].pattern] + attr["args"] = [A_matmul, B_matmul] + attr["arg0_shape"] = A_matmul.shape + attr["arg1_shape"] = B_matmul.shape + attr["ret_shape"] = C_matmul.shape + attr["lhs_arg_idx"] = 0 + attr["rhs_arg_idx"] = 1 + elif len(match_results) >= 1 and match_results[1].pattern in CONV2D_LIST: + A_input = match_results[0].matched_buffers[0] + A_conv2d, B_conv2d, C_conv2d = match_results[1].matched_buffers + attr = OP_PATTERN_ATTR_LIST[match_results[1].pattern] + attr["args"] = [A_input, B_conv2d] + attr["arg0_shape"] = A_input.shape + attr["arg1_shape"] = B_conv2d.shape + attr["ret_shape"] = C_conv2d.shape + attr["lhs_arg_idx"] = 0 + attr["rhs_arg_idx"] = 1 + else: + return ["", 0] + + # add profiler into attr + attr["gemm_profiler"] = gemm_profiler + attr["conv2d_profiler"] = conv2d_profiler + + cutlass_patterns = [ + # 9 + batch_matmul_bias_gelu, + # 4 + conv2d_bias_residual_add, + # 3 + batch_matmul_bias_residual_mul, + matmul_bias_relu, + conv2d_bias, + # 2 + matmul_bias, + batch_matmul_bias, + conv2d, + # 1 + matmul, + batch_matmul, + ] + for pattern in cutlass_patterns: + res = pattern(match_results, attr) + if res is not None: + return res + + return ["", 0] + + return cutlass_codegen_with_match_results + + +def cutlass_codegen_gemm(attrs): + """cutlass codegen for gemm""" + gemm_profiler = attrs["gemm_profiler"] + op_type = attrs["op_type"] + lhs_shape = attrs["arg0_shape"] + rhs_shape = attrs["arg1_shape"] + MM = lhs_shape[-2] + KK = lhs_shape[-1] + if "transposed" in op_type: + NN = rhs_shape[-2] + ldb = "K" + layout_b = LayoutType.ColumnMajor + else: + NN = rhs_shape[-1] + ldb = "N" + layout_b = LayoutType.RowMajor + + lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) + rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1) + if lhs_batches == 1 and rhs_batches == 1: + # Regular matmul + is_batched = False + batch_attrs = {} + else: + is_batched = True + batch_attrs = { + # If both lhs_batches and rhs_batches are greater than 1, + # they must be equal. This is checked by is_shape_valid_for_cutlass_matmul. + "batch": lhs_batches if rhs_batches == 1 else rhs_batches, + "batch_stride_A": 0 if lhs_batches == 1 else MM * KK, + "batch_stride_B": 0 if rhs_batches == 1 else KK * NN, + "batch_stride_C": MM * NN, + } + op_name, op_def, _ = gemm_profiler.profile( + op_type, + MM, + NN, + KK, + attrs["ret_dtype"], + attrs["arg0_dtype"], + attrs["arg1_dtype"], + False, + batched=is_batched, + find_first_valid=False, + use_multiprocessing=True, + layout_b=layout_b, + ) + attrs["cutlass_op_name"] = op_name + attrs["cutlass_op_def"] = op_def + attrs["lda"] = "K" + attrs["ldb"] = ldb + attrs["ldc"] = "N" + attrs.update(batch_attrs) + del attrs["gemm_profiler"] + del attrs["conv2d_profiler"] + + nargs = 2 + if "bias_arg_idx" in attrs: + nargs += 1 + if "residual_arg_idx" in attrs: + nargs += 1 + func_args = ["inp" + str(i) for i in range(nargs)] + + # A temporary solution to handle batch matmul residual cases + # TODO(@bohan): remove this after initialize_template supports bmm residual + if op_type in [ + "cutlass.matmul_bias_residual_multiply", + ]: + + def _convert_dtype_str(dtype): + if isinstance(dtype, list): + arr = [] + for t in dtype: + arr.append(_convert_dtype_str(t)) + return arr + elif isinstance(dtype, str): + if dtype == "float16": + return "cutlass::half_t" + elif dtype == "float32": + return "float" + raise ValueError("dtype not supported") + + typea, typeb, typec = _convert_dtype_str( + [attrs["arg0_dtype"], attrs["arg1_dtype"], attrs["ret_dtype"]] + ) + + text = f""" +#define CUTLASS_ENABLE_CUBLAS 1 +#define CUTLASS_NAMESPACE cutlass +#define CUTLASS_ENABLE_TENSOR_CORE_MMA 1 +#define NDEBUG +#include +#include +#include +#include +#include +#include +#include +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" +#include "cutlass/gemm/device/gemm_universal_with_broadcast.h" +#include +#include +#include +#include +#define DMLC_USE_LOGGING_LIBRARY +#include +#include +#include +namespace {{ +using namespace tvm; +using namespace tvm::runtime; +void _BHGEMM(NDArray A, NDArray B, NDArray Bias, NDArray D, NDArray C) {{ + // A: [Batch, M, K], B: [1, K, N]/[K, N], Bias: [1, N]/[N], D: [Batch, M, N], C: [Batch, M, N] + CHECK_EQ(A->ndim, 3); + int bdim = B->ndim; + int bias_dim = Bias->ndim; + CHECK_EQ(C->ndim, 3); + CHECK_EQ(A->shape[2], B->shape[bdim - 2]); + CHECK_EQ(Bias->shape[bias_dim - 1], B->shape[bdim - 1]); + CHECK_EQ(D->ndim, 3); + CHECK_EQ(D->shape[0], A->shape[0]); + CHECK_EQ(D->shape[1], A->shape[1]); + CHECK_EQ(D->shape[2], B->shape[bdim - 1]); + CHECK_EQ(C->shape[0], A->shape[0]); + CHECK_EQ(C->shape[1], A->shape[1]); + CHECK_EQ(C->shape[2], B->shape[bdim - 1]); + int64_t M = A->shape[0] * A->shape[1]; + int64_t N = B->shape[bdim - 1]; + int64_t K = A->shape[2]; + int64_t input_a_batch_stride = M * K; + int64_t input_a_stride = K; + int64_t input_a_offset = 0; // default to 0 + int64_t input_b_batch_stride = K * N; + int64_t input_b_stride = N; + int64_t input_b_offset = 0; // default to 0 + int64_t output_stride = N; + int64_t output_offset = 0; + int64_t a_size = 1; + a_size *= A->shape[0]; + a_size *= A->shape[1]; + a_size *= A->shape[2]; + + int64_t b_size = 1; + b_size *= B->shape[bias_dim - 2]; + b_size *= B->shape[bias_dim - 1]; + + int64_t c_size = 1; + c_size *= C->shape[0]; + c_size *= C->shape[1]; + c_size *= C->shape[2]; + + // Define the GEMM operation + {op_def} + using kernel = Operation_{op_name}; + using ElementComputeEpilogue = typename kernel::ElementAccumulator; + typename kernel::Arguments arguments({{ + cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode + {{M, N, K}}, // GemmCoord problem_size + 1, // int batch_count + {{ElementComputeEpilogue(1), ElementComputeEpilogue(1)}}, // typename EpilogueOutputOp::Params epilogue + ({typea}*)(A->data) + input_a_offset, // void const * ptr_A + ({typeb}*)(B->data) + input_b_offset, // void const * ptr_B + ({typec}*)(D->data), // void const * ptr_C1 + ({typec}*)(C->data) + output_offset, // void * ptr_D + ({typea}*)(Bias->data), // void * ptr_Vector + nullptr, // void * ptr_Tensor + input_a_batch_stride, // int64_t batch_stride_A + input_b_batch_stride, // int64_t batch_stride_B + 0, // int64_t batch_stride_C1 + 0, // int64_t batch_stride_D + 0, // int64_t batch_stride_Vector + 0, // int64_t batch_stride_Tensor + input_a_stride, // typename LayoutA::Stride::Index lda + input_b_stride, // typename LayoutB::Stride::Index ldb + N, // typename LayoutC::Stride::Index ldc1 + output_stride, // typename LayoutC::Stride::Index ldd + 0, // typename LayoutC::Stride::Index ldr + 0, // typename LayoutC::Stride::Index ldt + }}); + kernel gemm_op; + size_t workspace_size = gemm_op.get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + cutlass::Status status = gemm_op.can_implement(arguments); + CHECK(status == cutlass::Status::kSuccess); + status = gemm_op.initialize(arguments, workspace.get()); + CHECK(status == cutlass::Status::kSuccess); + status = gemm_op(); + CHECK(status == cutlass::Status::kSuccess); + return; +}} +}} // namespace +TVM_DLL_EXPORT_TYPED_FUNC({{global_symbol}}, _BHGEMM); + """ + return text + + code = instantiate_template(op_type, attrs, func_args) + return _final_code(code.code, code.headers, func_args) + + +def cutlass_codegen_conv2d(attrs): + """cutlass codegen for conv2d""" + # cutlass backend only supports nhwc for now + conv2d_profiler = attrs["conv2d_profiler"] + op_type = attrs["op_type"] + conv_kind = ConvKind.Fprop + op_name, op_def, _ = conv2d_profiler.profile( + op_type=attrs["op_type"], + d_shape=attrs["arg0_shape"], + w_shape=attrs["arg1_shape"], + padding=attrs["padding"], + stride=attrs["strides"], + dilation=attrs["dilation"], + out_dtype=attrs["ret_dtype"], + data_dtype=attrs["arg0_dtype"], + weight_dtype=attrs["arg1_dtype"], + use_3xtf32=False, + conv_kind=conv_kind, + split_k_slices=[1], + profile_all_alignments=True, + find_first_valid=False, + use_multiprocessing=True, + ) + attrs["cutlass_op_def"] = op_def + attrs["cutlass_op_name"] = op_name + del attrs["gemm_profiler"] + del attrs["conv2d_profiler"] + + nargs = 2 + if "bias_arg_idx" in attrs: + nargs += 1 + if "residual_arg_idx" in attrs: + nargs += 1 + func_args = ["inp" + str(i) for i in range(nargs)] + code = instantiate_template(op_type, attrs, func_args) + return _final_code(code.code, code.headers, func_args) diff --git a/python/tvm/relax/backend_tir/pattern.py b/python/tvm/relax/backend_tir/pattern.py new file mode 100644 index 000000000000..10f7a3b1628d --- /dev/null +++ b/python/tvm/relax/backend_tir/pattern.py @@ -0,0 +1,576 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring,chained-comparison +"""TIR Patterns""" +from typing import List + +import tvm +from tvm.runtime import Object +import tvm._ffi + +from tvm.script import tir as T + + +@tvm._ffi.register_object("relax.MatchResult") +class MatchResult(Object): + """The match result of a TIR pattern.""" + + def __init__(self, pattern, symbol_values, matched_buffers): + self.__init_handle_by_constructor__( + tvm._ffi.MatchResult, pattern, symbol_values, matched_buffers + ) + + +@T.prim_func +def matmul_rrr_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_matmul: T.handle, + M: T.int64, + N: T.int64, + K: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [M, K], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N], dtype="float16") + matmul = T.match_buffer(var_matmul, [M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1, i2 in T.grid(M, N, K): + with T.block("matmul"): + i0_1, i1_1, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, k], rxplaceholder_1[k, i1_1]) + T.writes(matmul[i0_1, i1_1]) + with T.init(): + matmul[i0_1, i1_1] = T.float16(0) + matmul[i0_1, i1_1] = ( + matmul[i0_1, i1_1] + rxplaceholder[i0_1, k] * rxplaceholder_1[k, i1_1] + ) + + +@T.prim_func +def bias_row_2d_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_T_add: T.handle, + M: T.int64, + N: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N], dtype="float16") + T_add = T.match_buffer(var_T_add, [M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1 in T.grid(M, N): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[T.int64(0), ax1]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[T.int64(0), ax1] + + +@T.prim_func +def bias_row_1d_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_T_add: T.handle, + M: T.int64, + N: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16") + T_add = T.match_buffer(var_T_add, [M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1 in T.grid(M, N): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[ax1] + + +@T.prim_func +def batch_bias_row_2d_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_T_add: T.handle, + batch: T.int64, + M: T.int64, + N: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N], dtype="float16") + T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1, i2 in T.grid(batch, M, N): + with T.block("T_add"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[T.int64(0), ax2]) + T.writes(T_add[ax0, ax1, ax2]) + T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + rxplaceholder_1[T.int64(0), ax2] + + +@T.prim_func +def batch_bias_row_1d_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_T_add: T.handle, + batch: T.int64, + M: T.int64, + N: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16") + T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1, i2 in T.grid(batch, M, N): + with T.block("T_add"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[ax2]) + T.writes(T_add[ax0, ax1, ax2]) + T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + rxplaceholder_1[ax2] + + +@T.prim_func +def relu_fp16(var_rxplaceholder: T.handle, var_compute: T.handle, M: T.int64, N: T.int64) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16") + compute = T.match_buffer(var_compute, [M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1 in T.grid(M, N): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float16(0)) + + +@T.prim_func +def batch_matmul_rrr_2d_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_matmul: T.handle, + batch: T.int64, + M: T.int64, + N: T.int64, + K: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N], dtype="float16") + matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1, i2, i3 in T.grid(batch, M, N, K): + with T.block("matmul"): + i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[k, i2_1]) + T.writes(matmul[i0_1, i1_1, i2_1]) + with T.init(): + matmul[i0_1, i1_1, i2_1] = T.float16(0) + matmul[i0_1, i1_1, i2_1] = ( + matmul[i0_1, i1_1, i2_1] + rxplaceholder[i0_1, i1_1, k] * rxplaceholder_1[k, i2_1] + ) + + +@T.prim_func +def batch_matmul_rrr_3d_fp16( + var_rxplaceholder: T.handle, + var_rxplaceholder_1: T.handle, + var_matmul: T.handle, + batch: T.int64, + M: T.int64, + N: T.int64, + K: T.int64, +) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K], dtype="float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [batch, K, N], dtype="float16") + matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16") + # body + # with T.block("root") + for i0, i1, i2, i3 in T.grid(batch, M, N, K): + with T.block("matmul"): + i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[i0_1, k, i2_1]) + T.writes(matmul[i0_1, i1_1, i2_1]) + with T.init(): + matmul[i0_1, i1_1, i2_1] = T.float16(0) + matmul[i0_1, i1_1, i2_1] = ( + matmul[i0_1, i1_1, i2_1] + + rxplaceholder[i0_1, i1_1, k] * rxplaceholder_1[i0_1, k, i2_1] + ) + + +@T.prim_func +def copy_4d_fp16( + A_handle: T.handle, + B_handle: T.handle, + N: T.int64, + H: T.int64, + W: T.int64, + C: T.int64, +) -> None: + A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") + B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16") + # body + # with T.block("root") + for n, h, w, c in T.grid(N, H, W, C): + with T.block("copy"): + vn, vh, vw, vc = T.axis.remap("SSSS", [n, h, w, c]) + T.reads(A[vn, vh, vw, vc]) + T.writes(B[vn, vh, vw, vc]) + B[vn, vh, vw, vc] = A[vn, vh, vw, vc] + + +@T.prim_func +def padding_2d_nhwc_fp16( + A_handle: T.handle, + B_handle: T.handle, + N: T.int64, + H: T.int64, + W: T.int64, + C: T.int64, + pH: T.int64, + pW: T.int64, + lH: T.int64, + lW: T.int64, + rH: T.int64, + rW: T.int64, +) -> None: + A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") + B = T.match_buffer(B_handle, [N, pH, pW, C], dtype="float16") + # body + # with T.block("root") + for v, v_1, v_2, v_3 in T.grid(N, pH, pW, C): + with T.block("copy"): + v_4, v_5, v_6, v_7 = T.axis.remap("SSSS", [v, v_1, v_2, v_3]) + T.reads(A[v_4, v_5 - lH, v_6 - lW, v_7]) + T.writes(B[v_4, v_5, v_6, v_7]) + B[v_4, v_5, v_6, v_7] = T.if_then_else( + lH <= v_5 and v_5 < rH and lW <= v_6 and v_6 < rW, + A[v_4, v_5 - lH, v_6 - lW, v_7], + T.float16(0), + dtype="float16", + ) + + +@T.prim_func +def conv2d_nhwc_fp16( + A_handle: T.handle, + B_handle: T.handle, + out_handle: T.handle, + N: T.int64, + pH: T.int64, + pW: T.int64, + H: T.int64, + W: T.int64, + C: T.int64, + O: T.int64, + KH: T.int64, + KW: T.int64, + StrideH: T.int64, + StrideW: T.int64, + DilateH: T.int64, + DilateW: T.int64, +) -> None: + A = T.match_buffer(A_handle, [N, pH, pW, C], dtype="float16") + B = T.match_buffer(B_handle, [O, KH, KW, C], dtype="float16") + out = T.match_buffer(out_handle, [N, H, W, O], dtype="float16") + # body + # with T.block("root") + for v, v_1, v_2, v_3, v_4, v_5, v_6 in T.grid(N, H, W, O, KH, KW, C): + with T.block("conv"): + v_7, v_8, v_9, v_10, v_11, v_12, v_13 = T.axis.remap( + "SSSSRRR", [v, v_1, v_2, v_3, v_4, v_5, v_6] + ) + T.reads( + A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9 * StrideW, v_13], + B[v_10, v_11, v_12, v_13], + ) + T.writes(out[v_7, v_8, v_9, v_10]) + with T.init(): + out[v_7, v_8, v_9, v_10] = T.float16(0) + out[v_7, v_8, v_9, v_10] = ( + out[v_7, v_8, v_9, v_10] + + A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9 * StrideW, v_13] + * B[v_10, v_11, v_12, v_13] + ) + + +@T.prim_func +def bias_add_nhwc_2d_fp16( + A_handle: T.handle, + B_handle: T.handle, + out_handle: T.handle, + N: T.int64, + H: T.int64, + W: T.int64, + C: T.int64, +): + A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") + B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16") + out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16") + for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), T.int64(0), v_ax3]) + T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3]) + out[v_ax0, v_ax1, v_ax2, v_ax3] = ( + A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, T.int64(0), T.int64(0), v_ax3] + ) + + +@T.prim_func +def bias_add_nhwc_1d_fp16( + A_handle: T.handle, + B_handle: T.handle, + out_handle: T.handle, + N: T.int64, + H: T.int64, + W: T.int64, + C: T.int64, +): + A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") + B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16") + out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16") + for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[T.int64(0), T.int64(0), T.int64(0), v_ax3]) + T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3]) + out[v_ax0, v_ax1, v_ax2, v_ax3] = ( + A[v_ax0, v_ax1, v_ax2, v_ax3] + B[T.int64(0), T.int64(0), T.int64(0), v_ax3] + ) + + +@T.prim_func +def elem_add_2d_fp16( + in0_handle: T.handle, + in1_handle: T.handle, + out_handle: T.handle, + N: T.int64, + M: T.int64, +): + in0 = T.match_buffer(in0_handle, [N, M], dtype="float16") + in1 = T.match_buffer(in1_handle, [N, M], dtype="float16") + out = T.match_buffer(out_handle, [N, M], dtype="float16") + for ax0, ax1 in T.grid(N, M): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(in0[v_ax0, v_ax1], in1[v_ax0, v_ax1]) + T.writes(out[v_ax0, v_ax1]) + out[v_ax0, v_ax1] = in0[v_ax0, v_ax1] + in1[v_ax0, v_ax1] + + +@T.prim_func +def elem_add_3d_fp16( + in0_handle: T.handle, + in1_handle: T.handle, + out_handle: T.handle, + B: T.int64, + N: T.int64, + M: T.int64, +): + in0 = T.match_buffer(in0_handle, [B, N, M], dtype="float16") + in1 = T.match_buffer(in1_handle, [B, N, M], dtype="float16") + out = T.match_buffer(out_handle, [B, N, M], dtype="float16") + for ax0, ax1, ax2 in T.grid(B, N, M): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(in0[v_ax0, v_ax1, v_ax2], in1[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = in0[v_ax0, v_ax1, v_ax2] + in1[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def elem_add_4d_fp16( + A_handle: T.handle, + B_handle: T.handle, + out_handle: T.handle, + N: T.int64, + H: T.int64, + W: T.int64, + C: T.int64, +): + A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16") + B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16") + out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16") + for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3]) + out[v_ax0, v_ax1, v_ax2, v_ax3] = ( + A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, v_ax1, v_ax2, v_ax3] + ) + + +@T.prim_func +def scalar_mul_3d_fp16( + inp0_handle: T.handle, + out_handle: T.handle, + D1: T.int64, + D2: T.int64, + D3: T.int64, + scalar: T.float16, +): + inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16") + out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16") + for ax0, ax1, ax2 in T.grid(D1, D2, D3): + with T.block("T_mul"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(inp0[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * scalar + + +@T.prim_func +def erf_3d_fp32( + inp0_handle: T.handle, + out_handle: T.handle, + D1: T.int64, + D2: T.int64, + D3: T.int64, +): + inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32") + out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32") + for ax0, ax1, ax2 in T.grid(D1, D2, D3): + with T.block("T_erf"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(inp0[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = T.erf(inp0[v_ax0, v_ax1, v_ax2]) + + +@T.prim_func +def scalar_add_3d_fp16( + inp0_handle: T.handle, + out_handle: T.handle, + D1: T.int64, + D2: T.int64, + D3: T.int64, + scalar: T.float16, +): + inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16") + out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16") + for ax0, ax1, ax2 in T.grid(D1, D2, D3): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(inp0[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = scalar + inp0[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def elem_mul_3d_fp16( + inp0_handle: T.handle, + inp1_handle: T.handle, + out_handle: T.handle, + D1: T.int64, + D2: T.int64, + D3: T.int64, +): + inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16") + inp1 = T.match_buffer(inp1_handle, [D1, D2, D3], dtype="float16") + out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16") + for ax0, ax1, ax2 in T.grid(D1, D2, D3): + with T.block("T_mul"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(inp0[v_ax0, v_ax1, v_ax2], inp1[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * inp1[v_ax0, v_ax1, v_ax2] + + +@T.prim_func +def cast_3d_fp16( + inp0_handle: T.handle, + out_handle: T.handle, + D1: T.int64, + D2: T.int64, + D3: T.int64, +): + inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32") + out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16") + for ax0, ax1, ax2 in T.grid(D1, D2, D3): + with T.block("T_cast"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(inp0[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = T.Cast("float16", inp0[v_ax0, v_ax1, v_ax2]) + + +@T.prim_func +def cast_3d_fp32( + inp0_handle: T.handle, + out_handle: T.handle, + D1: T.int64, + D2: T.int64, + D3: T.int64, +): + inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16") + out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32") + for ax0, ax1, ax2 in T.grid(D1, D2, D3): + with T.block("T_cast"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(inp0[v_ax0, v_ax1, v_ax2]) + T.writes(out[v_ax0, v_ax1, v_ax2]) + out[v_ax0, v_ax1, v_ax2] = T.Cast("float32", inp0[v_ax0, v_ax1, v_ax2]) + + +def get_tir_pattern() -> List[tvm.tir.PrimFunc]: + """Get the tir patterns for backend dispatch.""" + return [ + matmul_rrr_fp16, + bias_row_2d_fp16, + bias_row_1d_fp16, + batch_bias_row_2d_fp16, + batch_bias_row_1d_fp16, + relu_fp16, + erf_3d_fp32, + batch_matmul_rrr_2d_fp16, + batch_matmul_rrr_3d_fp16, + copy_4d_fp16, + padding_2d_nhwc_fp16, + conv2d_nhwc_fp16, + bias_add_nhwc_2d_fp16, + bias_add_nhwc_1d_fp16, + elem_add_2d_fp16, + elem_add_3d_fp16, + elem_add_4d_fp16, + elem_mul_3d_fp16, + scalar_add_3d_fp16, + scalar_mul_3d_fp16, + cast_3d_fp16, + cast_3d_fp32, + ] diff --git a/python/tvm/relax/binding_rewrite.py b/python/tvm/relax/binding_rewrite.py new file mode 100644 index 000000000000..a9f6d878ad0d --- /dev/null +++ b/python/tvm/relax/binding_rewrite.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name +"""Developer API of add/remove/replace bindings in Relax.""" + +from typing import Optional + +import tvm +import tvm._ffi +from tvm.runtime import Object +from . import Binding, DataflowBlock, Expr, Function, Var +from . import _ffi_api + + +@tvm._ffi.register_object("relax.DataflowBlockRewrite") +class DataflowBlockRewrite(Object): + """ + A binding/statement-level dataflow block rewriter. + + Notes + ----- + Due to the immutable and copy-on-write nature of TVM AST nodes, the rewriting is not done in + place. Instead, a new DataflowBlock is created and returned with mutated_dfb. Similarly, its new + root Function is created and returned by mutated_root_fn. To apply this change for an IRModule, + use mutate_irmodule which rewrites the old function that registered in the constructor. + """ + + def __init__(self, dfb: DataflowBlock, root_fn: Function): + """ + Construct a rewriter with the DataflowBlock to rewrite and its root function. + + Parameters + ---------- + dfb : DataflowBlock + The DataflowBlock to rewrite. + root_fn : Function + The root function of the DataflowBlock. + """ + self.func_name = root_fn.__name__ if hasattr(root_fn, "__name__") else None + self.__init_handle_by_constructor__( + _ffi_api.DataflowBlockRewrite, dfb, root_fn # type: ignore + ) + + def replace_all_uses(self, old_var: Var, new_var: Var) -> None: + """ + Replace all uses of old_var with new_var. + + Parameters + ---------- + old_var : Var + The old variable to replace. + new_var : Var + The new variable to replace with. + """ + _ffi_api.dfb_rewrite_replace_all_uses(self, old_var, new_var) # type: ignore + + def add_binding(self, binding: Binding) -> None: + return _ffi_api.dfb_rewrite_add_binding(self, binding) # type: ignore + + def add(self, expr: Expr, name: Optional[str] = None, is_dfvar: bool = False) -> None: + """ + Add a new statement to the DataflowBlock with an automatically generated variable name. + + Parameters + ---------- + expr : Expr + The expression to add. + name : Optional[str], optional + Variable name, by default None + is_dfvar : bool, optional + The variable type, by default False + + Notes + ----- + If the variable name is not given, it will be automatically generated in a form of + "tmp${COUNTER}". The variable type will be DataflowVar if is_dfvar is True, otherwise + it will be Var. Being Var means the variables are output variables of the DataflowBlock. + While being DataflowVar means the variables are internal variables of the DataflowBlock. + """ + _ffi_api.dfb_rewrite_add(self, expr, name, is_dfvar) # type: ignore + + def remove_unused(self, var: Var, allow_undef=False) -> None: + """ + Remove a statement by its variable definition if and only if it is unused. + + Parameters + ---------- + var : Var + The unused variable definition. + allow_undef : bool, optional + Whether to allow var being undefined variable, by default False + + Raises + ------ + TVMError if the variable is used or undefined (allow_undef=False). + """ + _ffi_api.dfb_rewrite_remove_unused(self, var, allow_undef) # type: ignore + + def remove_all_unused(self) -> None: + """ + Remove all unused variables. + + Notes + ----- + This could remove unused variables in other DataflowBlocks as well. + """ + _ffi_api.dfb_rewrite_remove_all_unused(self) # type: ignore + + def mutated_dfb(self) -> DataflowBlock: + """ + Returns the mutated DataflowBlock. + """ + return self.dfb + + def mutated_root_fn(self) -> Function: + """ + Returns the mutated root function. + """ + ret = self.root_fn + if self.func_name: + ret.__name__ = self.func_name + return ret + + def mutate_irmodule(self, irmodule: tvm.IRModule) -> tvm.IRModule: + """ + Return an updated IRModule by replacing the old function with the mutated root function. + + Parameters + ---------- + irmodule : tvm.IRModule + The base IRModule to update. + + Returns + ------- + tvm.IRModule + The updated IRModule. + """ + ret = _ffi_api.dfb_rewrite_mutate_irmodule(self, irmodule) # type: ignore + if hasattr(irmodule, "__name__"): + ret.__name__ = irmodule.__name__ + return ret diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py new file mode 100644 index 000000000000..c2a32e563fed --- /dev/null +++ b/python/tvm/relax/block_builder.py @@ -0,0 +1,658 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name +"""Developer API of constructing Relax AST.""" + +from typing import Dict, List, Optional, Union, Any, Callable +from tvm.ir.module import IRModule +from tvm.runtime import Object +from tvm import relax as rx, tir +import tvm +from .expr import ( + Expr, + Var, + GlobalVar, + BindingBlock, + Tuple, + BaseFunc, + Binding, +) +from .struct_info import StructInfo +from .op.base import call_tir +from . import _ffi_api +from .utils import gen_call_tir_inputs + + +class FunctionScope(object): + """Auxiliary scope for function""" + + def __init__(self, block_builder, name, params, attrs): + self._bb = block_builder + self._name = name + self._params = params + self._attrs = attrs + + def __enter__(self): + self._bb._enter_function_scope(self._name, self._params, self._attrs) + + def __exit__(self, exc_type, exc_val, exc_tb): + # __exit__ should properly handle the case where the with block exits with an exception + # when handling error case in exit, always check if there is already an exception + # been thrown in the with block + self._bb._exit_function_scope(exc_type, exc_val, exc_tb) + + +class DataflowScope(object): + """Auxiliary scope for Dataflow block""" + + def __init__(self, block_builder): + self._bb = block_builder + + def __enter__(self): + block = self._bb._end_block() + if len(block.bindings) > 0: + self._bb._blocks.append(block) + self._bb._begin_dataflow_block() + + def __exit__(self, ptype, value, trace): + block = self._bb._end_block() + if len(block.bindings) > 0: + self._bb._blocks.append(block) + self._bb._begin_binding_block() + + +class TestingScope(object): + """Auxiliary scope for testing purposes""" + + def __init__(self, block_builder, def_vars): + self._bb = block_builder + shape_vars = [] + for var in def_vars: + if isinstance(var, tvm.tir.Var): + shape_vars.append(var) + else: + raise ValueError("def_vars only can take tir.Var") + # setup a dummy var so shape is in scope. + sparam = rx.Var("sparam", rx.ShapeStructInfo(shape_vars)) + self._scope_params = [sparam] + + def __enter__(self): + self._bb.begin_scope(self._scope_params) + self._bb._begin_dataflow_block() + + def __exit__(self, ptype, value, trace): + self._bb._end_block() + self._bb.end_scope() + + +@tvm._ffi.register_object("relax.BlockBuilder") +class BlockBuilder(Object): + """A builder to build Relax IR for testing and dev. + + Examples + -------- + .. code-block:: python + + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16") + bb = rx.BlockBuilder() + with bb.function([x, y], "func"): + with bb.dataflow() as df: + lv0 = bb.emit(rx.add(x, y)) + lv1 = bb.emit(rx.multiply(lv0, y)) + gv0 = bb.emit_output(lv1) + bb.emit_func_output(gv0) + mod = bb.get() + + BlockBuilder can also be used to construct neural networks with nn.Module API + + .. code-block:: python + + from tvm.relax.testing import nn + + n = tir.Var("n", "int64") + input_size = 784 + hidden_sizes = [128, 32] + output_size = 10 + bb = rx.BlockBuilder() + + with bb.function("main"): + model = nn.Sequential( + nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LogSoftmax(), + ) + data = nn.Placeholder((n, input_size), name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + mod = bb.get() + """ + + _current = None + + @staticmethod + def current(): + """Returns the current BlockBuilder.""" + return BlockBuilder._current + + def __init__(self, mod: IRModule = None): + self._blocks: List[BindingBlock] = [] + # a boolean flag that tracks if emit_func_output has been called + self._is_emit_func_output_called = False + self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate, mod) # type: ignore + + def _begin_dataflow_block(self) -> None: + _ffi_api.BlockBuilderBeginDataflowBlock(self) # type: ignore + + def _begin_binding_block(self) -> None: + _ffi_api.BlockBuilderBeginBindingBlock(self) # type: ignore + + def _end_block(self) -> BindingBlock: + return _ffi_api.BlockBuilderEndBlock(self) # type: ignore + + def _enter_function_scope(self, name, params, attrs): + if BlockBuilder.current() is not None: + raise RuntimeError("BlockBuilder does not allow nested functions.") + BlockBuilder._current = self + self._func_name = name + self._func_params = params + self._func_attrs = attrs + self.begin_scope(params) + self._begin_binding_block() + + def _exit_function_scope(self, exc_type, exc_val, exc_tb): + # record + is_emit_func_output_called = self._is_emit_func_output_called + # recover to default state + self._blocks = [] + self._is_emit_func_output_called = False + BlockBuilder._current = None + + # NOTE: we must raise after we recover the state so future + # block builder scoping functions correctly + if exc_type is None: + if not is_emit_func_output_called: + raise RuntimeError("emit_func_output must be called in a relax function.") + + def function( + self, + name: str, + params: Optional[Union[Var, Tuple, List[Var]]] = None, + attrs: Optional[Dict[str, Object]] = None, + ) -> FunctionScope: + """Annotate a Relax function. + + Parameters + ---------- + name : str, optional + The name of the function + + params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional + The parameters of the function. + If params is None, it means deferring initialization of function parameters + until emit_func_output. + + attrs : Dict[str, Object], optional + The function attrs + + Returns + ------- + ret: FunctionScope + A FunctionScope for building a Relax function node. + """ + if not params: + params = None + elif isinstance(params, rx.Var): + params = [params] + elif isinstance(params, (list, tuple)): + for param in params: + if not isinstance(param, rx.Var): + raise TypeError( + "each element of function parameters must be of type tvm.relax.Var,\ + but got: {}".format( + type(param) + ) + ) + if attrs is None: + attrs = {} + return FunctionScope(self, name, params, attrs) + + def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope: + """Start a scope for unit-testing purposes. + + Parameters + ---------- + def_vars: List[tir.Var] + List of symbolic variables that are marked as defined in scope. + + Returns + ------- + ret: TestingScope + A TestingScope to setup builder for emit and other purposes. + """ + return TestingScope(self, def_vars) + + def dataflow(self) -> DataflowScope: + """Annotate a Relax dataflow block. + + Returns + ------- + ret: DataflowScope + A DataflowScope for building a Relax dataflow block. + """ + return DataflowScope(self) + + def emit(self, expr: Expr, name_hint: str = "") -> Var: + """Emit an expr. + This infers the shape and type of the expr, create a variable, + and bind the expr to the variable. + + Parameters + ---------- + expr : tvm.relax.Expr + The Expr to be emitted. + + name_hint : str + Name hint for the bound variable. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that gets bound to the input expr. + """ + return _ffi_api.BlockBuilderEmit(self, expr, name_hint) # type: ignore + + def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr: + """Generate a call node according to the te function. + This function converts arguments from relax expression to te tensor, + The callback func should return a te tensor or a list of te tensors. + Please see detailed example in emit_te + + Parameters + ---------- + func : Callable + A function that returns a te tensor or a list of te tensors. + + args : Any, optional + arguments passed to the function. + + kwargs : Any, optional + The keyword arguments passed to the function. + Note that the following keyword args are reserved: + + - 'primfunc_name_hint' for passing name hint to the PrimFunc + that gets generated. + - 'primfunc_attrs' is reserved for passing func attributes to + be added to the PrimFunc that gets created. + + + Returns + ------- + ret : tvm.relax.Call + A newly created call node + """ + + primfunc_name = kwargs.pop("primfunc_name_hint", None) + tir_func, call_args, output_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs) + if not primfunc_name: + primfunc_name = func.__name__ + gvar = self.add_func(tir_func, primfunc_name) + + return call_tir(gvar, call_args, output_sinfo, tir_vars) + + def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var: + """Emit a call node according to the te function. + This function converts arguments from relax expression to te tensor, + The callback func should return a te tensor or a list of te tensors. + + Parameters + ---------- + func : Callable + A function that returns a te tensor or a list of te tensors. + + args : Any, optional + arguments passed to the function. + + kwargs : Any, optional + The keyword arguments passed to the function. + Note that the key "primfunc_name_hint" is reserved for passing name hint + to the PrimFunc that gets generated. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that gets bound to the call code. + + Example + ------- + + .. code-block:: python + + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + + def te_func(args, args_dict, msg): + A = args[0] + B = args_dict["B"] + return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + + with bb.function([x, y], "rx_func"): + out = bb.emit_te(te_func, [x], {"B": y}, msg="hello") + bb.emit_func_output(out) + + will result in TVMScript + + .. code-block:: python + + @tvm.script.ir_module + class Module: + @T.prim_func + def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, + var_compute: T.handle) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32") + compute = T.match_buffer(var_compute, [128, 128], dtype="float32") + # body + # with T.block("root") + for i0, i1 in T.grid(128, 128): + with T.block("compute"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads([rxplaceholder[i, j], rxplaceholder_1[i, j]]) + T.writes([compute[i, j]]) + compute[i, j] = rxplaceholder[i, j] + rxplaceholder_1[i, j] + + @R.function + def rx_func(x: Tensor((n, m), "float32"), y: Tensor((n, m), "float32")) -> Tensor: + # block 0 + gv = relax.call_tir("te_func", (x, y), R.Tensor((128, 128), "float32")) + return gv + + Example + ------- + + .. code-block:: python + + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", relax.TensorStructInfo([n], "float32")) + y = relax.Var("y", relax.TensorStructInfo([n + 1], "float32")) + + def te_func(A): + C = te.compute((n + 1), lambda i: A[i]) + return C + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, y) + bb.emit_func_output(x1) + + will result in TVMScript + + .. code-block:: python + + @tvm.script.ir_module + class Module: + @T.prim_func + def te_func(var_rxplaceholder: T.handle, var_compute: T.handle, n: T.int64) -> None: + rxplaceholder = T.match_buffer(var_rxplaceholder, [n + T.int64(1)], + dtype="float32") + compute = T.match_buffer(var_compute, [n + T.int64(1)], dtype="float32") + # body + # with T.block("root") + for i0 in T.serial(0, n + T.int64(1)): + with T.block("compute"): + i = T.axis.spatial(n + T.int64(1), i0) + T.reads([rxplaceholder[i]]) + T.writes([compute[i]]) + compute[i] = rxplaceholder[i] + + @R.function + def rx_func(x: Tensor((n,), "float32"), y: Tensor(((n + 1),), "float32")) + -> Tensor(None, "float32", ndim=-1): + # block 0 + gv = relax.call_tir(te_func, (y,), R.Tensor((n + 1,), "float32"), (n,)) + return gv + """ + return self.emit(self.call_te(func, *args, **kwargs)) + + def match_cast(self, value: Expr, struct_info: StructInfo) -> Var: + """Emit a MatchCast. + + Parameters + ---------- + value : tvm.relax.Expr + The value of the MatchCast to be emitted. + + struct_info : StructInfo + The struct info to be matched. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that get bounds to be the casted result. + """ + return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info) # type: ignore + + def emit_output(self, output: Union[Expr, Tuple, List[Expr]], name_hint: str = "") -> Var: + """Emit output for the current dataflow block or function. + + Parameters + ---------- + output : Expr | Tuple | List[Expr] + The output of the current block/function. + + name_hint : str + Name hint for the bound variable. + + Returns + ------- + ret : tvm.relax.Var + The return variable which gets bound to the output. + """ + if isinstance(output, (list, tuple)): + output = Tuple(output) + return _ffi_api.BlockBuilderEmitOutput(self, output, name_hint) # type: ignore + + def emit_func_output( + self, + output: Union[Expr, Tuple, List[Expr]], + params: Optional[Union[Var, Tuple, List[Var]]] = None, + ) -> None: + """Emit output for the function. + + Parameters + ---------- + output : Expr | Tuple | List[Expr] + The output of the current block/function. + + params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional + The parameters of the function to be built. + If params is None, it means the params have been initialized in the function with scope. + """ + if self._is_emit_func_output_called: + raise RuntimeError("emit_func_output must be called exactly once in a relax function.") + self._is_emit_func_output_called = True + + if self._func_params is not None and params is not None: + raise RuntimeError( + "function parameters have been initialized in the function with scope." + ) + + if self._func_params is None and params is None: + raise RuntimeError("Relax function must have parameter.") + + if self._func_params is None: + self._func_params = params + + if BlockBuilder.current() is not self: + raise RuntimeError("BlockBuilder._current must be self.") + + if isinstance(output, (list, tuple)): + output = Tuple(output) + + block = self._end_block() + if len(block.bindings) > 0: + self._blocks.append(block) + seqe = self.normalize(rx.SeqExpr(self._blocks, output)) + + # do not specify ret_struct_info and let constructor deduce + # from seqe.struct_info + func = rx.Function(self._func_params, seqe) + for key, value in self._func_attrs.items(): + func = func.with_attr(key, value) + self.end_scope() + self.add_func(func, self._func_name) + + def normalize(self, expr: Expr) -> Expr: + """Normalize an Expr to complete its shape and type. + + Parameters + ---------- + expr : Expr + The input expr. + + Returns + ------- + ret : Expr + The expr with normalized shape and type. + """ + return _ffi_api.BlockBuilderNormalize(self, expr) # type: ignore + + def get(self) -> tvm.IRModule: + """Return the IRModule being built. + + Returns + ------- + ret : tvm.IRModule + An IRModule with Relax and TIR functions being built. + """ + return _ffi_api.BlockBuilderGetContextIRModule(self) # type: ignore + + def get_unique_name(self, name_prefix: str) -> str: + """Generate a unique name with a specified prefix. + + Parameters + ---------- + name_hint : str + The name prefix. + + Returns + ------- + ret : str + The generated name. + """ + return _ffi_api.BlockBuilderGetUniqueName(self, name_prefix) # type: ignore + + def add_func(self, func: BaseFunc, func_name: str) -> GlobalVar: + """Add a Relax function or a TIR PrimFunc to the IRModule being built. + + Parameters + ---------- + func : BaseFunc + The function to be added. + + func_name : str + The name of the function to be added. + + Returns + ------- + gvar : GlobalVar + The global var bound to the added function. + """ + return _ffi_api.BlockBuilderAddFunction(self, func, func_name) # type: ignore + + def update_func(self, gv: GlobalVar, updated_func: BaseFunc) -> None: + """Add a Relax function or a TIR PrimFunc to the IRModule being built. + + Parameters + ---------- + gv : GlobalVar + The global var referring the function to be updated. + + updated_func : BaseFunc + The updated function. + """ + return _ffi_api.BlockBuilderUpdateFunction(self, gv, updated_func) # type: ignore + + def current_block_is_dataflow(self) -> bool: + """Check if the block being built is DataflowBlock or not. + + Returns + ------- + ret : bool + A boolean that indicates if the block being built is DataflowBlock or not. + """ + return _ffi_api.BlockBuilderCurrentBlockIsDataFlow(self) # type: ignore + + def emit_normalized(self, binding: Binding) -> None: + """Emit an already normalized binding. + + Parameters + ---------- + binding: Binding + The binding to be emitted. + """ + _ffi_api.BlockBuilderEmitNormalized(self, binding) # type: ignore + + def lookup_binding(self, var: Var) -> Optional[Expr]: + """Lookup a var in the binding table binding_table_. + + Parameters + ---------- + var: Var + The input var. + + Returns + ------- + expr: Expr + The Expr bound to the input var. + """ + return _ffi_api.BlockBuilderLookupBinding(self, var) # type: ignore + + def begin_scope(self, params: Optional[List[Var]] = None) -> None: + """Begin a new scope, with optional parameters that + are visible within the scope. + + Parameters + ---------- + params: Optional[List[Var]] + Parameters that are visible within the scope. + + Note + ---- + This function should be called when new scope is introduced + (function, seq) to properly track the variable availability + and help the best effort deduction. + """ + + return _ffi_api.BlockBuilderBeginScope(self, params) # type: ignore + + def end_scope(self) -> None: + """End the current scope. Please see `begin_scope` for details""" + + return _ffi_api.BlockBuilderEndScope(self) # type: ignore diff --git a/python/tvm/relax/dpl/__init__.py b/python/tvm/relax/dpl/__init__.py new file mode 100644 index 000000000000..6451238428c2 --- /dev/null +++ b/python/tvm/relax/dpl/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""The Relax Dataflow Pattern Language.""" + +from .pattern import * +from .context import * +from .rewrite import rewrite_call, rewrite_bindings diff --git a/python/tvm/relax/dpl/_ffi.py b/python/tvm/relax/dpl/_ffi.py new file mode 100644 index 000000000000..6699e42bee63 --- /dev/null +++ b/python/tvm/relax/dpl/_ffi.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DataFlow Pattern Language FFI bindings.""" +import tvm._ffi + +tvm._ffi._init_api("relax.dpl", __name__) diff --git a/python/tvm/relax/dpl/context.py b/python/tvm/relax/dpl/context.py new file mode 100644 index 000000000000..69a5e70ed0f1 --- /dev/null +++ b/python/tvm/relax/dpl/context.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""The Graph Matching Context Manager for Dataflow Pattern Language.""" + +from typing import Optional, Dict + +import tvm +from ..expr import DataflowBlock, Var +from .pattern import DFPattern +from . import _ffi as ffi + + +class PatternContext(tvm.runtime.Object): + """A context object for doing graph (topogical) pattern matching.""" + + def __init__(self, incremental=False): + """ + Initialize the PatternContext + + Parameters + ---------- + incremental : bool, optional + perform incremental matching based on the recent context, by default False + """ + self.__init_handle_by_constructor__(ffi.PatternContext, incremental) # type: ignore + + def __enter__(self): + """Enter the context""" + ffi.enter_context(self) # type: ignore + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Exit the context""" + ffi.exit_context(self) # type: ignore + + @staticmethod + def current() -> "PatternContext": + """ + Get the current context + + Returns + ------- + PatternContext + The current context + """ + return ffi.current_context() # type: ignore + + def match_dfb( + self, + dfb: DataflowBlock, + start_hint: Optional[Var] = None, + must_include_hint: bool = False, + ) -> Dict[DFPattern, Var]: + """ + Match a DataflowBlock via a graph of DFPattern and corresponding constraints + + Parameters + ---------- + dfb : DataflowBlock + The DataflowBlock to match + start_hint : Optional[Var], optional + Indicating the starting expression to match, by default None + must_include_hint : bool, optional + Whether the start_hint expression must be matched, by default False + + Returns + ------- + Dict[DFPattern, Var] + The mapping from DFPattern to matched expression + """ + return ffi.match_dfb(self, dfb, start_hint, must_include_hint) # type: ignore diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py new file mode 100644 index 000000000000..79883b9161ec --- /dev/null +++ b/python/tvm/relax/dpl/pattern.py @@ -0,0 +1,1125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern types in Relax Dataflow Pattern Language""" +# pylint: disable=no-member +# pylint: disable=pointless-statement + +import typing +from typing import Dict, List, Optional, Tuple, Union + +import tvm +import tvm._ffi as tvm_ffi +from tvm.ir.container import Array +from tvm.ir.expr import PrimExpr +from tvm.relay.op import get + +from ...ir import make_node +from ...ir.base import Node +from ...runtime import Object +from ..expr import Expr, Var +from . import _ffi as ffi + + +def register_df_node(type_key=None): + """ + Register a Relax node type + + Parameters + ---------- + type_key : str or cls + The type key of the node + """ + if not isinstance(type_key, str): + return tvm_ffi.register_object("relax.dpl." + type_key.__name__)(type_key) + return tvm_ffi.register_object(type_key) + + +class DFPattern(Node): + """Base class of all Patterns.""" + + def __call__(self, *args, varg_default_wildcard=False, add_constraint=True) -> "CallPattern": + """ + Syntax sugar for creating a CallPattern with argument patterns + + Returns + ------- + result: CallPattern + The resulting CallPattern + """ + return CallPattern(self, args, varg_default_wildcard, add_constraint) + + def __or__(self, other: "DFPattern") -> "OrPattern": + """ + Syntax sugar for creating an OrPattern + + Parameters + ---------- + other: DFPattern + Alternative pattern + + Returns + ------- + result: OrPattern + The resulting OrPattern + """ + return OrPattern(self, other) + + def __and__(self, other: "DFPattern") -> "AndPattern": + """ + Syntax sugar for creating an AndPattern + + Parameters + ---------- + other: DFPattern + Additional pattern to satisfy + + Returns + ------- + result: AndPattern + The resulting AndPattern + """ + return AndPattern(self, other) + + def __invert__(self) -> "NotPattern": + """ + Syntax sugar for creating a DFPattern to reject + + Returns + ------- + result: NotPattern + The resulting NotPattern + """ + return reject(self) + + def has_attr(self, attrs: Dict[str, Object]) -> "AttrPattern": + """ + Add an attribute constraint to this pattern + + Parameters + ---------- + attrs: Dict[str, Object] + + Returns + ------- + result: AttrPattern + The resulting AttrPattern + """ + attrs = make_node("DictAttrs", **attrs) + return AttrPattern(self, attrs) + + def has_type(self, ttype: tvm.ir.type.Type) -> "TypePattern": + """ + Add a type constraint to this pattern + + Parameters + ---------- + ttype: tvm.ir.type.Type + The type to match + + Returns + ------- + result: TypePattern + The resulting TypePattern + """ + return TypePattern(self, ttype) + + def has_dtype(self, dtype: str) -> "DataTypePattern": + """ + Add a type constraint to this pattern + + Parameters + ---------- + dtype: str + The dtype to match + + Returns + ------- + result: DataTypePattern + The resulting DataTypePattern + """ + return has_dtype(dtype, self) + + def has_shape(self, shape: List[PrimExpr]) -> "ShapePattern": + """ + Add a shape constraint to this pattern + + Parameters + ---------- + shape: List[PrimExpr] + Expected shape list + + Returns + ------- + result: ShapePattern + The resulting ShapePattern + + Note + ---- + has_shape assumes that the matched relax.Expr only has one + output tensor. Use is_tuple for those with multiple outputs. + """ + if not isinstance(shape, (list, tuple, tvm.ir.PrimExpr)): + raise ValueError("has_shape takes a list or tuple as input.") + return ShapePattern(pattern=self, shape=shape) + + def match(self, expr, var2val: Optional[Dict[Var, Expr]] = None) -> bool: + """ + Match a relax.Expr syntactically + + Parameters + ---------- + expr : tvm.relax.Expr + The expression to match + var2val : Optional[Dict[tvm.relax.Var, tvm.relax.Expr]] + A mapping from relax.Var to relax.Expr for autojump. + + Returns + ------- + result: bool + Whether or not the expression matches the pattern + + Note + ---- + Unlike Relay whose function is an expression, functions in Relax consist + of blocks of bindings that are not syntactically connected. We use a + mapping (i.e., var2val) to mitigate the gap. For example, when matching + "relax.add(lv0, lv1)", given var2val, we match lv0's bound expression + when the recursive pattern matching goes to check lv0. The var2val mapping + can be computed through the tvm.relax.analysis.get_var2val function. + """ + return ffi.match_expr(self, expr, var2val) # type: ignore + + def extract_matched_expr( + self, expr, var2val: Optional[Dict[Var, Expr]] = None + ) -> Optional[Dict["DFPattern", Expr]]: + """ + Match a relax.Expr and return a map from matching patterns to matched expressions. + + Parameters + ---------- + expr : tvm.relax.Expr + The expression to match + var2val : Optional[Dict[tvm.relax.Var, tvm.relax.Expr]] + A mapping from relax.Var to relax.Expr for autojump. + + Returns + ------- + result: Optional[Dict[DFPattern, Expr]] + Map from matching patterns to matched expressions. + Return None if the pattern does not match expr. + + Note + ---- + Check the note of `match` for the meaning of var2val. + """ + return ffi.extract_matched_expr(self, expr, var2val) + + def used_by(self, other: Union["DFPattern", "PatternSeq"], index=-1) -> "PatternSeq": + """ + The current pattern being used by another pattern (sequence) + + Parameters + ---------- + other : Union[DFPattern, DFPattern] + The consumer pattern (sequence) + index : int, optional + The argument index called by the consumer pattern, by default -1 + + Returns + ------- + result: PatternSeq + A chained pattern sequence + """ + return _used_by(self, other, index) + + def __xor__(self, other: Union["DFPattern", "PatternSeq"]) -> "PatternSeq": + """Syntax sugar of DFPattern.used_by""" + return self.used_by(other, -1) + + def only_used_by(self, other: Union["DFPattern", "PatternSeq"], index=-1) -> "PatternSeq": + """ + The current pattern being **ONLY** used by another pattern (sequence) + + Parameters + ---------- + other : Union[DFPattern, DFPattern] + The consumer pattern (sequence) + index : int, optional + The argument index called by the consumer pattern, by default -1 + + Returns + ------- + result: PatternSeq + A chained pattern sequence + """ + return _only_used_by(self, other, index) + + def __rshift__(self, other: Union["DFPattern", "PatternSeq"]) -> "PatternSeq": + """Syntax sugar of DFPattern.only_used_by""" + return self.only_used_by(other, -1) + + def dup(self) -> "DFPattern": + """ + Duplicate the current pattern (new object under different address) + + Returns + ------- + DFPattern + A duplicated pattern + """ + return ffi.dup_pattern(self) # type: ignore + + def fork_to(self, *args) -> None: + """Fork the current pattern to multiple pattern branches""" + for v in args: + self ^ v + + +@register_df_node +class ExprPattern(DFPattern): + """A pattern which matches an expression. + + Parameters + ---------- + expr : tvm.relax.Expr + The expression to match. + """ + + def __init__(self, expr: Expr): + self.__init_handle_by_constructor__(ffi.ExprPattern, expr) # type: ignore + + +@register_df_node +class VarPattern(DFPattern): + """A pattern for Var. + + Parameters + ---------- + name_hint: str + The name of the variable. Optional, if not provided, + the pattern will match any VarNode. + """ + + def __init__(self, name_hint: str = ""): + self.__init_handle_by_constructor__(ffi.VarPattern, name_hint) # type: ignore + + +@register_df_node +class DataflowVarPattern(DFPattern): + """A pattern for DataflowVar. + + Parameters + ---------- + name_hint: str + The name of the variable. Optional, if not provided, + the pattern will match any VarNode. + """ + + def __init__(self, name_hint: str = ""): + self.__init_handle_by_constructor__(ffi.DataflowVarPattern, name_hint) # type: ignore + + +@register_df_node +class GlobalVarPattern(DFPattern): + """A pattern for GlobalVar. + + Parameters + ---------- + name_hint: str + The name of the variable. Optional, if not provided, + the pattern will match any GlobalVarNode. + """ + + def __init__(self, name_hint: str = ""): + self.__init_handle_by_constructor__(ffi.GlobalVarPattern, name_hint) # type: ignore + + +@register_df_node +class ExternFuncPattern(DFPattern): + """A external function pattern. + + Parameters + ---------- + global_symbol: str + The name of the function. Optional, if not provided, + the pattern will match any ExternFuncNode. + """ + + def __init__(self, global_symbol: str = ""): + self.__init_handle_by_constructor__(ffi.ExternFuncPattern, global_symbol) # type: ignore + + +@register_df_node +class ConstantPattern(DFPattern): + """A pattern matching a Relax Constant.""" + + def __init__(self): + self.__init_handle_by_constructor__(ffi.ConstantPattern) # type: ignore + + +@register_df_node +class CallPattern(DFPattern): + """A pattern matching a function call node. + + Parameters + ---------- + op: tvm.relax.dpl.DFPattern + The operation to be called. + + args: List[tvm.relax.dpl.DFPattern] + The arguments to the call or None to match any arguments. + + varg_default_wildcard: bool + If True, args can be fewer than actual provided arguments. + + add_constraint: bool + If True, automatically add "used-by" constraints between caller and callee expressions. + + Note + ---- + By setting varg_default_wildcard to True, we can only focus on the argument + patterns we specified. For example, CallPattern(Op, [A, B]) can match + a call of Op(A, B) or Op(A, B, C, ...) that has more arguments. However, + the specified argument patterns must be matched (i.e., A and B). + """ + + def __init__( + self, + op: "DFPattern", + args: Union[List["DFPattern"], typing.Tuple["DFPattern", ...]], + varg_default_wildcard: bool = False, + add_constraint=True, + ): + self.__init_handle_by_constructor__( + ffi.CallPattern, op, args, varg_default_wildcard # type: ignore + ) + + if add_constraint: + for i, arg in enumerate(args): + arg.used_by(self, i) + + +@register_df_node +class FunctionPattern(DFPattern): + """A pattern matching a function node in Relax. + + Parameters + ---------- + params: List[tvm.relax.dpl.DFPattern] + The parameters to the Function or None to match any parameters. + + body: tvm.relax.dpl.DFPattern + The body fo the Function + + """ + + def __init__( + self, + params: List["DFPattern"], + body: "DFPattern", + ): + self.__init_handle_by_constructor__(ffi.FunctionPattern, params, body) # type: ignore + + +@register_df_node +class TuplePattern(DFPattern): + """A patern matching a Relax Tuple. + + Parameters + ---------- + fields : Array[tvm.relax.dpl.DFPattern] + The fields in the tuple. + """ + + def __init__(self, fields: Array): + self.__init_handle_by_constructor__(ffi.TuplePattern, fields) # type: ignore + + def __getitem__(self, index: Optional[int]) -> "TupleGetItemPattern": + if index is not None: + # support negative index for being pythonic + if index < 0: + index += len(self) + if index >= len(self): + raise IndexError("TuplePattern index out of range") + else: + index = -1 # -1 means matching any index + return TupleGetItemPattern(self, index) + + def __len__(self): + return len(self.fields) + + +@register_df_node +class UnorderedTuplePattern(DFPattern): + """A patern matching a Relax Tuple unorderedly. + + Parameters + ---------- + fields : Array[tvm.relax.dpl.DFPattern] + The fields in the tuple. + """ + + def __init__(self, fields: Array): + self.__init_handle_by_constructor__(ffi.UnorderedTuplePattern, fields) # type: ignore + + def __len__(self): + return len(self.fields) + + +@register_df_node +class TupleGetItemPattern(DFPattern): + """Get index-th item from a TuplePattern. + + Parameters + ---------- + tuple_value: tvm.relax.dpl.DFPattern + The input tuple expression. + + index: Optional[int] + The index to match; Default (None) to match a TupleGetItem with any index. + """ + + def __init__(self, tuple_value: "DFPattern", index: Optional[int] = None): + match_index = index if index is not None else -1 + self.__init_handle_by_constructor__( + ffi.TupleGetItemPattern, tuple_value, match_index # type: ignore + ) + + +@register_df_node +class OrPattern(DFPattern): + """Create a Pattern that can match one of two conditions + + Parameters + ---------- + left: tvm.relax.dpl.DFPattern + One possible matching pattern. + right: tvm.relax.dpl.DFPattern + One possible matching pattern. + """ + + def __init__(self, left: "DFPattern", right: "DFPattern"): + self.__init_handle_by_constructor__(ffi.OrPattern, left, right) # type: ignore + + +@register_df_node +class AndPattern(DFPattern): + """Create a Pattern that must match two conditions + + Parameters + ---------- + left: tvm.relax.dpl.DFPattern + One must-matching pattern. + right: tvm.relax.dpl.DFPattern + One must-matching pattern. + """ + + def __init__(self, left: "DFPattern", right: "DFPattern"): + self.__init_handle_by_constructor__(ffi.AndPattern, left, right) # type: ignore + + +@register_df_node +class NotPattern(DFPattern): + """Create a Pattern that matches the negation of a condition. + + Parameters + ---------- + to_reject: tvm.relax.dpl.DFPattern + The pattern to deny. + """ + + def __init__(self, to_reject: "DFPattern"): + self.__init_handle_by_constructor__(ffi.NotPattern, to_reject) # type: ignore + + +@register_df_node +class WildcardPattern(DFPattern): + """A pattern which matches anything.""" + + def __init__(self): + self.__init_handle_by_constructor__(ffi.WildcardPattern) # type: ignore + + +@register_df_node +class TypePattern(DFPattern): + """A pattern that matches another pattern with a certain type annotation. + + Parameters + ---------- + pattern: tvm.relax.dpl.DFPattern + The input pattern that needs type annotation. + + ttype: tvm.ir.type.Type + The type to match. + """ + + def __init__(self, pattern: "DFPattern", ttype: tvm.ir.type.Type): + self.__init_handle_by_constructor__(ffi.TypePattern, pattern, ttype) # type: ignore + + +@register_df_node +class DataTypePattern(DFPattern): + """A pattern that matches another pattern with certain data type + + Parameters + ---------- + pattern: tvm.relax.dpl.DFPattern + The input pattern that needs type annotation. + + dtype: str + The dtype to match. + """ + + def __init__(self, pattern: "DFPattern", dtype: str): + self.__init_handle_by_constructor__(ffi.DataTypePattern, pattern, dtype) # type: ignore + + +@register_df_node +class ShapePattern(DFPattern): + """A pattern that matches another pattern with a certain tensor shape + + Parameters + ---------- + pattern: tvm.relax.dpl.DFPattern + The input pattern that needs type annotation. + + shape: List[tvm.ir.PrimExpr] + The shape to match. + """ + + def __init__(self, pattern: "DFPattern", shape: List[tvm.ir.PrimExpr]): + self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape) # type: ignore + + +@register_df_node +class PrimArrPattern(DFPattern): + """ + A pattern to match an array of PrimExpr + + Parameters + ---------- + shape : List[tvm.ir.PrimExpr] + The shape to match. + """ + + def __init__(self, shape: List[tvm.ir.PrimExpr]): + self.__init_handle_by_constructor__(ffi.PrimArrPattern, shape) # type: ignore + + def __getitem__(self, index: int): + if index >= len(self): + raise IndexError("PrimArrPattern index out of range") + return self.fields[index] + + def __len__(self): + return len(self.fields) + + +@register_df_node +class AttrPattern(DFPattern): + """Get match an expression with a certain attributes. + Currently only supports Op Attributes, not call Attributes. + + Parameters + ---------- + pattern: tvm.relax.dpl.DFPattern + The input pattern. + + attrs: tvm.ir.attrs.Attrs + The attributes to match. + """ + + def __init__(self, pattern: "DFPattern", attrs: tvm.ir.attrs.Attrs): + self.__init_handle_by_constructor__(ffi.AttrPattern, pattern, attrs) # type: ignore + + +def is_var(name: str = "") -> VarPattern: + """ + Syntatic sugar for creating an optionally named VarPattern. + + Parameters + ---------- + name: str + The name of the input pattern to match. + + Returns + ------- + result: tvm.relax.dpl.VarPattern + The resulting pattern. + """ + return VarPattern(name) + + +def is_gv(name: str = "") -> GlobalVarPattern: + """Syntax sugar for creating an optionally (if name is empty) named GlobalVarPattern.""" + return GlobalVarPattern(name) + + +def is_dfv(name: str = "") -> DataflowVarPattern: + """Syntax sugar for creating an optionally (if name is empty) named DataflowVarPattern.""" + return DataflowVarPattern(name) + + +def is_const() -> ConstantPattern: + """ + Syntatic sugar for creating a ConstantPattern. + + Parameters + ---------- + name: str + The name of the input pattern to match. + + Returns + ------- + result: tvm.relax.dpl.ConstantPattern + The resulting pattern. + """ + return ConstantPattern() + + +def is_expr(expr: Expr) -> ExprPattern: + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + expr: Expr + The Relax expression to match. + + Returns + ------- + result: tvm.relax.dpl.ExprPattern + The resulting pattern. + """ + return ExprPattern(expr) + + +def is_op(op_name: str) -> ExprPattern: + """ + Syntatic sugar for creating an operator ExprPattern. + + Parameters + ---------- + op_name: String + The name of the tvm.ir.op.Op object + + Returns + ------- + result: tvm.relax.dpl.ExprPattern + The resulting ExprPattern + """ + op = get(op_name) + return ExprPattern(op) + + +def is_tuple( + fields: Union[Array, List, Tuple], unordered=False +) -> Union[TuplePattern, UnorderedTuplePattern]: + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + fields : Array[tvm.relax.dpl.DFPattern] + The fields in the tuple. + + Returns + ------- + result: tvm.relax.dpl.DFPattern + The resulting pattern. + """ + if not isinstance(fields, (list, tuple, Array)): + raise ValueError("fields must be a list, tuple, or Array") + if unordered: + return UnorderedTuplePattern(fields) + return TuplePattern(fields) + + +def is_tuple_get_item(tuple_value: DFPattern, index: Optional[int] = None) -> TupleGetItemPattern: + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + tuple_value: tvm.relax.dpl.DFPattern + The input tuple expression. + + index: Optional[int] + The index to match; Default (None) to match a TupleGetItem with any index. + + Returns + ------- + result: tvm.relax.dpl.TupleGetItemPattern + The resulting pattern. + """ + return TupleGetItemPattern(tuple_value, index) + + +def wildcard() -> WildcardPattern: + """ + Syntatic sugar for creating a WildcardPattern. + + Returns + ------- + result: tvm.relax.dpl.WildcardPattern + The resulting pattern. + """ + return WildcardPattern() + + +def has_dtype(dtype: str, pattern: DFPattern = None) -> DataTypePattern: + """ + Syntatic sugar for creating a DataTypePattern + + Parameters + ---------- + dtype: str + The dtype to match + + pattern: tvm.relax.dpl.DFPattern + The pattern that needs type annotation + + Returns + ------- + result: tvm.relax.dpl.DataTypePattern + The resulting DataTypePattern + """ + if pattern is None: + pattern = wildcard() + return DataTypePattern(pattern, dtype) + + +def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern": + """ + Directly matches a shape which is an array of PrimExpr + + Parameters + ---------- + shape : List[tvm.ir.PrimExpr] + The expected shape + + Returns + ------- + PrimArrPattern + The resulting PrimArrPattern pattern + + Raises + ------ + ValueError + If the argument shape is not a list/tuple/tvm.ir.Array + + Note + ---- + The difference between p.has_shape(s) and is_shape(s) is that: has_shape + puts assumptions on the shape of the tensor matched by pattern p. While + is_shape directly matches the shape (an array of PrimExpr). + """ + if not isinstance(shape, (list, tuple, tvm.ir.Array)): + raise ValueError("is_shape takes a list or tuple as input.") + return PrimArrPattern(shape) + + +# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo +def _is_call_tir( + func_pattern: DFPattern, + args: Union[List, Tuple, TuplePattern] = None, +) -> CallPattern: + if args is None: + args = wildcard() + elif isinstance(args, (list, tuple)): + args = TuplePattern(args) + + return is_op("relax.call_tir")(func_pattern, args, add_constraint=False) + + +# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo +def is_call_tir( + func_name: str, + args: Union[List, Tuple, TuplePattern] = None, +) -> CallPattern: + """ + Syntax sugar for creating a CallPattern for call_tir that calls an function through global var. + + Parameters + ---------- + func_name : str + Name of the CPS function to call. + args : Union[List[DFPattern], Tuple[DFPattern]], optional + Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments + + Returns + ------- + CallPattern + The resulting CallPattern + """ + func_pattern = GlobalVarPattern(func_name) + return _is_call_tir(func_pattern, args) + + +def _is_call_dps_packed( + func_pattern: DFPattern, + args: Union[List, Tuple, TuplePattern] = None, +) -> CallPattern: + if args is None: + args = wildcard() + elif isinstance(args, (list, tuple)): + args = TuplePattern(args) + + return is_op("relax.call_dps_packed")(func_pattern, args, add_constraint=False) + + +def is_call_dps_packed( + func_name: str, + args: Union[List, Tuple, TuplePattern] = None, +) -> CallPattern: + """Syntax sugar for creating a CallPattern for call_dps_packed + + Parameters + ---------- + func_name : str + Name of the CPS function to call. + args : Union[List[DFPattern], Tuple[DFPattern]], optional + Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments + + Returns + ------- + CallPattern + The resulting CallPattern + """ + func_pattern = ExternFuncPattern(func_name) + return _is_call_dps_packed(func_pattern, args) + + +def is_call_packed( + func_name: str, args: Union[List[DFPattern], Tuple[DFPattern]] = None +) -> CallPattern: + """ + Syntax sugar for creating a CallPattern for call_packed + + Parameters + ---------- + func_name : str + Name of the external function to call + args : Union[List[DFPattern], Tuple[DFPattern]], optional + Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments + + Returns + ------- + CallPattern + The resulting CallPattern + """ + if args is None: + return ExternFuncPattern(func_name)(varg_default_wildcard=True, add_constraint=False) + return ExternFuncPattern(func_name)(*args) + + +def reject(pattern: DFPattern) -> NotPattern: + """ + Syntax sugar for creating a DFPattern to reject + + Parameters + ---------- + pattern : DFPattern + The pattern to deny + + Returns + ------- + result: NotPattern + The resulting NotPattern + """ + return NotPattern(pattern) + + +def has_attr(attrs, pattern=None) -> AttrPattern: + """ + Syntatic sugar for creating an AttrPattern + + Parameters + ---------- + attrs: Dict[str, Object] + The attributes to match + + pattern: Optional[tvm.relax.dpl.DFPattern] + The input pattern. + + Returns + ------- + result: tvm.relax.dpl.DFPattern + The resulting AttrPattern + """ + if pattern is None: + pattern = wildcard() + return pattern.has_attr(attrs) + + +@register_df_node +class PatternSeq(Node): + """A sequence of patterns with consecutive constraints""" + + def __init__(self, patterns: List[DFPattern], only_use=False): + """ + Initializer to PatternSeq + + Parameters + ---------- + patterns : List[DFPattern] + A chain of patterns + only_use : bool, optional + Whether the patterns follows only-used-by relations consecutively, by default False + """ + self.__init_handle_by_constructor__(ffi.PatternSeq, patterns, only_use) # type: ignore + + def used_by(self, other: Union[DFPattern, "PatternSeq"], index=-1) -> "PatternSeq": + """ + Assuming the right-most pattern must be used by the `other` pattern as a producer + + Parameters + ---------- + other : Union[DFPattern, PatternSeq] + The consumer pattern (sequence) + index : int, optional + The argument index called by the consumer pattern, by default -1 + + Returns + ------- + PatternSeq + A chained pattern sequence + + Note + ---- + If other is PatternSeq, it means the right-most pattern must be used by the left-most + pattern of the other sequence. + """ + return _used_by(self, other, index) + + def only_used_by(self, other: Union[DFPattern, "PatternSeq"], index=-1) -> "PatternSeq": + + """ + Assuming the right-most pattern must be **ONLY** used by the `other` pattern as a producer + + Parameters + ---------- + other : Union[DFPattern, PatternSeq] + The consumer pattern (sequence) + index : int, optional + The argument index called by the consumer pattern, by default -1 + + Returns + ------- + PatternSeq + A chained pattern sequence + + Note + ---- + If other is PatternSeq, it means the right-most pattern must be **ONLY** used by the + left-most pattern of the other sequence. + """ + return _only_used_by(self, other, index) + + def __getitem__(self, index: int) -> DFPattern: + """ + Access the pattern at the given index + + Parameters + ---------- + index : int + Index of the accessed pattern + + Returns + ------- + DFPattern + The accessed pattern + """ + return self.patterns[index] + + def __xor__(self, other) -> "PatternSeq": + """Syntax sugar of PatternSeq.used_by""" + return self.used_by(other, -1) + + def __rshift__(self, other) -> "PatternSeq": + """Syntax sugar of PatternSeq.only_used_by""" + return self.only_used_by(other, -1) + + def dup(self) -> "PatternSeq": + """ + Duplicate the pattern sequence (new object under different address) + + Returns + ------- + PatternSeq + A duplicated chain + """ + return ffi.dup_seq(self) # type: ignore + + +### Private functions + + +def _used_by( + lhs: Union[DFPattern, PatternSeq], + rhs: Union[DFPattern, PatternSeq], + index=-1, +) -> PatternSeq: + if isinstance(lhs, DFPattern): + lhs = PatternSeq([lhs]) + if isinstance(rhs, DFPattern): + rhs = PatternSeq([rhs]) + return ffi.used_by(lhs, rhs, index) # type: ignore + + +def _only_used_by( + lhs: Union[DFPattern, PatternSeq], rhs: Union[DFPattern, PatternSeq], index=-1 +) -> PatternSeq: + if isinstance(lhs, DFPattern): + lhs = PatternSeq([lhs]) + if isinstance(rhs, DFPattern): + rhs = PatternSeq([rhs]) + return ffi.only_used_by(lhs, rhs, index) # type: ignore + + +def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None): + """ + A simple utility to create patterns for an operation fused with bias addition and activation. + + Parameters + ---------- + op_name: str + The name of a Relax op, such as "relax.nn.conv2d" + + with_bias: bool + Whether or not to include bias addition + + activation: str + The name of an activation Relax op, such as "relax.nn.relu" + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a fused operation + """ + lhs = wildcard() + rhs = wildcard() + out = is_op(op_name)(lhs, rhs) + + if with_bias: + bias = wildcard() + out = is_op("relax.add")(out, bias) + + if activation: + return is_op(activation)(out) + + return out diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py new file mode 100644 index 000000000000..1b62a429030e --- /dev/null +++ b/python/tvm/relax/dpl/rewrite.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""APIs for pattern-based rewriting.""" +from typing import Dict, Callable +from .pattern import DFPattern +from .context import PatternContext + +from ..expr import Expr, Function, Var +from . import _ffi as ffi + + +def rewrite_call( + pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function +) -> Function: + """ + Rewrite a function with the given pattern and the rewriter function. + + Parameters + ---------- + pattern: DFPattern + The pattern to match. + + rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr] + The function to be called on a successful matching for rewriting. Given the matched + call node and the map of patterns and matched expressions, it should return a new call node + to replace the original one or the original matched call node as is. + + For example, to replace x + x with 2 * x, we can write the rewriter as follows: + ``` + x = wildcard() + pattern = is_op("relax.add")(x, x) + + def rewriter(orig, matchings): + return R.multiply(matchings[x], R.const(2, "float32")) + ``` + + func: Function + The function to rewrite. + + Returns + ------- + rewritten_func: Function + The rewritten or the input function, depending on the pattern matching result. + """ + return ffi.rewrite_call(pattern, rewriter, func) + + +def rewrite_bindings( + ctx: PatternContext, rewriter: Callable[[Dict[DFPattern, Var]], Dict[Var, Expr]], func: Function +) -> Function: + """ + Rewrite a function with the given pattern and the rewriter function. + + Parameters + ---------- + ctx: PatternContext + The pattern constraint context under which rewriting takes place. + + rewriter: Callable[[Dict[DFPattern, Var]], Dict[Var, Expr]] + The function to be called on a successful matching for rewriting. Given the map of patterns + and corresponding variables (bound variables or parameters), it should return a map that + specifies new values for matched bound variables. + + For example, to rewrite three matmuls for QKV projection in transformer models into one + matmul followed by slicing, one can use the follwoing rewriter: + ``` + inp_pat = wildcard() + Q_weight_pat, K_weight_pat, V_weight_pat = wildcard(), wildcard(), wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + def rewriter(matchings): + inp = matchings[inp_pat] + Q_weight = matchings[Q_weight_pat] + K_weight = matchings[K_weight_pat] + V_weight = matchings[V_weight_pat] + width = Q_weight.struct_info.shape[1] + + concat = R.concat([Q_weight, K_weight, V_weight], axis=1) + matmul = R.matmul(inp, concat) + Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width]) + K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2]) + V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width * 3]) + + # matchings[matmul1] gives the bound variable in the binding whose RHS matches with + # the matmul1 pattern. For example, lv0 in lv0 = R.matmul(x1, w0). + # We want to replace the RHS of this binding with Q. + return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V} + ``` + + func: Function + The function to rewrite. + + Returns + ------- + rewritten_func: Function + The rewritten or the input function, depending on the pattern matching result. + """ + return ffi.rewrite_bindings(ctx, rewriter, func) diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py new file mode 100644 index 000000000000..140c497eb967 --- /dev/null +++ b/python/tvm/relax/exec_builder.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""A builder to build Relax VM executable.""" +from enum import IntEnum +from typing import Optional, Union, List +import tvm +from tvm.runtime import Object +from tvm.runtime.container import ShapeTuple +from .vm_build import Executable +from . import _ffi_api + + +class SpecialReg(IntEnum): + """Magic numbers that represent special registers in vm.""" + + VOID_ARG = (1 << 54) + 0 + VM_STATE = (1 << 54) + 1 + + +class VMFuncKind(IntEnum): + """VM function kind code.""" + + PACKED_FUNC = 0 + VM_FUNC = 1 + + +class VMFuncScope(object): + """An object corresponds to each VM function, working as a context manager.""" + + stack: List["VMFuncScope"] = [] + + def __init__(self, exit_callback): + self.exit_callback = exit_callback + + def __enter__(self): + VMFuncScope.stack.append(self) + return self + + def __exit__(self, ptype, value, trace): + VMFuncScope.stack.pop() + self.exit_callback() + + +@tvm._ffi.register_object("relax.ExecBuilder") +class ExecBuilder(Object): + """A builder to emit instructions and build executable for the virtual machine.""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__(_ffi_api.ExecBuilderCreate) # type: ignore + + def r(self, idx: int) -> int: + """set instruction's argument as a register.""" + return _ffi_api.ExecBuilderR(self, idx) # type: ignore + + def imm(self, value: int) -> int: + """set instruction's argument as an immediate.""" + return _ffi_api.ExecBuilderImm(self, value) # type: ignore + + def c(self, idx: int) -> int: + """set instruction's argument as a constant.""" + return _ffi_api.ExecBuilderC(self, idx) # type: ignore + + def f(self, name: str) -> int: + """set instruction's argument as a function.""" + return _ffi_api.ExecBuilderF(self, name) # type: ignore + + def void_arg(self) -> int: + return self.r(SpecialReg.VOID_ARG) + + def vm_state(self) -> int: + return self.r(SpecialReg.VM_STATE) + + def declare_function(self, func_name: str, kind: VMFuncKind = VMFuncKind.PACKED_FUNC) -> None: + """Declare a function""" + _ffi_api.ExecBuilderDecalreFunction(self, func_name, kind) # type: ignore + + def function( + self, func_name: str, num_inputs: Optional[int] = 0, param_names: List[str] = None + ) -> VMFuncScope: + """annotate a VM function.""" + _ffi_api.ExecBuilderEmitFunction(self, func_name, num_inputs, param_names) # type: ignore + return VMFuncScope(lambda: _ffi_api.ExecBuilderEndFunction(self, func_name)) # type: ignore + + def _check_scope(self) -> None: + if len(VMFuncScope.stack) == 0: + raise ValueError("emit should happen in a function scope") + + def convert_constant(self, const: object) -> int: + return _ffi_api.ExecBuilderConvertConstant(self, const) # type: ignore + + def emit_call( + self, + name: str, + args: Optional[List[Union[tvm.nd.NDArray, tvm.DataType]]] = None, + dst: int = None, + ) -> None: + """emit a call instruction which calls a packed function.""" + self._check_scope() + if dst is None: + dst = SpecialReg.VOID_ARG + args_ = [] + if args is not None: + for arg in args: + if isinstance(arg, tuple): + shape_tuple = ShapeTuple(arg) + new_arg = self.convert_constant(shape_tuple) + args_.append(new_arg) + elif isinstance(arg, (tvm.nd.NDArray, tvm.DataType, ShapeTuple)): + new_arg = self.convert_constant(arg) + args_.append(new_arg) + else: + args_.append(arg) + _ffi_api.ExecBuilderEmitCall(self, name, args_, dst) # type: ignore + + def emit_ret(self, result: int) -> None: + """emit a return instruction""" + self._check_scope() + _ffi_api.ExecBuilderEmitRet(self, result) # type: ignore + + def emit_goto(self, pc_offset): + """emit a goto instruction""" + self._check_scope() + _ffi_api.ExecBuilderEmitGoto(self, pc_offset) # type: ignore + + def emit_if(self, cond, false_offset): + """emit an if instruction""" + self._check_scope() + _ffi_api.ExecBuilderEmitIf(self, cond, false_offset) # type: ignore + + def get(self) -> Executable: + """return the executable""" + return Executable(_ffi_api.ExecBuilderGet(self)) # type: ignore diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py new file mode 100644 index 000000000000..fdf98c179b7c --- /dev/null +++ b/python/tvm/relax/expr.py @@ -0,0 +1,706 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import, super-init-not-called +# pylint: disable=redefined-builtin +"""The expression nodes of Relax.""" +import typing +from numbers import Number +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as _np # type: ignore + +import tvm +import tvm._ffi +import tvm.ir +import tvm.relax +from tvm import DataType +from tvm._ffi import base as _base +from tvm.runtime import Object +from tvm.runtime import ndarray as _nd + +from ..ir import BaseFunc, Node, SourceName, Span +from ..runtime import Scriptable, String +from ..tir import PrimExpr +from . import _ffi_api + +# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370 +# This feature is not supported until python 3.10: +# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias +Expr = Union[tvm.ir.RelayExpr] +Type = Union[tvm.ir.Type] +GlobalVar = Union[tvm.ir.GlobalVar] + + +@tvm._ffi.register_object("relax.Id") +class Id(Object): + """Unique identifier(name) used in Var. + Guaranteed to be stable across all passes. + """ + + def __init__(self): + raise RuntimeError("Cannot directly construct Id") + + +# NOTE: place base struct info in expr to avoid cyclic dep +# from expr to struct info. +class StructInfo(Node, Scriptable): + """The base class of all StructInfo. + + StructInfo contains both the static type + and runtime structural information. + """ + + def __eq__(self, other): + """Compare two struct info for structural equivalence.""" + return tvm.ir.structural_equal(self, other) + + def __ne__(self, other): + return not self.__eq__(other) + + def same_as(self, other): + """Overload with structural equality.""" + return super().__eq__(other) + + def is_base_of(self, derived: "StructInfo") -> bool: + """Check if self is base of another derived struct info. + + Parameters + ---------- + derived : StructInfo + The derived struct info to be checked. + + Returns + ------- + result : bool + The check result. + """ + return _ffi_api.StructInfoIsBaseOf(self, derived) # type: ignore + + +# will be registered afterwards in python/tvm/relax/op/init.py +_op_ffi_api = None + + +def _binary_op_helper(lhs: "ExprWithOp", rhs: "ExprWithOp", op: Callable) -> "ExprWithOp": + if not isinstance(lhs, Expr): # type: ignore + raise ValueError("lhs must be Expr") + if isinstance(rhs, Expr): # type: ignore + return op(lhs, rhs) + elif isinstance(rhs, Number): + raise TypeError(f"Please convert {rhs} with `const` first") + else: + raise TypeError(f"type {type(rhs)} not supported") + + +def _binary_rhs_helper(rhs: "ExprWithOp") -> "ExprWithOp": + if isinstance(rhs, Number): + raise TypeError(f"Please convert {rhs} with `const` first") + raise TypeError(f"type {type(rhs)} not supported") + + +class ExprWithOp(Expr, Scriptable): + """Basetype of all relax expressions that defines op overloading.""" + + def astype(self, dtype: Union[str, DataType]) -> "ExprWithOp": + """Cast the content type of the current data to dtype. + + Parameters + ---------- + dtype : str + The target data type. + + Note + ---- + This function only works for TensorType Exprs. + + Returns + ------- + result : ExprWithOp + The result expression. + """ + return _op_ffi_api.astype(self, dtype) # type: ignore + + def __neg__(self) -> "ExprWithOp": + return _op_ffi_api.negative(self) # type: ignore + + def __lt__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.less) # type: ignore + + def __gt__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.greater) # type: ignore + + def __ge__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.greater_equal) # type: ignore + + def __le__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.less_equal) # type: ignore + + # NOTE: Cannot override __eq__ and __ne__, which will influence object equal + + def __add__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.add) # type: ignore + + def __radd__(self, other: Expr) -> "ExprWithOp": + return self.__add__(other) + + def __sub__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.subtract) # type: ignore + + def __rsub__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __mul__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.multiply) # type: ignore + + def __rmul__(self, other: Expr) -> "ExprWithOp": + return self.__mul__(other) + + def __truediv__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.divide) # type: ignore + + def __rtruediv__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __floordiv__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.floor_divide) # type: ignore + + def __rfloordiv__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __mod__(self, other: Expr) -> "ExprWithOp": + # TODO(siyuan): Support it after mod operator is supported in relax + raise ValueError("relax.mod is not supported yet.") + + def __rmod__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __pow__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.power) # type: ignore + + def __rpow__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __call__(self, *args: List[Expr], attrs: Optional[Dict[str, Any]] = None) -> "ExprWithOp": + """Call the variable (if it represents a function). + + Parameters + ---------- + args: List[Expr] + The arguments to the call. + + attr: Optional[Dict[str, object]] + The additional attributes to the call. + + Returns + ------- + call: ExprWithOp + A call taking the variable as a function. + """ + return Call(self, args, attrs=attrs) + + def __getitem__(self, index: int) -> "ExprWithOp": + """Get the i-th element of the tuple or Expr with TupleType. + + Parameters + ---------- + index: int + The index of the element to be retrieved. + + Note + ---- + This function will be overridden by Tuple and ShapeExpr + + Returns + ------- + result: ExprWithOp + The result expression. + """ + return TupleGetItem(self, index) + + +@tvm._ffi.register_object("relax.expr.Call") +class Call(ExprWithOp): + """Function call node in Relax. + + Call node corresponds the operator application node + in computational graph terminology. + + Parameters + ---------- + op: tvm.ir.Op or any tvm.relax.Expr with function type. + The operation to be called. + + args: Union[List[Expr], typing.Tuple[Expr, ...]] + The arguments to the call. + + attrs: Optional[tvm.ir.Attrs] + Attributes to the call, can be None + + sinfo_args: Optional[Union[List[StructInfo], typing.Tuple[StructInfo, ...]]] + The structure info arguments of a CallNode. + sinfo_args is designed to be non-empty only for intrinsic op (e.g., + call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main + usage of structure info inference. + + span: Optional[Span] + Span that points to original source code + """ + + def __init__( + self, + op: Union[Expr, tvm.ir.Op], + args: Union[List[Expr], typing.Tuple[Expr, ...]], + attrs: Optional[tvm.ir.Attrs] = None, + sinfo_args: Optional[Union[List[StructInfo], typing.Tuple[StructInfo, ...]]] = None, + span: Optional[Span] = None, + ): + if not sinfo_args: + sinfo_args = [] + self.__init_handle_by_constructor__( + _ffi_api.Call, op, args, attrs, sinfo_args, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.expr.If") +class If(ExprWithOp): + """A conditional expression in Relax. + + Parameters + ---------- + cond: Expr + The condition. + + true_branch: Expr + The expression evaluated when condition is true. + + false_branch: Expr + The expression evaluated when condition is false. + """ + + def __init__(self, cond: Expr, true_branch: Expr, false_branch: Expr, span: Span = None): + self.__init_handle_by_constructor__( + _ffi_api.If, cond, true_branch, false_branch, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.expr.Tuple") +class Tuple(ExprWithOp): + """Tuple expression that groups several fields together. + + Parameters + ---------- + fields : Union[List[Expr], typing.Tuple[Expr, ...]] + The fields in the tuple. + + span: Optional[Span] + Span that points to original source code + """ + + def __init__(self, fields: Union[List[Expr], typing.Tuple[Expr, ...]], span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.Tuple, fields, span) # type: ignore + + def __getitem__(self, index: int) -> Expr: + if index >= len(self) or index < -len(self): + raise IndexError("Tuple index out of range") + return self.fields[index] + + def __len__(self) -> int: + return len(self.fields) + + +@tvm._ffi.register_object("relax.expr.TupleGetItem") +class TupleGetItem(ExprWithOp): + """Get index-th item from a tuple. + + Parameters + ---------- + tuple_value: Expr + The input tuple expression. + + index: int + The index. + """ + + def __init__(self, tuple_value: Expr, index: int): + self.__init_handle_by_constructor__( + _ffi_api.TupleGetItem, tuple_value, index # type: ignore + ) + + +@tvm._ffi.register_object("relax.expr.ShapeExpr") +class ShapeExpr(ExprWithOp): + """A shape expression which allows users to construct a shape containing PrimExpr.""" + + values: List[PrimExpr] + + def __init__( + self, + values: Union[List[PrimExpr], typing.Tuple[PrimExpr, ...], tvm.ir.Array], + span: Span = None, + ) -> None: + self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values, span) # type: ignore + + def __getitem__(self, index): + if index >= len(self) or index < -len(self): + raise IndexError("ShapeExpr index out of range") + return self.values[index] + + def __len__(self): + return len(self.values) + + +def make_shape(shape: Union[List[Any], typing.Tuple[Any, ...]]) -> ShapeExpr: + if isinstance(shape, (list, tuple)): + return ShapeExpr(shape) + raise ValueError("Wrong type") + + +@tvm._ffi.register_object("relax.expr.Constant") +class Constant(ExprWithOp): + def __init__(self, data: tvm.nd.NDArray, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.Constant, data, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.Var") +class Var(ExprWithOp): + """The variable class for all Relax bindings.""" + + vid: Id + struct_info: Optional[StructInfo] + + def __init__( + self, + name_hint: Union[str, Id], + struct_info: Optional[StructInfo] = None, + span: Span = None, + ) -> None: + if struct_info is not None: + struct_info = tvm.runtime.convert_to_object(struct_info) + if not isinstance(struct_info, StructInfo): + raise TypeError( + "struct_info needs to be an instance of StructInfo. " + "If you attempt to pass in shape, " + "use relax.TensorStructInfo(shape, dtype)." + ) + self.__init_handle_by_constructor__( + _ffi_api.Var if isinstance(name_hint, str) else _ffi_api.VarFromId, # type: ignore + name_hint, + struct_info, + span, + ) + + @property + def name_hint(self): + """Get name hint of the current var.""" + name = str(self.vid.name_hint) + return name + + +@tvm._ffi.register_object("relax.expr.DataflowVar") +class DataflowVar(Var): + """A sub-type of the variable node used to mark dataflow variables from + normal visible "function local" bindings.""" + + vid: Id + struct_info: Optional[StructInfo] + + def __init__( + self, + name_hint: Union[str, Id], + struct_info: Optional[StructInfo] = None, + span: Span = None, + ) -> None: + if struct_info is not None: + struct_info = tvm.runtime.convert_to_object(struct_info) + if not isinstance(struct_info, StructInfo): + raise TypeError( + "struct_info needs to be an instance of StructInfo. " + "If you attempt to pass in shape, " + "use relax.TensorStructInfo(shape, dtype)." + ) + + self.__init_handle_by_constructor__( + _ffi_api.DataflowVar # type: ignore + if isinstance(name_hint, str) + else _ffi_api.DataflowVarFromId, # type: ignore + name_hint, + struct_info, + span, + ) + + +@tvm._ffi.register_object("relax.expr.PrimValue") +class PrimValue(Expr, Scriptable): + """The prim expr representing the value.""" + + value: PrimExpr + + def __init__(self, value: Union[PrimExpr, int], span: Span = None) -> None: + if isinstance(value, int): + value = tvm.tir.IntImm("int64", value) + self.__init_handle_by_constructor__(_ffi_api.PrimValue, value, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.StringImm") +class StringImm(Expr, Scriptable): + """Represent a string literal constant.""" + + value: str + + def __init__(self, value: str, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.DataTypeImm") +class DataTypeImm(Expr, Scriptable): + """Represent a data type constant.""" + + value: DataType + + def __init__(self, value: Union[DataType, str], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.DataTypeImm, value, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.Binding") +class Binding(Node, Scriptable): + """The base class of a binding in Relax.""" + + +@tvm._ffi.register_object("relax.expr.MatchCast") +class MatchCast(Binding): + """Runtime-match the value to the struct info. + + This operation does runtime check, populates the un-defined symbolic shape vars + and vars in struct_info in the first occurrence, and insert equality assertions in + other cases. + + Parameters + ---------- + var: Var + The return variable that the match cast bind to. + + value: Expr + The input value expression. + + struct_info: tvm.relax.StructInfo + The struct info to match cast to. + """ + + var: Var + struct_info: "tvm.relax.StructInfo" + value: Expr + + def __init__( + self, var: Var, value: Expr, struct_info: "tvm.relax.StructInfo", span: Span = None + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.MatchCast, var, value, struct_info, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.expr.VarBinding") +class VarBinding(Binding): + """Variable binding, bind he variable of the lhs with the rhs.""" + + var: Var + value: Expr + + def __init__(self, var: Var, value: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.BindingBlock") +class BindingBlock(Node, Scriptable): + """base class of binding block, bindings inside can be impure + (with side effect or control flow)""" + + bindings: List[Binding] + + def __init__(self, bindings: List[Binding], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.DataflowBlock") +class DataflowBlock(BindingBlock): + """dataflow block, bindings inside are pure (no side effect and no control flow)""" + + def __init__(self, bindings: List[Binding], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.DataflowBlock, bindings, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.SeqExpr") +class SeqExpr(ExprWithOp): + """A sequence of binding blocks followed by an expression.""" + + blocks: List[BindingBlock] + body: Expr + + def __init__(self, blocks: List[BindingBlock], body: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.Function") +class Function(BaseFunc, Scriptable): + """A Relax function.""" + + params: List[Var] + body: Expr + ret_struct_info: StructInfo + attrs: Optional[tvm.ir.DictAttrs] + + def __init__( + self, + params: List[Var], + body: Expr, + ret_struct_info: Optional[StructInfo] = None, + attrs: Optional[tvm.ir.DictAttrs] = None, + span: Optional[Span] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.Function, params, body, ret_struct_info, attrs, span # type: ignore + ) + + @staticmethod + def create_empty( + params: List[Var], + ret_struct_info: StructInfo, + attrs: Optional[tvm.ir.DictAttrs] = None, + span: Optional[Span] = None, + ): + """Construct a relax.Function but without body""" + return _ffi_api.FunctionCreateEmpty(params, ret_struct_info, attrs, span) # type: ignore + + def __call__(self, *args): + """Invoke the global function. + + Parameters + ---------- + args: List[relax.Expr] + Arguments. + """ + return Call(self, args, None, None) + + +@tvm._ffi.register_object("relax.expr.ExternFunc") +class ExternFunc(BaseFunc): + """extern function, which represents a PackedFunc.""" + + global_symbol: String + + def __init__(self, global_symbol: String, span: Span = None) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ExternFunc, global_symbol, span # type: ignore + ) + + +def extern(name: str, span: Span = None): + """Create extern function.""" + return ExternFunc(name, span) + + +def const( + value: Union[bool, int, float, _np.ndarray, tvm.nd.NDArray], dtype: Optional[str] = None +) -> Constant: + """Create a constant value. + + Parameters + ---------- + value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray] + The constant value. + + dtype: Optional[str] + The data type of the resulting constant. + + Note + ---- + When dtype is None, we use the following rule: + + - int maps to "int32" + - float maps to "float32" + - bool maps to "bool" + - other using the same default rule as numpy. + """ + if isinstance(value, (_base.numeric_types, (bool, list))): + value = _np.array(value, dtype=dtype) + + if not dtype: + # when dtype is None: int maps to "int32", float maps to "float32" + dtype = { # type: ignore + _np.dtype("int64"): _np.int32, # type: ignore + _np.dtype("float64"): _np.float32, # type: ignore + }.get( + value.dtype, None # type: ignore + ) + + if isinstance(value, (_np.ndarray, _np.generic)): + if dtype is not None: + value = value.astype(dtype) + value = _nd.array(value) + + if not isinstance(value, _nd.NDArray): + raise ValueError("value has to be scalar or NDArray") + + return Constant(value) + + +def te_tensor( + value: Expr, tir_var_map: Dict[tvm.tir.Var, tvm.tir.PrimExpr], name: str = "rxplaceholder" +): + """Create a TE tensor from relax expression, with TIR variables in the + tensor shape substituted by the given mapping + + Parameters + ---------- + value : Expr + The relax expression, which is required to have TensorStructInfo. + + tir_var_map : Dict[tvm.tir.Var, tvm.tir.PrimExpr] + The mapping to substitute the TIR variables appeared in the + shape of the input Expr. + + name : str + The name of the created tensor. + """ + return _ffi_api.TETensor(value, tir_var_map, name) # type: ignore + + +def get_shape_of(expr: Expr) -> Expr: + """Get shape of expr. + + Parameters + ---------- + expr: Expr + The input expr. + + Returns + ------- + shape: Expr + The shape expression + + Note + ---- + This function requires expr to be normalized. + The function will report an error if expr's StructInfo is not TensorStructInfo. + It will try to return symbolic function when possible. If the tensor do not + have a compile-time symbolic shape, the function will then choose to return + `Call(relax.op.shape_of, [expr])`. + """ + return _ffi_api.GetShapeOf(expr) # type: ignore + + +def _update_struct_info(expr: Expr, struct_info: Optional[StructInfo]) -> None: + _ffi_api.UpdateStructInfo(expr, struct_info) # type: ignore diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py new file mode 100644 index 000000000000..0252720f6ee8 --- /dev/null +++ b/python/tvm/relax/expr_functor.py @@ -0,0 +1,1530 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, arguments-differ +"""The expression functor of Relax.""" +from typing import Callable, Optional + +import tvm +from tvm.ir import Op +from tvm.meta_schedule.utils import derived_object +from tvm.runtime import Object + +from ..ir.module import IRModule +from . import _ffi_api +from .block_builder import BlockBuilder +from .expr import ( + Binding, + BindingBlock, + Call, + Constant, + Id, + DataflowBlock, + DataflowVar, + DataTypeImm, + Expr, + ExternFunc, + Function, + GlobalVar, + If, + MatchCast, + PrimValue, + SeqExpr, + ShapeExpr, + Span, + StringImm, + Tuple, + TupleGetItem, + Var, + VarBinding, +) +from .struct_info import StructInfo + +visitor = derived_object +""" +A decorator to wrap user-customized PyExprVisitor as TVM object _PyExprVisitor. + +Parameters +---------- +visitor_cls : PyExprVisitor + The user-customized PyExprVisitor. + +Returns +------- +cls : _PyExprVisitor + The decorated TVM object _PyExprVisitor(ExprVisitor on the C++ side). + +Example +------- +.. code-block:: python + + @relax.expr_functor.visitor + class MyExprVisitor(PyExprVisitor): + # customize visit function + def visit_call_(self, op: Call) -> None: + # just for demo purposes + ... + # myvisitor is now a special visitor that visit every Call with + # user-customized visit_call_ + myvisitor = MyExprVisitor() + # apply myvisitor to Expr/Binding/BindingBlock/VarDef + myvisitor.visit_expr(expr) + myvisitor.visit_binding(binding) + myvisitor.visit_binding_block(bindingblock) + myvisitor.visit_var_def(var) +""" + +mutator = derived_object +""" +A decorator to wrap user-customized PyExprMutator as TVM object _PyExprMutator. +Note: Cannot override visit function and post-order rewrite at the same time. + +Parameters +---------- +mutator_cls : PyExprMutator + The user-customized PyExprMutator. + +Returns +------- +cls : _PyExprMutator + The decorated TVM object _PyExprMutator(ExprMutator on the C++ side). + +Example +------- +.. code-block:: python + + @relax.expr_functor.mutator + class MyExprMutator(PyExprMutator): + # customize rewrite function + def visit_tuple_(self, op: Tuple) -> Expr: + # just for demo purposes + ... + + # mymutator is now a special mutator that rewrite every Tuple with + # user-customized visit_tuple_ + mymutator = MyExprMutator() + # apply mymutator to Expr/Binding/BindingBlock/VarDef + mymutator.visit_expr(expr) + mymutator.visit_binding(binding) + mymutator.visit_binding_block(bindingblock) + mymutator.visit_var_def(var) +""" + + +class ExprFunctor: + """ + An abstract visitor defined over Expr. + Defines the default dispatch over expressions, and + implements memoization. + """ + + def visit_expr(self, expr: Expr) -> Expr: + """Apply the visitor to an expression.""" + if isinstance(expr, Constant): # type: ignore + ret = self.visit_constant_(expr) + elif isinstance(expr, Tuple): + ret = self.visit_tuple_(expr) + elif isinstance(expr, DataflowVar): + ret = self.visit_dataflow_var_(expr) + elif isinstance(expr, Var): + ret = self.visit_var_(expr) + elif isinstance(expr, ShapeExpr): + ret = self.visit_shape_expr_(expr) + elif isinstance(expr, ExternFunc): + ret = self.visit_extern_func_(expr) + elif isinstance(expr, GlobalVar): # type: ignore + ret = self.visit_global_var_(expr) + elif isinstance(expr, Function): + ret = self.visit_function_(expr) + elif isinstance(expr, Call): # type: ignore + ret = self.visit_call_(expr) + elif isinstance(expr, SeqExpr): + ret = self.visit_seq_expr_(expr) + elif isinstance(expr, If): # type: ignore + ret = self.visit_if_(expr) + elif isinstance(expr, Op): + ret = self.visit_op_(expr) + elif isinstance(expr, TupleGetItem): + ret = self.visit_tuple_getitem_(expr) + elif isinstance(expr, PrimValue): + ret = self.visit_prim_value_(expr) + elif isinstance(expr, StringImm): + ret = self.visit_string_imm_(expr) + elif isinstance(expr, DataTypeImm): + ret = self.visit_data_type_imm_(expr) + else: + raise TypeError("Invalid type: {0}".format(type(expr))) + + return ret + + def visit_constant_(self, op: Constant): + raise NotImplementedError() + + def visit_tuple_(self, op: Tuple): + raise NotImplementedError() + + def visit_dataflow_var_(self, op: DataflowVar): + raise NotImplementedError() + + def visit_var_(self, op: Var): + raise NotImplementedError() + + def visit_shape_expr_(self, op: ShapeExpr): + raise NotImplementedError() + + def visit_extern_func_(self, op: ExternFunc): + raise NotImplementedError() + + def visit_global_var_(self, op: GlobalVar): + raise NotImplementedError() + + def visit_function_(self, op: Function): + raise NotImplementedError() + + def visit_call_(self, op: Call): + raise NotImplementedError() + + def visit_seq_expr_(self, op: SeqExpr): + raise NotImplementedError() + + def visit_if_(self, op: If): + raise NotImplementedError() + + def visit_op_(self, op: Op): + raise NotImplementedError() + + def visit_tuple_getitem_(self, op: TupleGetItem): + raise NotImplementedError() + + def visit_prim_value_(self, op: PrimValue): + raise NotImplementedError() + + def visit_string_imm_(self, op: StringImm): + raise NotImplementedError() + + def visit_data_type_imm_(self, op: DataTypeImm): + raise NotImplementedError() + + def visit_var_binding_(self, binding: VarBinding): + raise NotImplementedError() + + def visit_match_cast_(self, binding: MatchCast): + raise NotImplementedError() + + def visit_binding_block_(self, block: BindingBlock): + raise NotImplementedError() + + def visit_dataflow_block_(self, block: DataflowBlock): + raise NotImplementedError() + + def visit_var_def_(self, var: Var): + raise NotImplementedError() + + def visit_dataflow_var_def_(self, var: DataflowVar): + raise NotImplementedError() + + def visit_binding(self, binding: Binding): + if isinstance(binding, MatchCast): + self.visit_match_cast_(binding) + elif isinstance(binding, VarBinding): + self.visit_var_binding_(binding) + else: + raise TypeError("Invalid type: {0}".format(type(binding))) + + def visit_binding_block(self, block: BindingBlock): + if isinstance(block, DataflowBlock): + self.visit_dataflow_block_(block) + elif isinstance(block, BindingBlock): + self.visit_binding_block_(block) + else: + raise TypeError("Invalid type: {0}".format(type(block))) + + def visit_var_def(self, var: Var): + if isinstance(var, DataflowVar): + self.visit_dataflow_var_def_(var) + elif isinstance(var, Var): + self.visit_var_def_(var) + else: + raise TypeError("Invalid type: {0}".format(type(var))) + + +@tvm._ffi.register_object("expr_functor.PyExprVisitor") +class _PyExprVisitor(Object): + """ + A TVM object to support customization of ExprVisitor on the python side. + This is the decorated result returned from visitor decorator. + + WARNING: This is NOT the user facing class for method overwriting inheritance. + + See also: visitor, PyExprVisitor + """ + + def __init__( + self, + f_visit_expr: Callable = None, + f_visit_constant_: Callable = None, + f_visit_tuple_: Callable = None, + f_visit_var_: Callable = None, + f_visit_dataflow_var_: Callable = None, + f_visit_shape_expr_: Callable = None, + f_visit_extern_func_: Callable = None, + f_visit_global_var_: Callable = None, + f_visit_function_: Callable = None, + f_visit_call_: Callable = None, + f_visit_seq_expr_: Callable = None, + f_visit_if_: Callable = None, + f_visit_op_: Callable = None, + f_visit_tuple_getitem_: Callable = None, + f_visit_prim_value_: Callable = None, + f_visit_string_imm_: Callable = None, + f_visit_data_type_imm_: Callable = None, + f_visit_binding: Callable = None, + f_visit_var_binding_: Callable = None, + f_visit_match_cast_: Callable = None, + f_visit_binding_block: Callable = None, + f_visit_binding_block_: Callable = None, + f_visit_dataflow_block_: Callable = None, + f_visit_var_def: Callable = None, + f_visit_var_def_: Callable = None, + f_visit_dataflow_var_def_: Callable = None, + f_visit_span: Callable = None, + ) -> None: + """Constructor.""" + + self.__init_handle_by_constructor__( + _ffi_api.MakePyExprVisitor, # type: ignore + f_visit_expr, + f_visit_constant_, + f_visit_tuple_, + f_visit_var_, + f_visit_dataflow_var_, + f_visit_shape_expr_, + f_visit_extern_func_, + f_visit_global_var_, + f_visit_function_, + f_visit_call_, + f_visit_seq_expr_, + f_visit_if_, + f_visit_op_, + f_visit_tuple_getitem_, + f_visit_prim_value_, + f_visit_string_imm_, + f_visit_data_type_imm_, + f_visit_binding, + f_visit_var_binding_, + f_visit_match_cast_, + f_visit_binding_block, + f_visit_binding_block_, + f_visit_dataflow_block_, + f_visit_var_def, + f_visit_var_def_, + f_visit_dataflow_var_def_, + f_visit_span, + ) + + def visit_expr(self, expr: Expr) -> None: + """Generic dispatcher for Expr. + + Parameters + ---------- + expr : Expr + The expr to be visited. + """ + return _ffi_api.PyExprVisitorVisitExpr(self, expr) # type: ignore + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + return _ffi_api.PyExprVisitorVisitBinding(self, binding) # type: ignore + + def visit_binding_block(self, block: BindingBlock) -> None: + """Generic dispatcher for BindingBlock. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + """ + return _ffi_api.PyExprVisitorVisitBindingBlock(self, block) # type: ignore + + def visit_var_def(self, var: Var) -> None: + """Generic dispatcher for visiting the var definition site. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + """ + return _ffi_api.PyExprVisitorVisitVarDef(self, var) # type: ignore + + +class PyExprVisitor: + """ + An abstract ExprVisitor with customized methods on the python-side. + This is the user facing class for method overwriting inheritance. + _tvm_metadata discribes the class to inherit("cls"), the methods + that users can overwrite("methods"). + + Note: @relax.expr_functor.visitor is required for proper usage of any inherited class. + + See also: visitor, _PyExprVisitor + + Example: + @relax.expr_functor.visitor + def MyExprVisitor(PyExprVisitor): + ... + """ + + _tvm_metadata = { + "cls": _PyExprVisitor, + "methods": [ + "visit_expr", + "visit_constant_", + "visit_tuple_", + "visit_var_", + "visit_dataflow_var_", + "visit_shape_expr_", + "visit_extern_func_", + "visit_global_var_", + "visit_function_", + "visit_call_", + "visit_seq_expr_", + "visit_if_", + "visit_op_", + "visit_tuple_getitem_", + "visit_prim_value_", + "visit_string_imm_", + "visit_data_type_imm_", + "visit_binding", + "visit_var_binding_", + "visit_match_cast_", + "visit_binding_block", + "visit_binding_block_", + "visit_dataflow_block_", + "visit_var_def", + "visit_var_def_", + "visit_dataflow_var_def_", + "visit_span", + ], + } + + def visit_expr(self, expr: Expr) -> None: + """Generic dispatcher for Expr. + Users can customized this function to overwrite VisitExpr(const Expr& expr) on the C++ side. + + Parameters + ---------- + expr : Expr + The expr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitExpr(self._outer(), expr) # type: ignore + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + Users can customized this function to overwrite VisitBinding(const Binding& binding) + on the C++ side. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitBinding(self._outer(), binding) # type: ignore + + def visit_binding_block(self, block: BindingBlock) -> None: + """Generic dispatcher for BindingBlock. + Users can customized this function to overwrite VisitBindingBlock(const BindingBlock& block) + on the C++ side. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_var_def(self, var: Var) -> None: + """Generic dispatcher for visiting the var definition site. + Users can customized this function to overwrite VisitVarDef(const Var& var) on the C++ side. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitVarDef(self._outer(), var) # type: ignore + + def visit_constant_(self, op: Constant) -> None: + """Visit Constant. + Users can customized this function to overwrite VisitExpr_(const ConstantNode* op) + on the C++ side. + + Parameters + ---------- + op : Constant + The Constant to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_tuple_(self, op: Tuple) -> None: + """Visit Tuple. + Users can customized this function to overwrite VisitExpr_(const TupleNode* op) + on the C++ side. + + Parameters + ---------- + op : Tuple + The Tuple to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_var_(self, op: Var) -> None: + """Visit Var. + Users can customized this function to overwrite VisitExpr_(const VarNode* op) + on the C++ side. + + Parameters + ---------- + op : Var + The Var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_dataflow_var_(self, op: DataflowVar) -> None: + """Visit DataflowVar. + Users can customized this function to overwrite VisitExpr_(const DataflowVarNode* op) + on the C++ side. + + Parameters + ---------- + op : DataflowVar + The DataflowVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_shape_expr_(self, op: ShapeExpr) -> None: + """Visit ShapeExpr. + Users can customized this function to overwrite VisitExpr_(const ShapeExprNode* op) + on the C++ side. + + Parameters + ---------- + op : ShapeExpr + The ShapeExpr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_extern_func_(self, op: ExternFunc) -> None: + """Visit ExternFunc. + Users can customized this function to overwrite VisitExpr_(const ExternFuncNode* op) + on the C++ side. + + Parameters + ---------- + op : ExternFunc + The ExternFunc to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_global_var_(self, op: GlobalVar) -> None: + """Visit GlobalVar. + Users can customized this function to overwrite VisitExpr_(const GlobalVarNode* op) + on the C++ side. + + Parameters + ---------- + op : GlobalVar + The GlobalVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_function_(self, op: Function) -> None: + """Visit Function. + Users can customized this function to overwrite VisitExpr_(const FunctionNode* op) + on the C++ side. + + Parameters + ---------- + op : Function + The Function to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_call_(self, op: Call) -> None: + """Visit Call. + Users can customized this function to overwrite VisitExpr_(const CallNode* op) + on the C++ side. + + Parameters + ---------- + op : Call + The Call to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_seq_expr_(self, op: SeqExpr) -> None: + """Visit SeqExpr. + Users can customized this function to overwrite VisitExpr_(const SeqExprNode* op) + on the C++ side. + + Parameters + ---------- + op : SeqExpr + The SeqExpr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_if_(self, op: If) -> None: + """Visit If. + Users can customized this function to overwrite VisitExpr_(const IfNode* op) + on the C++ side. + + Parameters + ---------- + op : If + The If to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_op_(self, op: Op) -> None: + """Visit Op. + Users can customized this function to overwrite VisitExpr_(const OpNode* op) + on the C++ side. + + Parameters + ---------- + op : Op + The Op to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_tuple_getitem_(self, op: TupleGetItem) -> None: + """Visit TupleGetItem. + Users can customized this function to overwrite VisitExpr_(const TupleGetItemNode* op) + on the C++ side. + + Parameters + ---------- + op : TupleGetItem + The TupleGetItem to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_prim_value_(self, op: PrimValue) -> None: + """Visit PrimValue. + Users can customized this function to overwrite VisitExpr_(const PrimValueNode* op) + on the C++ side. + + Parameters + ---------- + op : PrimValue + The PrimValue to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_string_imm_(self, op: StringImm) -> None: + """Visit StringImm. + Users can customized this function to overwrite VisitExpr_(const StringImmNode* op) + on the C++ side. + + Parameters + ---------- + op : StringImm + The StringImm to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_data_type_imm_(self, op: DataTypeImm) -> None: + """Visit DataTypeImm. + Users can customized this function to overwrite VisitExpr_(const DataTypeImmNode* op) + on the C++ side. + + Parameters + ---------- + op : DataTypeImm + The DataTypeImm to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Visit VarBinding. + Users can customized this function to overwrite VisitBinding_(const VarBindingNode* binding) + on the C++ side. + + Parameters + ---------- + binding : VarBinding + The VarBinding to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBinding(self._outer(), binding) # type: ignore + + def visit_match_cast_(self, binding: MatchCast) -> None: + """Visit MatchCast. + Users can customized this function to overwrite VisitBinding_(const MatchCastNode* binding) + on the C++ side. + + Parameters + ---------- + binding : MatchCast + The MatchCast to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBinding(self._outer(), binding) # type: ignore + + def visit_binding_block_(self, block: BindingBlock) -> None: + """Visit BindingBlock. + Users can customized this function to overwrite VisitBindingBlock_(const BindingBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : BindingBlock + The BindingBlock to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + """Visit DataflowBlock. + Users can customized this function to overwrite VisitBindingBlock_(const DataflowBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : DataflowBlock + The DataflowBlock to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_var_def_(self, var: Var) -> None: + """Visit the Var definition site. + Users can customized this function to overwrite VisitVarDef_(const VarNode* var) + on the C++ side. + + Parameters + ---------- + var : Var + The Var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitVarDef(self._outer(), var) # type: ignore + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + """Visit the DataflowVar definition site. + Users can customized this function to overwrite VisitVarDef_(const DataflowVarNode* var) + on the C++ side. + + Parameters + ---------- + var : DataflowVar + The DataflowVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitVarDef(self._outer(), var) # type: ignore + + def visit_span(self, span: Span) -> None: + """Visit Span. + Users can customized this function to overwrite VisitSpan(const Span& span) on the C++ side. + + Parameters + ---------- + span : Span + The Span to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitSpan(self._outer(), span) # type: ignore + + +@tvm._ffi.register_object("expr_functor.PyExprMutator") +class _PyExprMutator(Object): + """ + A TVM object to support customization of ExprMutator on the python side. + This is the decorated result returned from mutator decorator. + + WARNING: This is NOT the user facing class for method overwriting inheritance. + + See also: mutator, PyExprmutator + """ + + def __init__( + self, + builder: BlockBuilder = None, + f_visit_expr: Callable = None, + f_visit_constant_: Callable = None, + f_visit_tuple_: Callable = None, + f_visit_var_: Callable = None, + f_visit_dataflow_var_: Callable = None, + f_visit_shape_expr_: Callable = None, + f_visit_extern_func_: Callable = None, + f_visit_global_var_: Callable = None, + f_visit_function_: Callable = None, + f_visit_call_: Callable = None, + f_visit_seq_expr_: Callable = None, + f_visit_if_: Callable = None, + f_visit_op_: Callable = None, + f_visit_tuple_getitem_: Callable = None, + f_visit_prim_value_: Callable = None, + f_visit_string_imm_: Callable = None, + f_visit_data_type_imm_: Callable = None, + f_visit_binding: Callable = None, + f_visit_var_binding_: Callable = None, + f_visit_match_cast_: Callable = None, + f_visit_binding_block: Callable = None, + f_visit_binding_block_: Callable = None, + f_visit_dataflow_block_: Callable = None, + f_visit_var_def: Callable = None, + f_visit_var_def_: Callable = None, + f_visit_dataflow_var_def_: Callable = None, + f_visit_span: Callable = None, + ) -> None: + """Constructor.""" + + self.__init_handle_by_constructor__( + _ffi_api.MakePyExprMutator, # type: ignore + builder, + f_visit_expr, + f_visit_constant_, + f_visit_tuple_, + f_visit_var_, + f_visit_dataflow_var_, + f_visit_shape_expr_, + f_visit_extern_func_, + f_visit_global_var_, + f_visit_function_, + f_visit_call_, + f_visit_seq_expr_, + f_visit_if_, + f_visit_op_, + f_visit_tuple_getitem_, + f_visit_prim_value_, + f_visit_string_imm_, + f_visit_data_type_imm_, + f_visit_binding, + f_visit_var_binding_, + f_visit_match_cast_, + f_visit_binding_block, + f_visit_binding_block_, + f_visit_dataflow_block_, + f_visit_var_def, + f_visit_var_def_, + f_visit_dataflow_var_def_, + f_visit_span, + ) + + def visit_expr(self, expr: Expr) -> Expr: + """Generic dispatcher for Expr. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation. + """ + return _ffi_api.PyExprMutatorVisitExpr(self, expr) # type: ignore + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + return _ffi_api.PyExprMutatorVisitBinding(self, binding) # type: ignore + + def visit_binding_block(self, block: BindingBlock) -> BindingBlock: + """Generic dispatcher for BindingBlock. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation. + """ + return _ffi_api.PyExprMutatorVisitBindingBlock(self, block) # type: ignore + + def visit_var_def(self, var: Var) -> Var: + """Generic dispatcher for visiting the var definition site. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + return _ffi_api.PyExprMutatorVisitVarDef(self, var) # type: ignore + + +class PyExprMutator: + """ + An abstract ExprMutator with customized methods on the python-side. + This is the user facing class for method overwriting inheritance. + _tvm_metadata discribes the class to inherit("cls"), the methods that users can + overwrite("methods"), the constructor's parameters("fields") + + Note: @relax.expr_functor.mutator is required for proper usage of any inherited class. + + See also: visitor, _PyExprVisitor + + Example: + @relax.expr_functor.mutator + def MyExprMutator(PyExprMutator): + ... + """ + + _tvm_metadata = { + "cls": _PyExprMutator, + "fields": ["builder_"], + "methods": [ + "visit_expr", + "visit_constant_", + "visit_tuple_", + "visit_var_", + "visit_dataflow_var_", + "visit_shape_expr_", + "visit_extern_func_", + "visit_global_var_", + "visit_function_", + "visit_call_", + "visit_seq_expr_", + "visit_if_", + "visit_op_", + "visit_tuple_getitem_", + "visit_prim_value_", + "visit_string_imm_", + "visit_data_type_imm_", + "visit_binding", + "visit_var_binding_", + "visit_match_cast_", + "visit_binding_block", + "visit_binding_block_", + "visit_dataflow_block_", + "visit_var_def", + "visit_var_def_", + "visit_dataflow_var_def_", + "visit_span", + ], + } + + def __init__(self, mod: Optional[IRModule] = None) -> None: + """Constructor""" + self.builder_ = BlockBuilder(mod) + + def visit_expr(self, expr: Expr) -> Expr: + """Generic dispatcher for Expr. + Users can customized this function to overwrite VisitExpr(const Expr& expr) on the C++ side. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitExpr(self._outer(), expr) # type: ignore + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + Users can customized this function to overwrite VisitBinding(const Binding& binding) + on the C++ side. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitBinding(self._outer(), binding) # type: ignore + + def visit_binding_block(self, block: BindingBlock) -> BindingBlock: + """Generic dispatcher for BindingBlock. + Users can customized this function to overwrite VisitBindingBlock(const BindingBlock& block) + on the C++ side. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_var_def(self, var: Var) -> Var: + """Generic dispatcher for visiting the var definition site. + Users can customized this function to overwrite VisitVarDef(const Var& var) on the C++ side. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + + Returns + ------- + result: Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitVarDef(self._outer(), var) # type: ignore + + def visit_constant_(self, op: Constant) -> Expr: + """Visit Constant. + Users can customized this function to overwrite VisitExpr_(const ConstantNode* op) + on the C++ side. + + Parameters + ---------- + op : Constant + The Constant to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_tuple_(self, op: Tuple) -> Expr: + """Visit Tuple. + Users can customized this function to overwrite VisitExpr_(const TupleNode* op) + on the C++ side. + + Parameters + ---------- + op : Tuple + The Tuple to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_var_(self, op: Var) -> Expr: + """Visit Var. + Users can customized this function to overwrite VisitExpr_(const VarNode* op) + on the C++ side. + + Parameters + ---------- + op : Var + The Var to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_dataflow_var_(self, op: DataflowVar) -> Expr: + """Visit DataflowVar. + Users can customized this function to overwrite VisitExpr_(const DataflowVarNode* op) + on the C++ side. + + Parameters + ---------- + op : DataflowVar + The DataflowVar to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_shape_expr_(self, op: ShapeExpr) -> Expr: + """Visit ShapeExpr. + Users can customized this function to overwrite VisitExpr_(const ShapeExprNode* op) + on the C++ side. + + Parameters + ---------- + op : ShapeExpr + The ShapeExpr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_extern_func_(self, op: ExternFunc) -> Expr: + """Visit ExternFunc. + Users can customized this function to overwrite VisitExpr_(const ExternFuncNode* op) + on the C++ side. + + Parameters + ---------- + op : ExternFunc + The ExternFunc to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_global_var_(self, op: GlobalVar) -> Expr: + """Visit GlobalVar. + Users can customized this function to overwrite VisitExpr_(const GlobalVarNode* op) + on the C++ side. + + Parameters + ---------- + op : GlobalVar + The GlobalVar to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_function_(self, op: Function) -> Expr: + """Visit Function. + Users can customized this function to overwrite VisitExpr_(const FunctionNode* op) + on the C++ side. + + Parameters + ---------- + op : Function + The Function to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_call_(self, op: Call) -> Expr: + """Visit Call. + Users can customized this function to overwrite VisitExpr_(const CallNode* op) + on the C++ side. + + Parameters + ---------- + op : Call + The Call to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_seq_expr_(self, op: SeqExpr) -> Expr: + """Visit SeqExpr. + Users can customized this function to overwrite VisitExpr_(const SeqExprNode* op) + on the C++ side. + + Parameters + ---------- + op : SeqExpr + The SeqExpr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_if_(self, op: If) -> Expr: + """Visit If. + Users can customized this function to overwrite VisitExpr_(const IfNode* op) + on the C++ side. + + Parameters + ---------- + op : If + The If to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_op_(self, op: Op) -> Expr: + """Visit Op. + Users can customized this function to overwrite VisitExpr_(const OpNode* op) + on the C++ side. + + Parameters + ---------- + op : Op + The Op to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr: + """Visit TupleGetItem. + Users can customized this function to overwrite VisitExpr_(const TupleGetItemNode* op) + on the C++ side. + + Parameters + ---------- + op : TupleGetItem + The TupleGetItem to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_prim_value_(self, op: PrimValue) -> Expr: + """Visit PrimValue. + Users can customized this function to overwrite VisitExpr_(const PrimValueNode* op) + on the C++ side. + + Parameters + ---------- + op : PrimValue + The PrimValue to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_string_imm_(self, op: StringImm) -> Expr: + """Visit StringImm. + Users can customized this function to overwrite VisitExpr_(const StringImmNode* op) + on the C++ side. + + Parameters + ---------- + op : StringImm + The StringImm to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_data_type_imm_(self, op: DataTypeImm) -> Expr: + """Visit DataTypeImm. + Users can customized this function to overwrite VisitExpr_(const DataTypeImmNode* op) + on the C++ side. + + Parameters + ---------- + op : DataTypeImm + The DataTypeImm to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Visit VarBinding. + Users can customized this function to overwrite VisitBinding_(const VarBindingNode* binding) + on the C++ side. + + Parameters + ---------- + binding : VarBinding + The VarBinding to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBinding(self._outer(), binding) # type: ignore + + def visit_match_cast_(self, binding: MatchCast) -> None: + """Visit MatchCast. + Users can customized this function to overwrite VisitBinding_(const MatchCastNode* binding) + on the C++ side. + + Parameters + ---------- + binding : MatchCast + The MatchCast to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBinding(self._outer(), binding) # type: ignore + + def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: + """Visit BindingBlock. + Users can customized this function to overwrite VisitBindingBlock_(const BindingBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : BindingBlock + The BindingBlock to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_dataflow_block_(self, block: DataflowBlock) -> BindingBlock: + """Visit DataflowBlock. + Users can customized this function to overwrite VisitBindingBlock_(const DataflowBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : DataflowBlock + The DataflowBlock to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_var_def_(self, var: Var) -> Var: + """Visit the Var definition site. + Users can customized this function to overwrite VisitVarDef_(const VarNode* var) + on the C++ side. + + Parameters + ---------- + var : Var + The Var to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitVarDef(self._outer(), var) # type: ignore + + def visit_dataflow_var_def_(self, var: DataflowVar) -> Var: + """Visit the DataflowVar definition site. + Users can customized this function to overwrite VisitVarDef_(const DataflowVarNode* var) + on the C++ side. + + Parameters + ---------- + var : DataflowVar + The DataflowVar to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitVarDef(self._outer(), var) # type: ignore + + def visit_span(self, span: Span) -> Span: + """Visit Span. + Users can customized this function to overwrite VisitSpan(const Span& span) on the C++ side. + + Parameters + ---------- + span : Span + The Span to be visited. + + Returns + ------- + result : Span + The span after transformation. + """ + raise NotImplementedError + + def visit_expr_post_order(self, expr: Expr) -> Expr: + """Post-order rewrite an Expr and normalize. + + Parameters + ---------- + expr : Expr + The Expr to be rewritten. + + Returns + ------- + result : Expr + The Expr after post-order rewritten. + """ + return _ffi_api.PyExprMutatorVisitExprPostOrder(self._outer(), expr) # type: ignore + + def set_var_remap(self, vid: Id, var: Var) -> None: + """Remap a var to a new var in use-site. + + Parameters + ---------- + vid : Id + The vid of the old var. + var : Var + The new var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorSetVarRemap(self._outer(), vid, var) # type: ignore + + def get_var_remap(self, vid: Id) -> Var: + """Remap a var to a new var in use-site. + + Parameters + ---------- + vid : Id + The vid of the old var + + Returns + ------- + var : Var + The remapped var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorGetVarRemap(self._outer(), vid) # type: ignore + + def visit_with_new_scope(self, expr: Expr) -> Expr: + """Rewrite the expr with a new scope, used in a Function's body and the branches of If. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + var : Var + The expr after visiting. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitWithNewScope(self._outer(), expr) # type: ignore + + def lookup_binding(self, var: Var) -> Optional[Expr]: + """Look up the value bound to a variable. + Note: For function parameters, this function returns NullOpt. + + Parameters + ---------- + var : Var + The var to be looked up. + + Returns + ------- + var : Var + The value bound to the input var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorLookupBinding(self._outer(), var) # type: ignore + + def with_struct_info(self, var: Var, struct_info: StructInfo) -> Var: + """Create a new var with specified shape and type if the original var's shape or type does + not match with the specified ones. + + Parameters + ---------- + var : Var + The var to be updated. + struct_info : StructInfo + The struct info. + + Returns + ------- + var : Var + The var filled with shape and type. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorWithStructInfo(self._outer(), var, struct_info) # type: ignore diff --git a/python/tvm/relax/frontend/__init__.py b/python/tvm/relax/frontend/__init__.py new file mode 100644 index 000000000000..4baf3195f032 --- /dev/null +++ b/python/tvm/relax/frontend/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Frontends for constructing Relax programs, with the model importers +""" +from .common import detach_params diff --git a/python/tvm/relax/frontend/common.py b/python/tvm/relax/frontend/common.py new file mode 100644 index 000000000000..9904324df40e --- /dev/null +++ b/python/tvm/relax/frontend/common.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Commons for Relax frontend.""" +from typing import Dict, List, Tuple + +import tvm + + +def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.nd.NDArray]]]: + """Detach the attribute "params" in the functions of the input IRModule as + separate dictionary of params. + + Parameters + ---------- + mod : tvm.IRModule + The IRModule whose functions' "param" attribute is going to be detached. + + Returns + ------- + detached_mod : tvm.IRModule + The IRModule after the detachment. + + params_dict : Dict[str, List[tvm.nd.NDArray]] + The detached params. The dict keys corresponds to the names of the + functions in the input IRModule that have attribute "params". + """ + detached_mod = tvm.IRModule() + params_dict = dict() + for gv, func in mod.functions.items(): + if func.attrs is not None and "params" in func.attrs: + params = list(func.attrs["params"]) + if not all([isinstance(param, tvm.nd.NDArray) for param in params]): + raise ValueError( + 'The value "params" attribute is expected to be a list of NDArray.' + ) + params_dict[gv.name_hint] = params + detached_mod[gv] = func.without_attr("params") + else: + detached_mod[gv] = func + return detached_mod, params_dict diff --git a/python/tvm/relax/frontend/torch/__init__.py b/python/tvm/relax/frontend/torch/__init__.py new file mode 100644 index 000000000000..55da5a456d6a --- /dev/null +++ b/python/tvm/relax/frontend/torch/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +PyTorch Frontends for constructing Relax programs, with the model importers +""" +from .fx_translator import from_fx +from .dynamo import relax_dynamo, dynamo_capture_subgraphs diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py new file mode 100644 index 000000000000..f48a2cde3c82 --- /dev/null +++ b/python/tvm/relax/frontend/torch/dynamo.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, missing-function-docstring, not-callable +# pylint: disable=import-outside-toplevel, unused-argument +# mypy: ignore-errors +"""PyTorch Dynamo backend of Relax.""" +import functools +from typing import Optional + +import tvm +from tvm.relax import build as relax_build + +from .fx_translator import from_fx + + +def device_from_inputs(example_inputs): + for x in example_inputs: + if hasattr(x, "device"): + return x.device + return None + + +def relax_dynamo(pipeline: Optional[tvm.transform.Pass] = None): + """A helper function to create a relax backend. + + Parameters + ---------- + pipeline : Optional[tvm.transform.Pass] + The pipeline to be applied to the relax module before sent to build. + + Returns + ------- + backend : Callable[[torch.fx.GraphModule, List[torch.Tensor]], Callable] + The relax dynamo backend. + """ + + def _relax_backend(graph_module, example_inputs): + import torch # type: ignore[import] + + assert isinstance(graph_module, torch.fx.GraphModule) + + def to_torch_tensor(nd_tensor): + """A helper function to transfer a NDArray to torch.tensor.""" + if isinstance(nd_tensor, tvm.nd.NDArray): + return torch.from_numpy(nd_tensor.numpy()) + elif isinstance(nd_tensor, tvm.ir.Array): + return tuple(to_torch_tensor(x) for x in nd_tensor) + else: + raise ValueError(f"Unsupported type {type(nd_tensor)}") + + def to_tvm_tensor(torch_tensor): + """A helper function to transfer a torch.tensor to NDArray.""" + if not isinstance(torch_tensor, torch._subclasses.fake_tensor.FakeTensor): + return tvm.nd.array(torch_tensor.numpy()) + # Fake Tensor + real_tensor = torch.randn(torch_tensor.shape, dtype=torch_tensor.dtype) + return tvm.nd.array(real_tensor.numpy()) + + device = device_from_inputs(example_inputs) + input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs] + mod = from_fx(graph_module, input_info) + + if device.type == "cuda": + dev = tvm.cuda(device.index) + target = tvm.target.cuda() + else: + dev = tvm.cpu(0) + target = tvm.target.Target(llvm_target()) + + # invoke optimization pipeline. + if pipeline is None: + # get default pipeline + seq = tvm.relax.get_pipeline() + elif isinstance(pipeline, str): + # lookup by name + seq = tvm.relax.get_pipeline(pipeline) + else: + seq = pipeline + + mod = mod.with_attr("target", target) + mod = seq(mod) + + ex = relax_build(mod, target=target) + + vm = tvm.relax.VirtualMachine(ex.mod, device=dev) + + def exec_tvm(*i_args): + args = [a.contiguous() for a in i_args] + vm_args = list() + for arg in args: + if arg.dim() != 0: + if arg.requires_grad: + arg = arg.detach() + vm_args.append(to_tvm_tensor(arg)) + outputs = vm["main"](*vm_args) + return to_torch_tensor(outputs) + + return exec_tvm + + return _relax_backend + + +def dynamo_capture_subgraphs(model, *params, **kwargs) -> tvm.IRModule: + """Capture subgraphs of the PyTorch model using torch.compile into an IRModule. + + Parameters + ---------- + model : torch.nn.Module + The PyTorch model to be captured. + + params : List[torch.Tensor] + The parameters of the PyTorch model. + + keep_params_as_input : bool + Whether to keep model parameters as input variables of the captured Relax functions. + + Returns + ------- + output : ImporterOutput + The output of translation, including the translated IRModule. + If `keep_params_as_input` is true, the functions in the IRModule have an + attribute "params" that contains the weights of the input model. The + weights can be detached by `relax.frontend.detach_params`. + """ + import torch # type: ignore[import] + from torch import fx # type: ignore[import] + from torch import _dynamo as dynamo # type: ignore[import] + + keep_params_as_input = "keep_params_as_input" in kwargs and kwargs["keep_params_as_input"] + kwargs.pop("keep_params_as_input", None) + mod = tvm.IRModule() + + def _capture(graph_module: fx.GraphModule, example_inputs): + assert isinstance(graph_module, torch.fx.GraphModule) + input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs] + mod_ = from_fx( + graph_module, + input_info, + keep_params_as_input=keep_params_as_input, + unwrap_unit_return_tuple=True, + ) + mod[f"subgraph_{len(mod.get_global_vars())}"] = mod_["main"] + return graph_module.forward + + dynamo.reset() + compiled_model = torch.compile(model, backend=_capture) + compiled_model(*params, **kwargs) + return mod + + +@functools.lru_cache(None) +def llvm_target(): + if "avx512" in open("/proc/cpuinfo").read(): + return "llvm -mcpu=skylake-avx512" + return "llvm -mcpu=core-avx2" diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py new file mode 100644 index 000000000000..c65e94d6916e --- /dev/null +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -0,0 +1,1312 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck +# pylint: disable=import-outside-toplevel +"""PyTorch FX frontend of Relax.""" +from typing import Callable, Dict, List, Optional, Tuple, Union +from functools import reduce + +import tvm +from tvm import relax + + +class TorchFXImporter: + """An importer from PyTorch FX to Relax.""" + + import torch # type: ignore + from torch import fx + + def __init__(self) -> None: + import torch # type: ignore + from torch import fx + + self.env: Dict[fx.node.Node, relax.Expr] = {} + self.params: Dict[torch.Tensor, relax.Expr] = {} + self.named_modules: Dict[str, torch.Module] = None + self.block_builder: relax.BlockBuilder = None + self.create_convert_map() + + ########## Utilities ########## + @staticmethod + def _fetch_attr(model, target: str): + import torch # type: ignore + + target_atoms = target.split(".") + attr_itr = model + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced non existing target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.Tensor): + return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr) + return attr_itr + + @staticmethod + def _convert_data_type(input_type, env: Optional[Dict] = None): + """converts the PyTorch scalar type input_type to a TVM dtype.""" + import torch # type: ignore + + if env is not None and input_type in env: + input_type = env[input_type] + + input_type = input_type.lower() if isinstance(input_type, str) else input_type + if input_type in ["float", "float32", "torch.float32", torch.float32]: + return "float32" + elif input_type in ["float16", "torch.float16", torch.float16]: + return "float16" + elif input_type in ["int64", "torch.int64", torch.int64]: + return "int64" + elif input_type in ["int32", "torch.int32", torch.int32]: + return "int32" + elif input_type in ["bool", "torch.bool", torch.bool]: + return "bool" + else: + raise NotImplementedError("input_type {} is not handled yet".format(input_type)) + + @staticmethod + def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var: + tensor = tensor.detach().cpu() + dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype)) + return relax.const(tensor.data.numpy(), dtype) + + @staticmethod + def shape_of(tensor): + """Get the shape of a tensor.""" + import torch # type: ignore + + if isinstance(tensor, relax.Expr): + if not isinstance(tensor.struct_info, relax.TensorStructInfo): + raise TypeError("The input Expr of shape_of should be a Tensor") + return tensor.struct_info.shape + elif isinstance(tensor, torch.Tensor): + return tensor.shape + raise ValueError("Unsupported type: {}".format(type(tensor))) + + def retrieve_args(self, node): + return self._retrieve_args(node.args) + + def _retrieve_args(self, node): + from torch import fx + + if isinstance(node, fx.node.Node): + return self.env[node] + elif isinstance(node, tuple): + return tuple(self._retrieve_args(x) for x in node) + elif isinstance(node, list): + return [self._retrieve_args(x) for x in node] + elif isinstance(node, dict): + return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} + else: + return node + + @staticmethod + def _promote_binary_op_args(lhs, rhs): + if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + return lhs, rhs + elif isinstance(lhs, relax.Expr): + assert isinstance(lhs.struct_info, relax.TensorStructInfo) + return lhs, relax.const(rhs, lhs.struct_info.dtype) + elif isinstance(rhs, relax.Expr): + assert isinstance(rhs.struct_info, relax.TensorStructInfo) + return relax.const(lhs, rhs.struct_info.dtype), rhs + else: + assert False + + def _call_binary_op(self, op, lhs, rhs): + lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs) + return self.block_builder.emit(op(lhs, rhs)) + + ########## Arithmetic ########## + + def _cos(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.cos(self.env[node.args[0]])) + + def _exp(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.exp(self.env[node.args[0]])) + + def _sin(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.sin(self.env[node.args[0]])) + + def _sigmoid(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]])) + + def _sqrt(self, node: fx.node.Node) -> relax.Expr: + arg = self.env[node.args[0]] + if isinstance(arg, (int, float)): + arg = relax.const(arg, "float32") + return self.block_builder.emit(relax.op.sqrt(arg)) + + def _rsqrt(self, node: fx.node.Node) -> relax.Expr: + arg = self.env[node.args[0]] + if isinstance(arg, (int, float)): + arg = relax.const(arg, "float32") + sqrt = self.block_builder.emit(relax.op.sqrt(arg)) + return self.block_builder.emit( + relax.op.divide(relax.const(1, sqrt.struct_info.dtype), sqrt) + ) + + def _round(self, node: fx.node.Node) -> relax.Expr: + if "decimals" in node.kwargs and node.kwargs["decimals"] != 0: + raise ValueError("specifying decimals for round is not supported yet") + arg = self.env[node.args[0]] + return self.block_builder.emit(relax.op.round(arg)) + + def _add(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.add, lhs, rhs) + elif isinstance(lhs, relax.expr.Constant): + return self._call_binary_op( + relax.op.add, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype) + ) + elif isinstance(rhs, relax.expr.Constant): + return self._call_binary_op( + relax.op.add, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs + ) + return lhs + rhs + + def _max(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.maximum, lhs, rhs) + + def _floordiv(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.floor_divide, lhs, rhs) + return lhs // rhs + + def _mul(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.multiply, lhs, rhs) + return lhs * rhs + + def _pow(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.power, lhs, rhs) + return lhs**rhs + + def _neg(self, node: fx.node.Node) -> relax.Expr: + x = self.env[node.args[0]] + return self.block_builder.emit(relax.op.negative(x)) + + def _sub(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.subtract, lhs, rhs) + return lhs - rhs + + def _truediv(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.divide, lhs, rhs) + return lhs / rhs + + def _clamp(self, node: fx.node.Node) -> relax.Expr: + args = self.retrieve_args(node) + a_min = node.kwargs["min"] + a_max = node.kwargs["max"] + if not isinstance(a_min, (int, float)): + raise ValueError( + f"TVM only supports constant min value for torch.clamp/clip, " + f"but got {a_min} with type {type(a_min)}" + ) + if not isinstance(a_max, (int, float)): + raise ValueError( + f"TVM only supports constant max value for torch.clamp/clip, " + f"but got {a_max} with type {type(a_max)}" + ) + return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + + ########## Compare ########## + + def _lt(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + return self._call_binary_op(relax.op.less, lhs, rhs) + + def _eq(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + return self._call_binary_op(relax.op.equal, lhs, rhs) + + ########## Creation ########## + + def _arange(self, node: fx.node.Node) -> relax.Var: + import torch + import numpy as np + + start_end_step = [None, None, None] + if "start" in node.kwargs: + start_end_step[0] = node.kwargs["start"] + if "end" in node.kwargs: + start_end_step[1] = node.kwargs["end"] + if "step" in node.kwargs: + start_end_step[2] = node.kwargs["step"] + + if len(node.args) == 1: + assert start_end_step[1] is None + start_end_step[1] = node.args[0] + elif len(node.args) == 2: + assert start_end_step[0] is None + assert start_end_step[1] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + elif len(node.args) == 3: + assert start_end_step[0] is None + assert start_end_step[1] is None + assert start_end_step[2] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + start_end_step[2] = node.args[2] + + if start_end_step[0] is None: + start_end_step[0] = 0 + if start_end_step[2] is None: + start_end_step[2] = 1 + + if "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + elif any([isinstance(x, float) for x in start_end_step]): + dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype()) + else: + dtype = "int64" + + return relax.const(np.arange(*start_end_step, dtype=dtype)) + + def _empty(self, node: fx.node.Node) -> relax.Var: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + return self.block_builder.emit(relax.op.zeros(node.args, dtype)) + + def _inplace_fill(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) + filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + self.env[node.args[0]] = filled + return filled + + def _tensor(self, node: fx.node.Node) -> relax.Var: + dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None + if isinstance(node.args[0], float): + return relax.const(node.args[0], dtype if dtype is not None else "float32") + elif isinstance(node.args[0], int): + return relax.const(node.args[0], dtype if dtype is not None else "int64") + raise ValueError("torch.tensor with value not a float or int is not accepted") + + def _tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + return self.block_builder.emit(op(x, k)) + + return convert + + def _inplace_tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + + mutated = self.block_builder.emit(op(x, k)) + self.env[node.args[0]] = mutated + return mutated + + return convert + + def _new_ones(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, self_var.struct_info.dtype), + self_var.struct_info.dtype, + ) + ) + + def _ones(self, node: fx.node.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = args[0] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + dtype = ( + TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + if "dtype" in node.kwargs + else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) + ) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, dtype), + dtype, + ) + ) + + def _full(self, node: fx.node.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = args[0] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + dtype = ( + TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + if "dtype" in node.kwargs + else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) + ) + value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + return self.block_builder.emit( + relax.op.full( + size, + value, + dtype, + ) + ) + + ########## Statistical ########## + + def _sum(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False + if len(args) == 1: + return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.sum(args[0], args[1])) + + def _mean(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False + if len(args) == 1: + return self.block_builder.emit(relax.op.mean(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.mean(args[0], args[1], keepdims=keepdim)) + + ########## DataType ########## + + def _float(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) + + def _half(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + + def _type(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + + def _to(self, node: fx.node.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x + + ########## Linear Algebra ########## + + def _matmul_impl(self, a: relax.Expr, b: relax.Expr): + return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) + + def _matmul(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + res = self._matmul_impl( + args[0], + args[1], + ) + return res + + def _addmm(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + y = self.env[node.args[1]] + z = self.env[node.args[2]] + matmul = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) + return self.block_builder.emit(relax.op.add(x, matmul)) + + def _baddbmm(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + a = self.env[node.args[1]] + b = self.env[node.args[2]] + alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1 + beta = node.kwargs["beta"] if "beta" in node.kwargs else 1 + + res = None + if alpha != 0: + res = self.block_builder.emit(relax.op.matmul(a, b)) + if alpha != 1: + dtype = res.struct_info.dtype + res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) + if beta != 0: + dtype = x.struct_info.dtype + if beta != 1: + bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) + else: + bias = x + res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) + return res + + ########## Manipulation ########## + + def _cat(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.concat(args[0], axis=node.kwargs["dim"])) + + def _expand(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.broadcast_to(args[0], args[1:])) + + def _flatten(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + start_dim = module.start_dim + end_dim = module.end_dim + else: + start_dim = node.args[1] if len(node.args) >= 2 else 0 + end_dim = node.args[2] if len(node.args) == 3 else -1 + shape = self.shape_of(x) + start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim + end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim + flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) + new_shape = ( + [shape[i] for i in range(0, start_dim)] + + [flattened] + + [shape[i] for i in range(end_dim + 1, len(shape))] + ) + return self.block_builder.emit(relax.op.reshape(x, new_shape)) + + def _permute(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) + + def _reshape(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) + + def _split(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + split_size = node.args[1] + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + else: + dim = 0 + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + return self.block_builder.emit(relax.op.split(x, n_section, dim)) + + def _chunk(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + chunks = node.args[1] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 2: + dim = node.args[2] + else: + dim = 0 + return self.block_builder.emit(relax.op.split(x, chunks, dim)) + + def _transpose(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + full_idx = list(range(len(self.shape_of(args[0])))) + full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] + return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) + + def _squeeze(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 1: + dim = node.args[1] + else: + dim = None + return self.block_builder.emit(relax.op.squeeze(x, dim)) + + def _cumsum(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 1: + dim = node.args[1] + else: + dim = None + if "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + else: + dtype = None + if "out" in node.kwargs: + raise ValueError("specifying out for cumsum is not supported yet") + + return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) + + def _index_select(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = self.env[node.args[2]] + return self.block_builder.emit(relax.op.take(x, index, dim)) + + def _masked_fill(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + mask = self.env[node.args[1]] + value = node.args[2] + rx_value = relax.const(value) + values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + return self.block_builder.emit(relax.op.where(mask, values, x)) + + def _inplace_masked_fill(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + mask = self.env[node.args[1]] + value = node.args[2] + rx_value = relax.const(value) + values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + output = self.block_builder.emit(relax.op.where(mask, values, x)) + self.env[node.args[0]] = output + return output + + ########## Search ########## + + def _argmax_argmin(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.node.Node): + x = self.env[node.args[0]] + dim = None + keepdims = False + + if len(node.args) > 1: + dim = node.args[1] + if len(node.args) > 2: + keepdims = node.args[2] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + if "keepdim" in node.kwargs: + keepdims = node.kwargs["keepdim"] + if "keepdims" in node.kwargs: + keepdims = node.kwargs["keepdims"] + + return self.block_builder.emit(op(x, dim, keepdims)) + + return convert + + ########## Neural Network ########## + + def _linear(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = None if module.bias is None else self.params[module.bias] + return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + + def _conv1d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + + conv1d = self.block_builder.emit( + relax.op.nn.conv1d( + x, + weight, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) + + if module.bias is None: + return conv1d + + bias = self.params[module.bias] + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + + return self.block_builder.emit(relax.op.add(conv1d, bias)) + + def _conv2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + + conv2d = self.block_builder.emit( + relax.op.nn.conv2d( + x, + weight, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if module.bias is None: + return conv2d + + bias = self.params[module.bias] + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + + return self.block_builder.emit(relax.op.add(conv2d, bias)) + + def _max_pool2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + kernel = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + ceil_mode = module.ceil_mode + else: + nargs = len(node.args) + kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] + stride = node.args[2] if nargs > 2 else node.kwargs["stride"] + padding = node.args[3] if nargs > 3 else node.kwargs["padding"] + dilation = node.args[4] if nargs > 4 else node.kwargs["dilation"] + ceil_mode = node.args[5] if nargs > 5 else node.kwargs["ceil_mode"] + + stride = kernel if stride is None else stride + + return self.block_builder.emit( + relax.op.nn.max_pool2d( + x, + pool_size=kernel, + strides=stride, + padding=padding, + dilation=dilation, + layout="NCHW", + ceil_mode=ceil_mode, + ) + ) + + def _adaptive_avg_pool2d(self, is_module: bool) -> Callable: + from torch import fx + + def _impl(node: fx.node.Node) -> relax.Var: + if is_module: + module = self.named_modules[node.target] + x = self.env[node.args[0]] + output_size = module.output_size + else: + x = self.env[node.args[0]] + output_size = node.args[1] + return self.block_builder.emit( + relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") + ) + + return _impl + + def _softmax(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + dim = module.dim + else: + nargs = len(node.args) + dim = node.args[1] if nargs > 1 else node.kwargs["dim"] + assert dim is not None + return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + + def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params[module.bias] + dtype = TorchFXImporter._convert_data_type(str(module.running_mean.dtype)) + running_mean = relax.const(module.running_mean.cpu().detach().numpy(), dtype) + running_var = relax.const(module.running_var.cpu().detach().numpy(), dtype) + eps = module.eps + + res_tuple = self.block_builder.emit( + relax.op.nn.batch_norm( + x, + weight, + bias, + running_mean, + running_var, + axis=1, + epsilon=eps, + ) + ) + + return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) + + def _layer_norm(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + import numpy as np # type: ignore + + x = self.env[node.args[0]] + + # functional.layer_norm + if node.target not in self.named_modules: + # static or symbolic + normalized_shape = ( + node.args[1] if type(node.args[1]) == tuple else self.env[node.args[1]] + ) + dim_num = len(normalized_shape) + axes = list(range(-dim_num, 0)) + + gamma = self.env[node.kwargs["weight"]] + beta = node.kwargs["bias"] + if beta is None: + shape_tuple = [int(s) for s in normalized_shape.values] + beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) + else: + beta = self.env[beta] + eps = node.kwargs["eps"] + + return self.block_builder.emit( + relax.op.nn.layer_norm( + x, + gamma, + beta, + axes=axes, + epsilon=eps, + ) + ) + + module = self.named_modules[node.target] + + if module.elementwise_affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) + beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) + dim_num = len(module.normalized_shape) + axes = list(range(-dim_num, 0)) + + return self.block_builder.emit( + relax.op.nn.layer_norm( + x, + gamma, + beta, + axes=axes, + epsilon=module.eps, + ) + ) + + def _group_norm(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + + if module.affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type) + beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type) + + dim = len(self.shape_of(x)) + return self.block_builder.emit( + relax.op.nn.group_norm( + x, + gamma, + beta, + num_groups=module.num_groups, + channel_axis=1, + axes=list(range(2, dim)), + epsilon=module.eps, + ) + ) + + def _embedding(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + x = self.block_builder.emit(relax.op.astype(x, "int32")) + + ndim = x.struct_info.ndim + if ndim == 1: + return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + else: + x_shape = x.struct_info.shape.values + emb_size = weight.struct_info.shape.values[-1] + x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) + embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) + return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) + + def _interpolate(self, node: fx.node.Node) -> relax.Var: + # torch.nn.functional.interpolate( + # input, size=None, scale_factor=None, mode='nearest', align_corners=None, + # recompute_scale_factor=None, antialias=False) + # (TODO) this is a temporary implementation for interpolate that only considers NCHW layout + # it basically replicates the implementation in tvm.relay.frontend.pytorch + data = self.env[node.args[0]] + size = ( + node.args[1] + if len(node.args) > 1 + else (node.kwargs["size"] if "size" in node.kwargs else None) + ) + scale_factor = ( + node.args[2] + if len(node.args) > 2 + else (node.kwargs["scale_factor"] if "scale_factor" in node.kwargs else None) + ) + method = ( + node.args[3] + if len(node.args) > 3 + else (node.kwargs["method"] if "method" in node.kwargs else "nearest") + ) + align_corners = ( + node.args[4] + if len(node.args) > 4 + else (node.kwargs["align_corners"] if "align_corners" in node.kwargs else None) + ) + recompute_scale_factor = ( + node.args[5] + if len(node.args) > 5 + else ( + node.kwargs["recompute_scale_factor"] + if "recompute_scale_factor" in node.kwargs + else None + ) + ) + antialias = ( + node.args[6] + if len(node.args) > 6 + else (node.kwargs["antialias"] if "antialias" in node.kwargs else False) + ) + + assert recompute_scale_factor is None + assert antialias is False + + if size is None: + shape = self.shape_of(data) + assert isinstance(shape, relax.ShapeExpr) + size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + + if method.startswith("nearest"): + method = "nearest_neighbor" + elif method[0:2] == "bi": + method = method[2:] + + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" + + return self.block_builder.emit( + relax.op.image.resize2d( + data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + ) + ) + + ########## Others ########## + + def _size(self, node: fx.node.Node) -> relax.Expr: + x = self.env[node.args[0]] + shape = self.shape_of(x) + if len(node.args) == 1: + assert isinstance(shape, relax.ShapeExpr) + return shape + assert len(node.args) == 2 + idx = node.args[1] + return self.shape_of(x)[idx].value + + def _getattr(self, node: fx.node.Node) -> relax.Var: + if isinstance(self.env[node.args[0]], relax.Expr): + if node.args[1] == "dtype": + return self.env[node.args[0]].struct_info.dtype + elif node.args[1] == "shape": + return self.shape_of(self.env[node.args[0]]) + return getattr(self.env[node.args[0]], node.args[1]) + + def _getitem(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)): + return x[node.args[1]] + elif isinstance(x, relax.Var): + if isinstance(x.struct_info, relax.TupleStructInfo): + return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) + + assert isinstance(x.struct_info, relax.TensorStructInfo) + begin = [] + end = [] + stride = [] + axes = [] + expand_dim = [] + i = 0 + shape = self.shape_of(x) + non_ellipsis_cnt = 0 + for index in node.args[1]: + if isinstance(index, (int, slice)): + non_ellipsis_cnt += 1 + for index in node.args[1]: + if isinstance(index, int): + begin.append(index) + end.append(index + 1) + stride.append(1) + axes.append(i) + i = i + 1 + elif isinstance(index, slice): + begin.append(0 if index.start is None else index.start) + end.append(shape[i] if index.stop is None else index.stop) + stride.append(1 if index.step is None else index.step) + axes.append(i) + i = i + 1 + elif index is None: + expand_dim.append(len(axes) + len(expand_dim)) + elif index is Ellipsis: + for _ in range(len(shape) - non_ellipsis_cnt): + begin.append(0) + end.append(shape[i]) + stride.append(1) + axes.append(i) + i += 1 + else: + raise ValueError("Unsupported index type: " + str(type(index))) + while i < len(shape): + begin.append(0) + end.append(shape[i]) + stride.append(1) + axes.append(i) + i += 1 + sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + sliced_shape = list(self.shape_of(sliced)) + for i in expand_dim: + sliced_shape.insert(i, 1) + return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) + elif isinstance(x, relax.Constant): + dtype = x.struct_info.dtype + return relax.const(x.data.numpy()[node.args[1]], dtype) + else: + assert False + + def create_convert_map(self): + from torch import nn + from torch import fx + + self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node], relax.Var]] = { + # call_module + nn.Linear: self._linear, + nn.Conv1d: self._conv1d, + nn.Conv2d: self._conv2d, + nn.MaxPool2d: self._max_pool2d, + nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True), + nn.Softmax: self._softmax, + nn.ReLU: lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), + nn.ReLU6: lambda node: self.block_builder.emit( + relax.op.clip(self.env[node.args[0]], 0, 6) + ), + nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), + nn.Flatten: self._flatten, + nn.BatchNorm2d: self._batch_norm_2d, + nn.LayerNorm: self._layer_norm, + nn.GroupNorm: self._group_norm, + nn.Dropout: lambda node: self.env[node.args[0]], + nn.modules.sparse.Embedding: self._embedding, + # call_function and call_method + "cos": self._cos, + "exp": self._exp, + "sin": self._sin, + "add": self._add, + "floordiv": self._floordiv, + "mul": self._mul, + "sub": self._sub, + "pow": self._pow, + "sigmoid": self._sigmoid, + "sqrt": self._sqrt, + "round": self._round, + "lt": self._lt, + "eq": self._eq, + "truediv": self._truediv, + "fill_": self._inplace_fill, + "new_ones": self._new_ones, + "arange": self._arange, + "empty": self._empty, + "tensor": self._tensor, + "tril": self._tril_triu(relax.op.tril), + "triu": self._tril_triu(relax.op.triu), + "tril_": self._inplace_tril_triu(relax.op.tril), + "triu_": self._inplace_tril_triu(relax.op.triu), + "sum": self._sum, + "float": self._float, + "half": self._half, + "type": self._type, + "astype": self._type, + "matmul": self._matmul, + "addmm": self._addmm, + "baddbmm": self._baddbmm, + "bmm": self._matmul, + "cat": self._cat, + "expand": self._expand, + "flatten": self._flatten, + "permute": self._permute, + "reshape": self._reshape, + "split": self._split, + "cumsum": self._cumsum, + "chunk": self._chunk, + "transpose": self._transpose, + "squeeze": self._squeeze, + "unsqueeze": lambda node: self.block_builder.emit( + relax.op.expand_dims(self.env[node.args[0]], node.args[1]) + ), + "view": self._reshape, + "argmax": self._argmax_argmin(relax.op.argmax), + "argmin": self._argmax_argmin(relax.op.argmin), + "softmax": self._softmax, + "dropout": lambda node: self.env[node.args[0]], + "clamp": self._clamp, + "relu": lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), + "gelu": lambda node: self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])), + "silu": lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), + "tanh": lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])), + "interpolate": self._interpolate, + "size": self._size, + "getattr": self._getattr, + "getitem": self._getitem, + "contiguous": lambda node: self.env[node.args[0]], + "to": self._to, + "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), + "layer_norm": self._layer_norm, + "index_select": self._index_select, + "masked_fill": self._masked_fill, + "ones": self._ones, + "full": self._full, + "masked_fill_": self._inplace_masked_fill, + "mean": self._mean, + "rsqrt": self._rsqrt, + "neg": self._neg, + "max": self._max, + } + + def from_fx( + self, + model, + input_info: List[Tuple[Tuple[int], str]], + keep_params_as_input: bool, + unwrap_unit_return_tuple: bool, + ) -> tvm.IRModule: + """Convert a PyTorch FX GraphModule to a Relax program.""" + from torch import fx + + self.named_modules = dict(model.named_modules()) + + graph: fx.Graph = model.graph + # Create input variables. + inputs = list() + for idx, (shape, dtype) in enumerate(input_info): + inputs.append( + relax.Var( + f"inp_{idx}", relax.TensorStructInfo(shape, self._convert_data_type(dtype)) + ) + ) + + # Initialize the block builder with a function and a dataflow block. + func_name = "main" + self.block_builder = relax.BlockBuilder() + params = [] + if keep_params_as_input: + func_attrs = {"num_input": len(inputs)} + for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): + shape = param.data.shape + dtype = self._convert_data_type(str(param.data.dtype)) + inputs.append(relax.Var(name, relax.TensorStructInfo(shape, dtype))) + self.params[param] = inputs[-1] + params.append(tvm.nd.array(param.data.cpu().numpy())) + else: + func_attrs = None + + with self.block_builder.function(name=func_name, params=inputs.copy(), attrs=func_attrs): + output = None + with self.block_builder.dataflow(): + # Translate model parameters. + for _, param in model.named_parameters(): + shape = param.data.shape + dtype = self._convert_data_type(str(param.data.dtype)) + if dtype in ("float32", "float16"): + if not keep_params_as_input: + self.params[param] = relax.const(param.data.cpu().numpy(), dtype) + else: + raise ValueError("Unsupported data type for model parameters: %s" % dtype) + # Translate the model. + for node in graph.nodes: + if node.op == "placeholder": + assert len(inputs) > 0, "Provided inputs is less than actual inputs" + self.env[node] = inputs.pop(0) + elif node.op == "output": + args = self.retrieve_args(node) + assert len(args) == 1 + if ( + unwrap_unit_return_tuple + and isinstance(args[0], (tuple, list, relax.Tuple)) + and len(args[0]) == 1 + ): + output = self.block_builder.emit_output(args[0][0]) + else: + output = self.block_builder.emit_output(args[0]) + break + elif node.op == "get_attr": + self.env[node] = TorchFXImporter._fetch_attr(model, node.target) + elif node.op == "call_module": + module = self.named_modules[node.target] + assert ( + type(module) in self.convert_map + ), f"Unsupported module type {type(module)}" + self.env[node] = self.convert_map[type(module)](node) + elif node.op == "call_function": + func_name = node.name.rstrip("0123456789_") + assert ( + func_name in self.convert_map + ), f"Unsupported function type {func_name}" + self.env[node] = self.convert_map[func_name](node) + elif node.op == "call_method": + assert ( + node.target in self.convert_map + ), f"Unsupported function target {node.target}" + self.env[node] = self.convert_map[node.target](node) + else: + raise ValueError(f"Unsupported op {node.op}") + assert output is not None + self.block_builder.emit_func_output(output) + + mod = self.block_builder.get() + if keep_params_as_input: + mod["main"] = mod["main"].with_attr("params", params) + return mod + + +def from_fx( + model, + input_info: List[Tuple[Tuple[int], str]], + *, + keep_params_as_input: bool = False, + unwrap_unit_return_tuple: bool = False, +) -> tvm.IRModule: + """Convert a PyTorch FX GraphModule to a Relax program + + Parameters + ---------- + model : fx.GraphModule + The PyTorch FX GraphModule to convert. + + input_info : List[Tuple[Tuple[int], str]] + A list of shapes and data types of input tensors. + + keep_params_as_input : bool + Whether to keep model parameters as input variables. + + unwrap_unit_return_tuple : bool + A boolean flag indicating if to the return value when it is an unit tuple. + When the return value is not a unit tuple, no unwrap will take place. + + Returns + ------- + output : tvm.IRModule + The import result IRModule, with the function "main" containing the + translated logic. + If `keep_params_as_input` is true, the "main" function have an attribute + "params" that contains the weights of the input model. The weights + can be detached by `relax.frontend.detach_params`. + + Examples + -------- + Users can use the FX tracer or dynamo.export() to extract + a fx.GraphModule from a PyTorch model. The following codes show + how to convert a PyTorch model to a Relax program. + + .. code-block:: python + + # Import the importer. + import numpy as np + import torch + from tvm.relax.frontend.torch_fx import from_fx + from torch import _dynamo as dynamo + + # Define the module + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True) + + def forward(self, input): + return self.linear(input) + + # Instantiate the model and create the input info dict. + torch_model = MyModule() + input_info = [((128, 10), "float32")] + input_tensors = [ + torch.astensor(np.random.randn(*shape).astype(dtype)) + for shape, dtype in input_info + ] + + # Use FX tracer to trace the PyTorch model. + graph_module = fx.symbolic_trace(torch_model) + + # Use the dynamo.export() to export the PyTorch model to FX. + try: + graph_module = dynamo.export(torch_model, *input_tensors) + except: + raise RuntimeError("Failed to export the PyTorch model to FX.") + + # Use the importer to import the PyTorch model to Relax. + mod: tvm.IRModule = from_fx(graph_module, input_info) + + # Print out the imported model. + print(mod.script()) + + Notes + ----- + For a given PyTorch model, to lookup the names of the model inputs in + FX, one can use + + .. code-block:: python + + fx.symbolic_trace(model).graph.print_tabular() + + to print out the tabular representation of the PyTorch module, and then + check the placeholder rows in the beginning of the tabular. + """ + return TorchFXImporter().from_fx( + model, input_info, keep_params_as_input, unwrap_unit_return_tuple + ) diff --git a/python/tvm/relax/ir/instrument.py b/python/tvm/relax/ir/instrument.py new file mode 100644 index 000000000000..fc51a796a7a6 --- /dev/null +++ b/python/tvm/relax/ir/instrument.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common relax pass instrumentation across IR variants.""" +import tvm +from tvm import relax + + +@tvm.instrument.pass_instrument +class WellFormedInstrument: + """An instrument that checks the input/output IRModule of the Pass + is well formed. It will skip specific passes, like Normalize. + """ + + def __init__(self): + self.skip_pass_name = ["Normalize", "ResolveGlobals"] + + def run_before_pass(self, mod, pass_info): + if pass_info.name not in self.skip_pass_name: + assert relax.analysis.well_formed(mod) + + def run_after_pass(self, mod, pass_info): + if pass_info.name not in self.skip_pass_name: + assert relax.analysis.well_formed(mod) diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py new file mode 100644 index 000000000000..39a645ffea54 --- /dev/null +++ b/python/tvm/relax/op/__init__.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax core operators.""" + +# Operators +from .base import * +from .binary import * +from .create import * +from .datatype import * +from .index import * +from .linear_algebra import * +from .manipulate import * +from .op_attrs import * +from .statistical import * +from .search import * +from .set import * +from .ternary import * +from .unary import * +from . import builtin +from . import image +from . import memory +from . import nn + + +def _register_op_make(): + # pylint: disable=import-outside-toplevel + from . import _ffi_api + from .. import expr + + expr._op_ffi_api = _ffi_api # type: ignore + + +_register_op_make() diff --git a/python/tvm/relax/op/_ffi_api.py b/python/tvm/relax/op/_ffi_api.py new file mode 100644 index 000000000000..8dc6a1b4fbb0 --- /dev/null +++ b/python/tvm/relax/op/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.relax.op""" +import tvm._ffi + +tvm._ffi._init_api("relax.op", __name__) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py new file mode 100644 index 000000000000..d6e8b29b6dca --- /dev/null +++ b/python/tvm/relax/op/base.py @@ -0,0 +1,446 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# pylint: disable=redefined-builtin +"""The base Relax operators.""" +from typing import Union, List, Tuple, Optional + + +import tvm +import tvm.runtime +from tvm.runtime.object import Object + +from . import _ffi_api +from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar +from ..expr import Tuple as RxTuple +from ..struct_info import StructInfo, TensorStructInfo +from ...ir import PrimExpr +from ..utils import args_converter + + +py_print = print # pylint: disable=invalid-name + + +def null_value() -> Call: + """Create a call node that represents a null value object. + + Returns + ------- + ret: Call + The created call node. + """ + return _ffi_api.null_value() # type: ignore + + +@args_converter.auto +def call_tir( + gvar: GlobalVar, + args: Expr, + out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]], + tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None, +) -> Call: + """ + Call a tir.prim_func and return the output. + + Parameters + ---------- + gvar : GlobalVar + The GlobalVar referring to a tir PrimFunc. + + args : Expr + The input arguments. + + out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] + The structure info of the call_tir output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. + + tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] + ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used + + Returns + ------- + ret: Call + A call node for the call_tir operator. + """ + if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + args = RxTuple((args,)) + + if not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] + + if isinstance(tir_vars, (list, tuple)): + tir_vars = ShapeExpr(tir_vars) + + return _ffi_api.call_tir(gvar, args, out_sinfo, tir_vars) # type: ignore + + +@args_converter.auto +def call_dps_packed( + func: Union[str, Expr], + args: Expr, + out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]], +) -> Call: + """ + Call a destination-passing-style packed function and return the output. + + Parameters + ---------- + func : Union[str, Expr] + The destination-passing-style function, can be ExternFunc. + + args : Expr + The input arguments. + + out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] + The structure info of the call_dps_packed output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. + + Returns + ------- + ret: Call + A call node for the call_dps_packed operator. + """ + if isinstance(func, str): + func = ExternFunc(func) + + if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + args = RxTuple((args,)) + + if not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] + + return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore + + +@args_converter.auto +def call_builtin_with_ctx( + func: Union[str, Expr], + args: Expr, + *, + sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] = None, +) -> Call: + """Call a builtin function func. + + Parameters + ---------- + func : Expr + The builtin function to be called. + + args : Expr + The input arguments. + + sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] + The struct info arguments to the call node. + + Returns + ------- + ret: Call + The created call node. + """ + if isinstance(func, str): + func = ExternFunc(func) + + if sinfo_args is not None and not isinstance(sinfo_args, (list, tuple)): + sinfo_args = [sinfo_args] + + return _ffi_api.call_builtin_with_ctx( # type: ignore + func, + args, + sinfo_args, # type: ignore + ) + + +@args_converter.auto +def make_closure( + func: Expr, + args: Expr, +) -> Object: + """ + Create a closure with free variables and return the closure. + + Parameters + ---------- + func : Expr + The closure, can be ExternFunc or PrimFunc. + + args : Expr + The input arguments. + + + Returns + ------- + ret: Object + The VMClosure. + """ + + return _ffi_api.make_closure(func, args) # type: ignore + + +@args_converter.auto +def invoke_closure( + closure: Expr, + args: Expr, + sinfo_args: Union[List[StructInfo], StructInfo], +) -> Object: + """ + Invoke a closure. + + Parameters + ---------- + closure : Expr + The VMClosure object. + + args : Expr + The input arguments. + + type_args: Union[List[StructInfo], StructInfo] + The structure info arguments of the CallNode + + Returns + ------- + ret: Object + The result. + """ + + if not isinstance(sinfo_args, (list, tuple)): + sinfo_args = [sinfo_args] + + return _ffi_api.invoke_closure(closure, args, sinfo_args) # type: ignore + + +def render_object(val: tvm.Object) -> str: + """ + Given a TVM Object, renders it in string form. Used for Relax printing and assertions. + + Parameters + ---------- + val: tvm.Object + An object to render + + Returns + ------- + ret: str + A string representing the value, ideally human-readable + """ + if isinstance(val, tvm.nd.NDArray): + return str(val) + # no pretty-printer by default, so if we don't handle this, + # then we can't look inside tuples + if isinstance(val, tvm.runtime.container.ADT): + # the fields array of an ADT cannot be directly accessed in Python + # so we have to get the length and index into the fields separately + fields = ", ".join([render_object(val[i]) for i in range(len(val))]) + # special case: tag = 0 is a tuple + if val.tag == 0: + return f"({fields})" + return f"ADT(tag={val.tag}, fields=[{fields}])" + if isinstance(val, tvm.ir.Array): + fields = ", ".join([render_object(val[i]) for i in range(len(val))]) + return f"({fields})" + return str(val) + + +@tvm.register_func("relax.run.shape_to_tensor") +def relax_shape_to_tensor(shape_tuple: tvm.runtime.ShapeTuple) -> tvm.nd.NDArray: + """ + Takes a ShapeTuple and convert it to NDArray. + + Parameters + ---------- + shape_tuple: tvm.runtime.ShapeTuple + Shape tuple that we want to convert to NDArray at runtime + """ + return tvm.nd.array([int(v) for v in shape_tuple]) + + +@tvm.register_func("relax.run.print") +def relax_print(format_str: str, *format_args: tvm.Object) -> None: + """ + Takes a list of values to print, formats with the given format string. + If the format string is empty, simply prints. + + Call from TVM script like this: + `relax.print(value1, value2, ..., valueN, format=format_str)` + or + `relax.print(value1, value2, ..., valueN) # format_str defaults to ""` + + Parameters + ---------- + format_str: str + The last argument is a Python-style format string for printing the value + + format_args: List[Object] + The values to print. + """ + val_strs = map(render_object, format_args) + if format_str == "": + py_print(*val_strs) + else: + py_print(format_str.format(*val_strs)) + + +def print(*values: List[Expr], format: Union[str, Expr] = "") -> Expr: + """Print op to print the values + + Parameters + ---------- + values : List[Expr] + The values to print. + + format: Union[str, Expr] + The format string or StringImm. + + Returns + ------- + result : Expr + A relax Call, which will print the value during runtime. + """ + if isinstance(format, str): + format = StringImm(format) + + return _ffi_api.print(values, format) # type: ignore # pylint: disable=no-member + + +@tvm.register_func("relax.run.assert_op") +def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Object) -> None: + """ + A variadic function. The first value serves as the assertion condition: + If the condition is true, then the operator does nothing. + If the condition is false, then the operator raises an assertion error. + + Arguments after the first value serve as format arguments for the error message; + the last argument must be a format string for the error message (empty by default). + If the format string is the empty string, then the error message will simply include + a comma-separated list of the format arguments. + The condition argument is not included in the format string. + + Parameters + ---------- + condition: tvm.Object + The assertion condition. Must be a boolean scalar. + + format_str: str + The last argument is a Python-style format string for printing the value + + format_args: List[tvm.Object] + Values used for formatting the string. + """ + if not isinstance(format_str, str): + raise ValueError( + f"The format string argument to assert must be a string, given {type(format_str)})" + ) + + # should be guaranteed by the type system + if not isinstance(condition, tvm.nd.NDArray): + raise ValueError(f"The condition must be an NDArray, but given a {type(condition)}.") + + # may happen if the original program had unknown shape or dtype for the tensor's type + dtype = condition.dtype + if dtype != "bool": + raise ValueError(f"The condition must be a bool scalar, but given a {dtype} tensor") + shape = condition.shape + if len(shape) != 0: + raise ValueError(f"The condition must be a scalar, but it has a shape of {shape}") + + val = condition.numpy() + if not val: + error_message = "Assertion Failed" + if format_args or format_str != "": + rendered = map(render_object, format_args) + if format_str != "": + error_message = format_str.format(*rendered) + else: + error_message = ", ".join(rendered) + raise AssertionError(error_message) + + +def assert_op( + condition: Expr, + format_args: Optional[Union[Expr, List[Expr]]] = None, + format: Union[str, Expr] = "", +) -> Expr: + """ + Create a call to Relax's assert_op operation (`assert` is reserved in Python, + so the name must be distinct). + + Parameters + ---------- + condition: Expr + The assertion condition. + + format_args: Optional[Union[Expr, List[Expr]]] + Format arguments for the error message if the condition fails. + + format: Union[str, Expr] + The format string or StringImm for the error message. + + Returns + ------- + result : Expr + A Call to the Relax assert operation. + """ + if format_args is None: + format_args = [] + if isinstance(format_args, Expr): # type: ignore + format_args = [format_args] + if isinstance(format, str): + format = StringImm(format) + return _ffi_api.assert_op(condition, format_args, format) # type: ignore + + +def shape_of(expr: Expr) -> Expr: + """Get shape of a tensor. + + Parameters + ---------- + expr : Expr + The input Expr. + + Returns + ------- + result : Expr + A relax Call, which gets the shape of the input + """ + return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member + + +def tensor_to_shape(expr: Expr) -> Expr: + """Convert tensor to shape expr. + Parameters + ---------- + expr : Expr + The input Expr + Returns + ------- + result : Expr + A relax Call, which transforms the tensor values to the shape + """ + return _ffi_api.tensor_to_shape(expr) # type: ignore # pylint: disable=no-member + + +def shape_to_tensor(expr: Expr) -> Expr: + """Convert shape to tensor expr. + Parameters + ---------- + expr : Expr + The input Expr + Returns + ------- + result : Expr + A relax Call, which transforms the shape values to the tensor + """ + return _ffi_api.shape_to_tensor(expr) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py new file mode 100644 index 000000000000..09a0c30f193e --- /dev/null +++ b/python/tvm/relax/op/binary.py @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, invalid-name +"""Relax binary arithmetic and comparison operators.""" +from . import _ffi_api +from ..expr import Expr + +###################### Arithmetic operators ###################### + + +def add(x1: Expr, x2: Expr) -> Expr: + """Addition with numpy-style broadcasting. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + + Returns + ------- + result : Expr + The computed result. + + Examples + -------- + .. code:: python + + bb = relax.BlockBuilder() + a = relax.Var("a", relax.TensorStructInfo(shape=(2, 3), dtype="float32")) + b = relax.Var("b", relax.TensorStructInfo(shape=(2, 1), dtype="float32")) + c = bb.normalize(relax.op.add(a, b)) # c has TensorStructInfo(shape=(2, 3), dtype="float32") + """ + return _ffi_api.add(x1, x2) # type: ignore + + +def divide(x1: Expr, x2: Expr) -> Expr: + """Division with numpy-style broadcasting. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.divide(x1, x2) # type: ignore + + +def floor_divide(x1: Expr, x2: Expr) -> Expr: + """Floor division with numpy-style broadcasting. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.floor_divide(x1, x2) # type: ignore + + +def multiply(x1: Expr, x2: Expr) -> Expr: + """Multiplication with numpy-style broadcasting. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + + Returns + ------- + result : Expr + The computed result. + """ + return _ffi_api.multiply(x1, x2) # type: ignore + + +def power(x1: Expr, x2: Expr): + """Power with numpy-style broadcasting. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.power(x1, x2) # type: ignore + + +def subtract(x1: Expr, x2: Expr) -> Expr: + """Subtraction with numpy-style broadcasting. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.subtract(x1, x2) # type: ignore + + +###################### Comparison operators ###################### + + +def equal(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs == rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.equal(x1, x2) # type: ignore + + +def greater(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs > rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.greater(x1, x2) # type: ignore + + +def greater_equal(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs >= rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.greater_equal(x1, x2) # type: ignore + + +def less(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs < rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.less(x1, x2) # type: ignore + + +def less_equal(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs <= rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.less_equal(x1, x2) # type: ignore + + +def not_equal(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs != rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.not_equal(x1, x2) # type: ignore + + +###################### Comparison operators ###################### + + +def maximum(x1: Expr, x2: Expr) -> Expr: + """Element-wise maximum + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.maximum(x1, x2) + + +def minimum(x1: Expr, x2: Expr) -> Expr: + """Element-wise minimum + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.minimum(x1, x2) diff --git a/python/tvm/relax/op/builtin/__init__.py b/python/tvm/relax/op/builtin/__init__.py new file mode 100644 index 000000000000..04837724b165 --- /dev/null +++ b/python/tvm/relax/op/builtin/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax builtin operators.""" + +from .builtin import * diff --git a/python/tvm/relax/op/builtin/_ffi_api.py b/python/tvm/relax/op/builtin/_ffi_api.py new file mode 100644 index 000000000000..42fe8cb65234 --- /dev/null +++ b/python/tvm/relax/op/builtin/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.relax.op.builtin""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.builtin", __name__) diff --git a/python/tvm/relax/op/builtin/builtin.py b/python/tvm/relax/op/builtin/builtin.py new file mode 100644 index 000000000000..9dfb30bc7487 --- /dev/null +++ b/python/tvm/relax/op/builtin/builtin.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""The builtin Relax operators.""" + +from typing import Union +from ...expr import Call, Expr, PrimValue, DataTypeImm +from ...utils import args_converter +from . import _ffi_api + + +@args_converter.auto +def alloc_tensor( + shape: Expr, dtype: Union[str, Expr], runtime_device_index: Union[int, Expr] +) -> Call: + """Construct a Call to allocate a tensor with specific shape, dtype, runtime_device_index. + + Parameters + ---------- + shape : Expr + The shape of the tensor to be allocated. + + dtype : Union[str, Expr] + The datatype of the tensor to be allocated. + + runtime_device_index : Union[int, Expr] + The device index indicating on which device the tensor is to be allocated at runtime. + Index -1 is reserved for the host device. + + Returns + ------- + result : Call + A relax Call, which gets the allocated tensor. + """ + if isinstance(dtype, str): + dtype = DataTypeImm(dtype) + if isinstance(runtime_device_index, int): + runtime_device_index = PrimValue(runtime_device_index) + + return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index) # type: ignore + + +def stop_lift_params(x: Expr) -> Expr: + """ + An indicator that the consumers of input tensor should not be + lifted to transform_params function + + Parameters + ---------- + x: relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The result tensor that is the same as input tensor + """ + return _ffi_api.stop_lift_params(x) # type: ignore diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py new file mode 100644 index 000000000000..a6643a8633e4 --- /dev/null +++ b/python/tvm/relax/op/create.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Creation operators.""" +from typing import Optional, Tuple, Union + +from tvm import DataType +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ..expr import Expr, ShapeExpr + +PrimExprLike = Union[int, PrimExpr] + + +def full( + shape: Union[Tuple[PrimExprLike], Expr], + fill_value: Expr, + dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + """Fill array with scalar value. + + Parameters + ---------- + shape : Union[Tuple[PrimExprLike], Expr] + The shape of the created tensor. + + fill_value : relax.Expr + The value to fill. Must be a scalar tensor. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of fill_value. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.full(shape, fill_value, dtype) # type: ignore + + +def full_like(x: Expr, fill_value: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr: + """Construct a tensor such that + - its shape is the same as the input data tensor's shape, + - its value is filled with the input scalar fill value. + + Parameters + ---------- + x : relax.Expr + The input tensor, which provides the shape, and dtype + when the `dtype` field is not specified. + + fill_value : relax.Expr + The value to fill. Must be a scalar tensor. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of the input tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.full_like(x, fill_value, dtype) # type: ignore + + +def ones(shape: Union[Tuple[PrimExprLike], Expr], dtype: Union[str, DataType]) -> Expr: + """Construct a tensor of all ones, with the input shape and dtype. + + Parameters + ---------- + shape : Union[Tuple[PrimExprLike], Expr] + The shape of the created tensor. + + dtype : Union[str, DataType] + The data type of the created tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + if isinstance(shape, (tuple, list)): + shape = ShapeExpr(shape) + return _ffi_api.ones(shape, dtype) # type: ignore + + +def ones_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr: + """Construct a tensor with all ones, with shape of the input tensor shape. + + Parameters + ---------- + x : relax.Expr + The input tensor, which provides the shape, and dtype + when the `dtype` field is not specified. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of the input tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.ones_like(x, dtype) # type: ignore + + +def zeros(shape: Union[Tuple[PrimExprLike], Expr], dtype: Union[str, DataType]) -> Expr: + """Construct a tensor of all zeros, with the input shape and dtype. + + Parameters + ---------- + shape : Union[Tuple[PrimExprLike], Expr] + The shape of the created tensor. + + dtype : Union[str, DataType] + The data type of the created tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + if isinstance(shape, (tuple, list)): + shape = ShapeExpr(shape) + return _ffi_api.zeros(shape, dtype) # type: ignore + + +def zeros_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr: + """Construct a tensor with all zeros, with shape of the input tensor shape. + + Parameters + ---------- + x : relax.Expr + The input tensor, which provides the shape, and dtype + when the `dtype` field is not specified. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of the input tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.zeros_like(x, dtype) # type: ignore + + +def tril(x: Expr, k: int = 0) -> Expr: + """Return the lower triangular part of a matrix or a batch of matrices. + + Parameters + ---------- + x : relax.Expr + The tensor that tril will be applied to. + It is required to have at least two dimensions. + + k : int + The index indicating the diagonal above which to zero elements. + If k = 0, the diagonal is the main diagonal. + If k < 0, the diagonal is below the main diagonal. + If k > 0, the diagonal is above the main diagonal. + + Returns + ------- + ret : relax.Expr + The result tensor. + """ + return _ffi_api.tril(x, k) # type: ignore + + +def triu(x: Expr, k: int = 0) -> Expr: + """Return the upper triangular part of a matrix or a batch of matrices. + + Parameters + ---------- + x : relax.Expr + The tensor that triu will be applied to. + It is required to have at least two dimensions. + + k : int + The index indicating the diagonal below which to zero elements. + If k = 0, the diagonal is the main diagonal. + If k < 0, the diagonal is below the main diagonal. + If k > 0, the diagonal is above the main diagonal. + + Returns + ------- + ret : relax.Expr + The result tensor. + """ + return _ffi_api.triu(x, k) # type: ignore diff --git a/python/tvm/relax/op/datatype.py b/python/tvm/relax/op/datatype.py new file mode 100644 index 000000000000..120487c0bdfc --- /dev/null +++ b/python/tvm/relax/op/datatype.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Datatype operators.""" +from typing import Union + +from tvm import DataType + +from . import _ffi_api +from ..expr import Expr + + +def astype(x: Expr, dtype: Union[str, DataType]) -> Expr: + """Cast input tensor to the given data type. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + dtype: Union[str, DataType] + The target data type + + Returns + ------- + result : relax.Expr + The casted result. + """ + return _ffi_api.astype(x, dtype) # type: ignore + + +def wrap_param(data: Expr, dtype: Union[str, DataType] = "float32") -> Expr: + """Cast input tensor which is model param to data type if the dtype of the input data is not + the same as the given dtype. + Parameters + ---------- + data : relax.Expr + The input data to the operator. + dtype : Union[str, DataType] + The target data type + Returns + ------- + result : relax.Expr + The casted result. + """ + return _ffi_api.wrap_param(data, dtype) # type: ignore diff --git a/python/tvm/relax/op/image/__init__.py b/python/tvm/relax/op/image/__init__.py new file mode 100644 index 000000000000..f2552ad6ac51 --- /dev/null +++ b/python/tvm/relax/op/image/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Image operators.""" +from .image import * diff --git a/python/tvm/relax/op/image/_ffi_api.py b/python/tvm/relax/op/image/_ffi_api.py new file mode 100644 index 000000000000..e666203ae7ff --- /dev/null +++ b/python/tvm/relax/op/image/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.image", __name__) diff --git a/python/tvm/relax/op/image/image.py b/python/tvm/relax/op/image/image.py new file mode 100644 index 000000000000..562de5021d53 --- /dev/null +++ b/python/tvm/relax/op/image/image.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Image operators.""" +from typing import Optional, Tuple, Union + +from tvm import DataType +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ...expr import Expr, ShapeExpr + + +PrimExprLike = Union[int, PrimExpr] + + +def resize2d( + data: Expr, + size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]], + roi: Optional[Union[float, Tuple[float]]] = None, + layout: str = "NCHW", + method: str = "linear", + coordinate_transformation_mode: str = "half_pixel", + rounding_method: str = "round", + cubic_alpha: float = -0.5, + cubic_exclude: int = 0, + extrapolation_value: float = 0.0, + out_dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + """Image resize2d operator. + + This operator takes data as input and does 2D scaling to the given scale factor. + In the default case, where the data_layout is `NCHW` + with data of shape (n, c, h, w) + out will have a shape (n, c, size[0], size[1]) + + method indicates the algorithm to be used while calculating the out value + and method can be one of ("linear", "nearest_neighbor", "cubic") + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]] + The out size to which the image will be resized. + If specified as a list, it is required to have length either 1 or 2. + If specified as an Expr, it is required to have ndim 2. + + roi: Optional[Union[float, Tuple[float]]] + The region of interest for cropping the input image. Expected to be of + size 4, and format [start_h, start_w, end_h, end_w]. + Only used if coordinate_transformation_mode is tf_crop_and_resize. + + layout : str + Layout of the input. + + method : str + Scale method to used [nearest_neighbor, linear, cubic]. + + coordinate_transformation_mode : str + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. Definitions can be found + in topi/image/resize.py. + [half_pixel, align_corners, asymmetric, pytorch_half_pixel, + tf_half_pixel_for_nn, and tf_crop_and_resize]. + + rounding_method: str + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + cubic_alpha: float + Spline Coefficient for bicubic interpolation + + cubic_exclude: int + Flag to exclude exterior of the image during bicubic interpolation + + extrapolation_value: float + Fill value to use when roi is outside of the image + + out_dtype : Optional[Union[str, DataType]] + The dtype of the output tensor. + It it is not specified, the output will have the same dtype as input if not specified. + + Returns + ------- + result: relax.Expr + The resized result. + """ + if roi is None: + roi = (0.0, 0.0, 0.0, 0.0) # type: ignore + elif isinstance(roi, float): + roi = (roi, roi, roi, roi) # type: ignore + + if isinstance(size, (int, PrimExpr)): + size = (size, size) + if isinstance(size, tuple): + if len(size) == 1: + size = ShapeExpr([size[0], size[0]]) + else: + size = ShapeExpr(size) + + return _ffi_api.resize2d( # type: ignore + data, + size, + roi, + layout, + method, + coordinate_transformation_mode, + rounding_method, + cubic_alpha, + cubic_exclude, + extrapolation_value, + out_dtype, + ) diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py new file mode 100644 index 000000000000..2a7afa5ba0f9 --- /dev/null +++ b/python/tvm/relax/op/index.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Indexing operators.""" +from typing import List, Optional, Union + +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ..expr import Expr + +PrimExprLike = Union[int, PrimExpr] + + +def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr: + """Take elements from a tensor along an axis. + + Parameters + ---------- + x : relax.Expr + The source tensor. + + indices : relax.Expr + The indices of the values to extract. + It is required to be a one-dimensional tensor which has integer dtype. + + axis : Optional[int] + The axis over which to select values. + If it is none, the input tensor is required to be one-dimensional. + + Returns + ------- + ret : relax.Expr + The taken result. + """ + return _ffi_api.take(x, indices, axis) # type: ignore + + +def strided_slice( + x: Expr, + axes: List[int], + begin: List[PrimExprLike], + end: List[PrimExprLike], + strides: Optional[List[PrimExprLike]] = None, +) -> Expr: + """Strided slice of a tensor. + + Parameters + ---------- + x : relax.Expr + The source tensor to be sliced. + + axes : List[int] + Axes along which slicing is applied. + + begin : List[PrimExprLike] + The indices to begin with in the slicing, inclusive. + + end : List[PrimExprLike] + The indices indicating end of the slice, exclusive. + + strides : Optional[List[PrimExprLike]] + Specifies the stride values, it can be negative in that case, + the input tensor will be reversed in that particular axis. + If not specified, it by default is an list of ones of the same length as `axes`. + + Returns + ------- + ret : relax.Expr + The sliced result. + + Note + ---- + strided_slice require the input `begin`, `end` and `strides` to have the + same length as `axes`. + """ + return _ffi_api.strided_slice(x, axes, begin, end, strides) # type: ignore diff --git a/python/tvm/relax/op/linear_algebra.py b/python/tvm/relax/op/linear_algebra.py new file mode 100644 index 000000000000..940861a97227 --- /dev/null +++ b/python/tvm/relax/op/linear_algebra.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Relax linear algebra operators""" +from typing import Optional, Union + +from tvm import DataType + +from ..expr import Expr +from . import _ffi_api +from .manipulate import permute_dims + + +def matmul(x1: Expr, x2: Expr, out_dtype: Optional[Union[str, DataType]] = None) -> Expr: + """General matrix multiplication of two tensors, with broadcasting on batched dimensions. + + The semantics and output shape deduction rule is specified as + https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + + x2 : relax.Expr + The second input tensor. + + out_dtype: Optional[Union[str, DataType]] + The data type of the matmul result. + When it is not specified, the output dtype will be the the same as input dtype. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.matmul(x1, x2, out_dtype) # type: ignore + + +def linear( + data: Expr, + weight: Expr, + bias: Optional[Expr] = None, + out_dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + """Applies a linear transformation to the incoming data: y = xA^T + b + + Parameters + ---------- + data : relax.Expr + The input data. + + weight : relax.Expr + The weight tensor. + + bias : Optional[Expr] + The bias tensor. + + out_dtype: Optional[Union[str, DataType]] + The data type of the matmul result. + When it is not specified, the output dtype will be the the same as input dtype. + + Notes + ----- + Relax does not regard the Linear Op as a primitive Op, + while combine the transpose, matmul and add op to implement it. + + Returns + ------- + result : relax.Expr + The computed result. + """ + + # Since weight can be 1D or 2D, we use `axes=None` to support both cases. + x = matmul(data, permute_dims(weight, axes=None), out_dtype=out_dtype) + return x + bias if bias is not None else x diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py new file mode 100644 index 000000000000..e9c3ce79d745 --- /dev/null +++ b/python/tvm/relax/op/manipulate.py @@ -0,0 +1,441 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Manipulation operators.""" +from typing import List, Optional, Tuple, Union, Callable + +from tvm import DataType +from tvm.ir.expr import PrimExpr +from tvm.tir import IntImm, FloatImm, IndexMap + +from . import _ffi_api +from ..expr import Expr, PrimValue, ShapeExpr, Tuple as RxTuple + + +PrimExprLike = Union[int, PrimExpr] + + +def broadcast_to(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: + """Broadcasts a tensor to a specified shape. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + shape : Union[Tuple[PrimExprLike], Expr] + The target shape. + + Returns + ------- + result : relax.Expr + The broadcasted tensor. + """ + if isinstance(shape, (tuple, list)): + shape = ShapeExpr(shape) + return _ffi_api.broadcast_to(x, shape) # type: ignore + + +def concat(tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: + """Concatenate the input tensors along the given axis. + + Parameters + ---------- + tensors : Union[relax.Expr, List[relax.Expr]] + An Expr in Tuple type, containing the tensors to be concatenated, + or a list of Tensors. + + axis : Optional[int] + The axis along which the tensors are concatenated. + If `axis` is `None`, the input tensor is required to be flattened before concatenation. + + Returns + ------- + result: relax.Expr + The concatenated tensor. + """ + if isinstance(tensors, (list, tuple)): + tensors = RxTuple(tensors) + return _ffi_api.concat(tensors, axis) # type: ignore + + +def expand_dims(x: Expr, axis: Union[int, List[int]]) -> Expr: + """Insert new axes at the positions given by `axis`. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + axis : Union[int, List[int]] + The axes at which the input array are expanded. + All values are required to lie in range `[-data.ndim - 1, data.ndim]`, with the convention + of negative indexing. + + Returns + ------- + result : relax.Expr + The transformed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.expand_dims(x, axis) # type: ignore + + +def flatten(x: Expr) -> Expr: + """Flatten all the tensor dimensions into one. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + Returns + ------- + result : relax.Expr + The flattened result. + """ + return _ffi_api.flatten(x) # type: ignore + + +def layout_transform( + x: Expr, + index_map: Union[Callable, IndexMap], + pad_value: Optional[Union[int, float, PrimValue]] = None, +): + """Modifies the layout of a tensor. + + Parameters + ---------- + x : relax.Expr + The input tensor to the operator. + + index_map : Union[Callable, IndexMap] + The transformation to apply. + + pad_value : Optional[Union[int, float, PrimValue]] + The value used for padding if the transformation results in implicit padding. + If not specified, any value can be used. + + Returns + ------- + result : relax.Expr + The transformed tensor. + """ + if callable(index_map): + index_map = IndexMap.from_func(index_map) + x_dtype = x.checked_type.dtype + + # Explicitly convert python int/float pad_value to the x's type. If the default behavior + # is applied, it would be converted to int32/float32, which may not match the x's type. + if pad_value is None: + pass + elif not isinstance(pad_value, PrimValue): + if "int" in x_dtype and isinstance(pad_value, int): + pad_value = IntImm(x_dtype, pad_value) + elif "float" in x_dtype and (isinstance(pad_value, (int, float))): + pad_value = FloatImm(x_dtype, float(pad_value)) + pad_value = PrimValue(pad_value) + return _ffi_api.layout_transform(x, index_map, pad_value) # type: ignore + + +def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr: + """Permutes the dimensions of an array. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + axes : Optional[List[int]] + The target axes order, reverse order if not specified. + + Returns + ------- + result : relax.Expr + The transposed result. + """ + return _ffi_api.permute_dims(x, axes) # type: ignore + + +def reshape(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: + """Reshape the input array. + + ``-1`` infers the dimension of the output shape by using the remainder of + the input dimensions keeping the size of the new array same as that of the input array. + At most one dimension of shape can be -1. + + .. code-block:: python + + x.shape = (2, 3, 4), shape = (6, 1, -1), result.shape = (6, 1, 4) + x.shape = (2, 3, 4), shape = (3, -1, 8), result.shape = (3, 1, 8) + x.shape = (2, 3, 4), shape = (-1,), result.shape = (24,) + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + shape : Union[Tuple[PrimExprLike], Expr] + The new shape. Should be compatible with the original shape. + + Returns + ------- + result : relax.Expr + The reshaped result. + + Note + ---- + The ``-1`` inference is only performed at compile-time. + That is to say, in any case the dimension length of ``-1`` cannot be inferred in + compile-time, an error will be thrown. + """ + return _ffi_api.reshape(x, shape) # type: ignore + + +def split( + x: Expr, + indices_or_sections: Union[int, List[PrimExprLike]], + axis: int = 0, +) -> Expr: + """Split input tensor along axis by sections or indices. + + If indices_or_sections is an integer, the input will be divided equally + along given axis (if possible). Last section will be smaller if the tensor + size along the given dimension is not divisible by the integer. + + If indices_or_sections is a tuple of mixture of int or PrimExpr, + the entries indicate the indices where along axis the array is split. + + Parameters + ---------- + x : relax.Expr + The tensor to be split. + + indices_or_sections : Union[int, List[PrimExprLike]] + Indices or sections to split into. Accepts an int or a list. + + axis : int + The axis over which to split. + + Returns + ------- + ret : relax.Expr + The computed result. + """ + if isinstance(indices_or_sections, int): + indices_or_sections = IntImm("int64", indices_or_sections) + return _ffi_api.split(x, indices_or_sections, axis) # type: ignore + + +def squeeze(x: Expr, axis: Optional[Union[int, List[int]]] = None) -> Expr: + """Squeeze axes in the array. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + axis : Optional[Union[int, List[int]] + The set of axes to remove. + If axis = None, remove all axis of dimensions 1. + If any specified axis has dimension that does not equal 1, it is an error. + + Returns + ------- + result : relax.Expr + The squeezed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.squeeze(x, axis) # type: ignore + + +def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr: + """Return a summation of data to the shape of collapse_target. + + For details, please see relax.op.collapse_sum_to. + + Parameters + ---------- + data : relax.Expr + The input tensor. + + collapse_target : relax.Expr + The tensor whose shape is the shape to collapse to. + + Returns + ------- + result : relax.Expr + The result tensor after summation. + """ + return _ffi_api.collapse_sum_like(data, collapse_target) # type: ignore + + +def collapse_sum_to(data: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: + """Return a summation of data to the given shape. + + collapse_sum_to is intended as the backward operator of tvm.relax.op.broadcast_to and + other broadcast operators in the automatic differentiation process. + + We expect that data is the result of broadcasting some tensor of the given shape in some + broadcast operation. Thus the given `shape` and `data.shape` must follow broadcast rules. + + During computation, all axes of `data.shape` and `shape` are checked from right to left. + For an axis, if it follows these rules, `data` will be summed over this axis: + - the axis exists in `data.shape` but not in `shape`, or + - the axis exists in `data.shape` and equals to 1 in `shape`. + + Parameters + ---------- + data : relax.Expr + The input tensor. + + shape : Union[Tuple[PrimExprLike], relax.Expr] + The shape to collapse to. + + Returns + ------- + result : relax.Expr + The result tensor of the given shape after summation. + """ + if isinstance(shape, (tuple, list)): + shape = ShapeExpr(shape) + return _ffi_api.collapse_sum_to(data, shape) # type: ignore + + +def repeat(data: Expr, repeats: int, axis: Optional[int] = None) -> Expr: + """Repeats elements of an array. + + Parameters + ---------- + data : relax.Expr + The input tensor. + + repeats : int + The number of repetitions. + + axis: Optional[int] + The axis along which to repeat values. The negative numbers are interpreted + counting from the backward. By default, use the flattened input array, and + return a flat output array. + + Returns + ------- + ret : relax.Expr + The computed result. + + Examples + -------- + .. code-block:: python + x = R.const([[1, 2], [3, 4]]) + lv1 = R.repeat(x, repeats=2) # lv1 == [1, 1, 2, 2, 3, 3, 4, 4] + lv2 = R.repeat(x, repeats=2, axis=1) # lv2 == [[1., 1., 2., 2.], + # [3., 3., 4., 4.]] + """ + return _ffi_api.repeat(data, repeats, axis) # type: ignore + + +def tile(data: Expr, repeats: Union[int, Tuple[int], List[int]]) -> Expr: + """Construct an array by repeating data the number of times given by repeats. + + If repeats has length l, and data has dimension d, the result will have dimension of max(l, d). + + If d < l, data is promoted to be l-dimensional by prepending new axes. So a shape (3,) Tensor is + promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D replication. If this is not + the desired behavior, promote data to d-dimensions manually before calling this function. + + If d > l, reps is promoted to length d by pre-pending 1's to it. Thus for a data of shape + (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2). + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + repeats : Union[int, Tuple[int], List[int]] + The number of repetitions of data along each axis. + + Returns + ------- + ret : relax.Expr + The computed result. + + Examples + -------- + .. code-block:: python + + x = R.const([[1, 2], [3, 4]]) + lv1 = R.tile(x, reps=(2, 3)) # lv1 = [[1., 2., 1., 2., 1., 2.], + # [3., 4., 3., 4., 3., 4.], + # [1., 2., 1., 2., 1., 2.], + # [3., 4., 3., 4., 3., 4.]] + lv2 = R.tile(x, reps=2) # lv2 = [[1., 2., 1., 2.], + # [3., 4., 3., 4.]] + """ + if isinstance(repeats, int): + repeats = [repeats] + return _ffi_api.tile(data, repeats) # type: ignore + + +def cumsum(data: Expr, axis: Optional[int] = None, dtype: Optional[Union[str, DataType]] = None): + """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along + a given axis. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + axis : Optional[int] + Axis along which the cumulative sum is computed. The default (None) is to compute + the cumsum over the flattened array. + + dtype : Optional[Union[str, DataType]] + Type of the returned array and of the accumulator in which the elements are summed. + If dtype is not specified, it defaults to the dtype of data. + + Returns + ------- + result : relax.Expr + The result has the same size as data, and the same shape as data if axis is not None. + If axis is None, the result is a 1-d array. + + Examples + -------- + .. code-block:: python + + a = [[1, 2, 3], [4, 5, 6]] + + cumsum(a) # if axis is not provided, cumsum is done over the flattened input. + -> [ 1, 3, 6, 10, 15, 21] + + cumsum(a, dtype="float32") + -> [ 1., 3., 6., 10., 15., 21.] + + cumsum(a, axis=0) # sum over rows for each of the 3 columns + -> [[1, 2, 3], + [5, 7, 9]] + + cumsum(a, axis=1) + -> [[ 1, 3, 6], + [ 4, 9, 15]] + + a = [1, 0, 1, 0, 1, 1, 0] # a is a boolean array + cumsum(a, dtype=int32) # dtype should be provided to get the expected results + -> [1, 1, 2, 2, 3, 4, 4] + """ + return _ffi_api.cumsum(data, axis, dtype) # type: ignore diff --git a/python/tvm/relax/op/memory/__init__.py b/python/tvm/relax/op/memory/__init__.py new file mode 100644 index 000000000000..e039590251fc --- /dev/null +++ b/python/tvm/relax/op/memory/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax memory primitives.""" + +from .memory import * diff --git a/python/tvm/relax/op/memory/_ffi_api.py b/python/tvm/relax/op/memory/_ffi_api.py new file mode 100644 index 000000000000..475de481b22e --- /dev/null +++ b/python/tvm/relax/op/memory/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.relax.op.memory""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.memory", __name__) diff --git a/python/tvm/relax/op/memory/memory.py b/python/tvm/relax/op/memory/memory.py new file mode 100644 index 000000000000..7b84ffc48bb6 --- /dev/null +++ b/python/tvm/relax/op/memory/memory.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""Relax memory primitives.""" + +from typing import Union +from . import _ffi_api +from ...expr import Expr, Call, PrimValue, DataTypeImm, StringImm +from ...utils import args_converter + + +@args_converter.auto +def alloc_storage( + size: Expr, + virtual_device_index: Union[int, Expr], + storage_scope: Union[str, Expr], + dtype: Union[str, Expr], +) -> Call: + """Construct a Call to allocate a storage with specific size, virtual_device_index, + storage_scope and dtype. + + Parameters + ---------- + size : Expr + The size of the storage to be allocated. + + virtual_device_index : Union[int, Expr] + The virtual device index indicating on which device the storage is to be allocated. + Index -1 is reserved for the host device. + + storage_scope : Union[str, Expr] + The storage scope to allocate the storage to. + + dtype : Union[str, Expr] + The datatype of the storage to be allocated. + + Returns + ------- + result : Call + A relax Call, which gets the allocated storage. + """ + if isinstance(dtype, str): + dtype = DataTypeImm(dtype) + if isinstance(storage_scope, str): + storage_scope = StringImm(storage_scope) + if isinstance(virtual_device_index, int): + virtual_device_index = PrimValue(virtual_device_index) + return _ffi_api.alloc_storage(size, virtual_device_index, storage_scope, dtype) # type: ignore + + +@args_converter.auto +def alloc_tensor( + storage: Expr, offset: Union[int, Expr], shape: Expr, dtype: Union[str, Expr] +) -> Call: + """Construct a Call to allocate a tensor on a certain storage starting from the given offset. + + Parameters + ---------- + storage : Expr + The storage to allocate the tensor to. + + offset : Union[int, Expr] + The storage offset to allocate the tensor. + + shape : Expr + The shape of the tensor to be allocated. + + dtype : Union[str, Expr] + The datatype of the tensor to be allocated. + + Returns + ------- + result : Call + A relax Call, which gets the allocated tensor. + """ + if isinstance(offset, int): + offset = PrimValue(offset) + if isinstance(dtype, str): + dtype = DataTypeImm(dtype) + return _ffi_api.alloc_tensor(storage, offset, shape, dtype) # type: ignore + + +@args_converter.auto +def kill_storage(storage: Expr) -> Call: + """Construct a Call to kill a storage. + + Parameters + ---------- + storage : Expr + The storage to be killed. + + Returns + ------- + result : Call + A relax Call to kill a storage. + """ + return _ffi_api.kill_storage(storage) # type: ignore + + +@args_converter.auto +def kill_tensor(tensor: Expr) -> Call: + """Construct a Call to kill a tensor. + + Parameters + ---------- + tensor : Expr + The tensor to be killed. + + Returns + ------- + result : Call + A relax Call to kill a tensor. + """ + return _ffi_api.kill_tensor(tensor) # type: ignore diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py new file mode 100644 index 000000000000..af2aa106bca7 --- /dev/null +++ b/python/tvm/relax/op/nn/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Neural network related operators.""" +from .nn import * diff --git a/python/tvm/relax/op/nn/_ffi_api.py b/python/tvm/relax/op/nn/_ffi_api.py new file mode 100644 index 000000000000..1785345ac1b1 --- /dev/null +++ b/python/tvm/relax/op/nn/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.nn", __name__) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py new file mode 100644 index 000000000000..02468637e0f9 --- /dev/null +++ b/python/tvm/relax/op/nn/nn.py @@ -0,0 +1,962 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Neural Network (NN) operators""" +from typing import List, Optional, Tuple, Union + +from tvm import DataType +from tvm.tir import FloatImm + +from . import _ffi_api +from ...expr import Expr + + +def conv1d( + data: Expr, + weight: Expr, + strides: Union[int, Tuple[int]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + dilation: Union[int, Tuple[int]] = 1, + groups: int = 1, + data_layout: str = "NCW", + kernel_layout: str = "OIW", + out_layout: Optional[str] = None, + out_dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + r"""1D convolution. + + This operator takes the weight as the 1D convolution kernel + and convolves it with data to produce an output. + + + In the default case, where the data_layout is `NCW` + and kernel_layout is `OIW`, conv1d takes in + a data Tensor with shape `(batch_size, in_channels, width)`, + and a weight Tensor with shape `(channels, in_channels, kernel_w)`, + where `kernel_w` is the length of the `W` kernel dimension, + to produce an output Tensor with the following rule: + + .. math:: + + \mbox{out}[b, c, x] = \sum_{dx, k} + \mbox{data}[b, k, \mbox{strides} * x + dx] * + \mbox{weight}[c, k, dx] + + Padding and dilation are applied to data and weight respectively before the computation. + This operator accepts data layout specification. + Semantically, the operator will convert the layout to the canonical layout + (`NCW` for data and `OIW` for weight), perform the computation, + then convert to the out_layout. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + weight : relax.Expr + The weight expressions. + + strides : Union[int, Tuple[int]] + The strides of convolution. It is required to have length 1. + + padding : Union[int, Tuple[int, ...]] + The padding of convolution on both sides of inputs before convolution. + It is required to have length either 1 or 2. + + dilation : Union[int, Tuple[int, int]] + Specifies the dilation rate to be used for dilated convolution. + It is required to have length 1. + + groups : int + Number of groups to split the input into for grouped convolution. + The number of input and output channels should be divisible by the number of groups. + + data_layout : str + Layout of the input. + + kernel_layout : str + Layout of the weight. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + out_dtype : Optional[Union[str, DataType]] + Specifies the output data type for mixed precision conv1d. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(strides, int): + strides = (strides,) + if isinstance(dilation, int): + dilation = (dilation,) + if isinstance(padding, int): + padding = (padding, padding) + + return _ffi_api.conv1d( # type: ignore + data, + weight, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + out_layout, + out_dtype, + ) + + +def conv2d( + data: Expr, + weight: Expr, + strides: Union[int, Tuple[int, int]] = (1, 1), + padding: Union[int, Tuple[int, ...]] = (0, 0), + dilation: Union[int, Tuple[int, int]] = (1, 1), + groups: int = 1, + data_layout: str = "NCHW", + kernel_layout: str = "OIHW", + out_layout: Optional[str] = None, + out_dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + r"""2D convolution. + + This operator takes the weight as the convolution kernel + and convolves it with data to produce an output. + + + In the default case, where the data_layout is `NCHW` + and kernel_layout is `OIHW`, conv2d takes in + a data Tensor with shape `(batch_size, in_channels, height, width)`, + and a weight Tensor with shape `(channels, in_channels, kernel_h, kernel_w)`, + where `kernel_h` and `kernel_w` is the lengths of the `H` and `W` kernel dimensions, + to produce an output Tensor with the following rule: + + .. math:: + + \mbox{out}[b, c, y, x] = \sum_{dy, dx, k} + \mbox{data}[b, k, \mbox{strides}[0] * y + dy, \mbox{strides}[1] * x + dx] * + \mbox{weight}[c, k, dy, dx] + + Padding and dilation are applied to data and weight respectively before the computation. + This operator accepts data layout specification. + Semantically, the operator will convert the layout to the canonical layout + (`NCHW` for data and `OIHW` for weight), perform the computation, + then convert to the out_layout. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + weight : relax.Expr + The weight expressions. + + strides : Union[int, Tuple[int, int]] + The strides of convolution. It is required to have length either 1 or 2. + + padding : Union[int, Tuple[int, ...]] + The padding of convolution on both sides of inputs before convolution. + It is required to have length either 1, 2 or 4. + + dilation : Union[int, Tuple[int, int]] + Specifies the dilation rate to be used for dilated convolution. + It is required to have length either 1 or 2. + + groups : int + Number of groups to split the input into for grouped convolution. + The number of input and output channels should be divisible by the number of groups. + + data_layout : str + Layout of the input. + + kernel_layout : str + Layout of the weight. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + out_dtype : Optional[Union[str, DataType]] + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(strides, int): + strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) + if isinstance(padding, int): + padding = (padding, padding, padding, padding) + + return _ffi_api.conv2d( # type: ignore + data, + weight, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + out_layout, + out_dtype, + ) + + +def conv2d_transpose( + data: Expr, + weight: Expr, + strides: Union[int, Tuple[int, int]] = (1, 1), + padding: Union[int, Tuple[int, ...]] = (0, 0), + output_padding: Union[int, Tuple[int, int]] = (0, 0), + dilation: Union[int, Tuple[int, int]] = (1, 1), + groups: int = 1, + data_layout: str = "NCHW", + kernel_layout: str = "IOHW", + out_layout: Optional[str] = None, + out_dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + r"""Two dimensional transposed convolution operator. + + This operator is intended to be the gradient operator of conv2d. That means, if + + `out = conv2d(data, weight, strides, padding, dilation)`, + + The gradient w.r.t. data can be calculated as follows: + + `data_grad = conv2d_transpose(out_grad, weight, strides, padding, output_padding, dilation)`, + + where `output_padding` is a parameter used to determine the output shape. + + The output shape can be explained in the simple case when `data_layout == "NCHW"` and + `kernel_layout == "IOHW"`. Suppose `data` has shape `(N, in_channel, in_h, in_w)`, `weight` has + shape `(in_channel, out_channel, weight_h, weight_w)`, we need to assure that + `in_channel % groups == 0`. The shape of the output will be + `(N, out_channel * groups, out_h, out_w)`, where + + - `out_h = ((in_h - 1) * strides[0] + weight_h - 2 * padding[0] + output_padding[0])` + - `out_w = ((in_w - 1) * strides[1] + weight_w - 2 * padding[1] + output_padding[1])` + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + weight : relax.Expr + The weight expressions. + + strides : Union[int, Tuple[int, int]] + The strides of convolution. It is required to have length either 1 or 2. + + padding : Union[int, Tuple[int, ...]] + The padding of convolution on both sides of inputs before convolution. + It is required to have length either 1, 2 or 4. + + output_padding : Union[int, Tuple[int, ...]], optional + Used to disambiguate the output shape. + + dilation : Union[int, Tuple[int, int]] + Specifies the dilation rate to be used for dilated convolution. + It is required to have length either 1 or 2. + + groups : int + Number of groups to split the input into for grouped convolution. + The number of input and output channels should be divisible by the number of groups. + + data_layout : str + Layout of the input. + + kernel_layout : str + Layout of the weight. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + out_dtype : Optional[Union[str, DataType]] + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : relax.Expr + The computed result. + """ + # TODO: symbolic shape is not fully supported now + if isinstance(strides, int): + strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) + if isinstance(padding, int): + padding = (padding, padding, padding, padding) + if isinstance(output_padding, int): + output_padding = (output_padding, output_padding) + + return _ffi_api.conv2d_transpose( # type: ignore + data, + weight, + strides, + padding, + output_padding, + dilation, + groups, + data_layout, + kernel_layout, + out_layout, + out_dtype, + ) + + +def max_pool2d( + data: Expr, + pool_size: Union[int, Tuple[int, int]] = (1, 1), + strides: Union[int, Tuple[int, int]] = (1, 1), + padding: Union[int, Tuple[int, ...]] = (0, 0), + dilation: Union[int, Tuple[int, int]] = (1, 1), + ceil_mode: bool = False, + layout: str = "NCHW", + out_layout: Optional[str] = None, +) -> Expr: + r"""2D maximum pooling operator. + + This operator takes data as input and does 2D max value calculation + with in pool_size sized window by striding defined by stride. + + In the default case, where the data_layout is `NCHW` + a data Tensor with shape `(batch_size, in_channels, height, width)`, + to produce an output Tensor with the following rule: + + with data of shape (b, c, h, w) and pool_size (kh, kw) + + .. math:: + + \mbox{out}(b, c, y, x) = \max_{m=0, \ldots, kh-1} \max_{n=0, \ldots, kw-1} + \mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x + n) + + Padding is applied to data before the computation. + ceil_mode is used to take ceil or floor while computing out shape. + This operator accepts data layout specification. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + pool_size : Union[int, Tuple[int, int]] + The size of window for pooling. It is required to have length either 1 or 2. + + strides : Union[int, Tuple[int, int]] + The strides of pooling. It is required to have length either 1 or 2. + + padding : Union[int, Tuple[int, ...]] + The padding for pooling. It is required to have length either 1, 2 or 4. + + dilation : Union[int, Tuple[int, int]] + The dilation of pooling. It is required to have length either 1 or 2. + + ceil_mode : bool + A boolean indicating if use ceil or floor to compute the output shape. + By using ceil, every element in the input tensor will be covered by a sliding window. + + layout : str + Layout of the input. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + Returns + ------- + result : Expr + The computed result. + """ + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) + if isinstance(padding, int): + padding = (padding, padding, padding, padding) + + return _ffi_api.max_pool2d( # type: ignore + data, pool_size, strides, padding, dilation, ceil_mode, layout, out_layout + ) + + +def avg_pool2d( + data: Expr, + pool_size: Union[int, Tuple[int, int]] = (1, 1), + strides: Union[int, Tuple[int, int]] = (1, 1), + padding: Union[int, Tuple[int, ...]] = (0, 0), + dilation: Union[int, Tuple[int, int]] = (1, 1), + ceil_mode: bool = False, + layout: str = "NCHW", + out_layout: Optional[str] = None, +) -> Expr: + r"""2D average pooling operator. + + This operator takes data as input and does 2D avarage value calculation + with in pool_size sized window by striding defined by stride. + + In the default case, where the data_layout is `NCHW` + a data Tensor with shape `(batch_size, in_channels, height, width)`, + to produce an output Tensor with the following rule: + + with data of shape (b, c, h, w) and pool_size (kh, kw) + + .. math:: + + \mbox{out}(b, c, y, x) = \frac{1}{kh * kw} \sum_{m=0, \ldots, kh-1} + \sum_{n=0, \ldots, kw-1} + \mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x + n) + + Padding is applied to data before the computation. + ceil_mode is used to take ceil or floor while computing out shape. + This operator accepts data layout specification. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + pool_size : Union[int, Tuple[int, int]] + The size of window for pooling. It is required to have length either 1 or 2. + + strides : Union[int, Tuple[int, int]] + The strides of pooling. It is required to have length either 1 or 2. + + padding : Union[int, Tuple[int, ...]] + The padding for pooling. It is required to have length either 1, 2 or 4. + + dilation : Union[int, Tuple[int, int]] + The dilation of pooling. It is required to have length either 1 or 2. + + ceil_mode : bool + A boolean indicating if use ceil or floor to compute the output shape. + By using ceil, every element in the input tensor will be covered by a sliding window. + + layout : str + Layout of the input. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + Returns + ------- + result : Expr + The computed result. + """ + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) + if isinstance(padding, int): + padding = (padding, padding, padding, padding) + + return _ffi_api.avg_pool2d( # type: ignore + data, pool_size, strides, padding, dilation, ceil_mode, layout, out_layout + ) + + +def adaptive_avg_pool2d( + data: Expr, + output_size: Optional[Union[int, Tuple[int, int]]] = None, + layout: str = "NCHW", + out_layout: Optional[str] = None, +) -> Expr: + r"""2D adaptive average pooling operator. This operator is experimental. + + This operator takes data as input and does 2D average value calculation + across each window represented by WxH. + + + In the default case, where the data_layout is `NCHW` + a data Tensor with shape `(batch_size, in_channels, height, width)`, + to produce an output Tensor with shape + (batch_size, in_channels, output_height, output_width). + + The pooling kernel and stride sizes are automatically chosen for + desired output sizes. + + For output_size: + If this argument is not provided, input height and width will be used + as output height and width. + + If a single integer is provided for output_size, the output size is + (N x C x output_size x output_size) for any input (NCHW). + + If a tuple of integers (height, width) are provided for output_size, + the output size is (N x C x height x width) for any input (NCHW). + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + output_size : Optional[Union[int, Tuple[int, int]]] + Output height and width. + If not specified, it will be the same as the input height and width. + If specified, it is required to have length either 1 or 2. + + layout : str + Layout of the input. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(output_size, int): + output_size = (output_size, output_size) + return _ffi_api.adaptive_avg_pool2d(data, output_size, layout, out_layout) # type: ignore + + +def relu(data: Expr) -> Expr: + """Rectified linear unit. + + .. math:: + text{ReLU}(x) = max(x, 0) + + Parameters + ---------- + data : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.relu(data) # type: ignore + + +def gelu(data: Expr) -> Expr: + """Gaussian Error Linear Units function + + .. math:: + text{GeLU}(x) = 0.5 * x * (1 + erf(x * 0.5**0.5)) + + where :math:`erf` is the Gauss Error function. + + Parameters + ---------- + data : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.gelu(data) # type: ignore + + +def silu(data: Expr) -> Expr: + """Sigmoid Linear Unit function + + .. math:: + text{SiLU}(x) = x * sigmoid(x) + + Parameters + ---------- + data : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.silu(data) # type: ignore + + +def softmax(data: Expr, axis: int = -1) -> Expr: + r"""Computes softmax. + + .. math:: text{softmax}(x)_i = frac{exp(x_i)}{\sum_j exp(x_j)} + + Parameters + ---------- + data: relax.Expr + The input data to the operator. + + axis: int + The axis to sum over when computing softmax. + If not specified, it is by default the last axis of the input tensor. + Supports negative indexing. + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.softmax(data, axis) # type: ignore + + +def log_softmax(data: Expr, axis: int = -1) -> Expr: + r"""Computes log softmax. + + .. math:: + + \text{log\_softmax}(x_i) = \log\left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}\right) + + .. note:: + This operator can be optimized away for inference. + + Parameters + ---------- + data: relax.Expr + The input data to the operator. + + axis: int + The axis to sum over when computing log softmax. + If not specified, it is by default the last axis of the input tensor. + Supports negative indexing. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.log_softmax(data, axis) # type: ignore + + +def batch_norm( + data: Expr, + gamma: Expr, + beta: Expr, + moving_mean: Expr, + moving_var: Expr, + axis: int, + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, +) -> Expr: + r""" + Batch normalization layer (Ioffe and Szegedy, 2014). + Normalizes the input at each batch, i.e. applies a transformation + that maintains the mean activation close to 0 and the activation + standard deviation close to 1. + + .. math:: + + data\_mean[i] = mean(data[:,i,:,...]) \\ + data\_var[i] = var(data[:,i,:,...]) + + Then compute the normalized output, which has the same shape as input, as following: + + .. math:: + + out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} + * gamma[i] + beta[i] + + Both *mean* and *var* returns a scalar by treating the input as a vector. + + Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` + have shape *(k,)*. + + Besides the inputs and the outputs, this operator accepts two auxiliary + states, ``moving_mean`` and ``moving_var``, which are *k*-length + vectors. They are global statistics for the whole dataset, which are updated by + + .. code:: python + + moving_mean = moving_mean * momentum + data_mean * (1 - momentum) + moving_var = moving_var * momentum + data_var * (1 - momentum) + + The parameter ``axis`` specifies which axis of the input shape denotes + the 'channel' (separately normalized groups). The default is 1. + Specifying -1 sets the channel axis to be the last item in the input shape. + + .. note:: + + This operator can be optimized away for inference. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + moving_mean : relax.Expr + Running mean of input. + + moving_var : relax.Expr + Running variance of input. + + axis : int + The axis along which the normalization is applied. + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.batch_norm( # type: ignore + data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale + ) + + +def layer_norm( + data: Expr, + gamma: Expr, + beta: Expr, + axes: Union[int, List[int]], + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, +) -> Expr: + r""" + Layer normalization (Lei Ba and et al., 2016). + Applies layer normalization to the n-dimensional input array. + This operator takes an n-dimensional input array and normalizes + the input using the given axis: + + .. math:: + + out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}} + * gamma + beta + + Unlike batch normalization, the mean and var are computed along the channel dimension. + + Assume the input has size k on axis 1, then both gamma and beta have shape (k,). + + .. note:: + + This operator can be optimized away for inference. + + Parameters + ---------- + data : relax.Expr + Input to which layer_norm will be applied. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + axes : Union[int, List[int]] + The axes that along which the normalization is applied. + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axes, int): + axes = [axes] + return _ffi_api.layer_norm(data, gamma, beta, axes, epsilon, center, scale) # type: ignore + + +def group_norm( + data: Expr, + gamma: Expr, + beta: Expr, + num_groups: int, + channel_axis: int, + axes: Union[int, List[int]], + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, +) -> Expr: + r""" + Group normalization (Yuxin Wu and et al., 2016). + Applies group normalization to the n-dimensional input array. + This operator takes an n-dimensional input array. First separate the input array + into groups along the channel axis. Then apply layer normalization to each group. + + Parameters + ---------- + data : relax.Expr + Input to which group_norm will be applied. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + num_groups : int + Number of groups to separate the channels into. + + channel_axis : int + The index of the channel axis in the input data. + + axes : Union[int, List[int]] + The axes that along which the normalization is applied (excluding the group axis) + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axes, int): + axes = [axes] + return _ffi_api.group_norm( # type: ignore + data, gamma, beta, num_groups, channel_axis, axes, epsilon, center, scale + ) + + +def dropout(data: Expr, rate: float = 0.5) -> Expr: + """Applies the dropout operation to the input tensor. + + During training, each element of the input is set to zero with + probability ``p``. The whole array is scaled by ``1/(1-p)`` + to keep the expected sum of the input unchanged. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + rate : float + The probability for an element to be reset to 0. + + Returns + ------- + result : relax.Expr + The result of dropout, which is a tuple of two tensors. + The first one is the original tensor and the second one is a + mask tensor (1.0 where element not dropped, 0.0 where dropped) + """ + return _ffi_api.dropout(data, rate) # type: ignore + + +def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr: + r"""CrossEntropy with logits between the predictions and labels. + + The shape of predictions and labels must be the same. And when ndim >= 2, + the first dimension is regarded as the batch_size N. In this case the + computed result will divide by N to perform a mean reduction. + + .. math:: + + \text{cross\_entropy\_with\_logits}(x_i, y_i) = \frac{\sum_i -x_i \cdot y_i}{N} + + Parameters + ---------- + predictions : relax.Expr + The predictions. + + labels : relax.Expr + The labels (the ground truth values). + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.cross_entropy_with_logits(predictions, labels) # type: ignore + + +def attention( + query: Expr, + key: Expr, + value: Expr, + bias: Optional[Expr] = None, + scale: Optional[FloatImm] = None, +) -> Expr: + r"""Computes fused multi head attention. + + All input tensors are of 4-D tensors with BSNH layout. + + .. math:: + FMA(Q, K, V) = \text{Softmax}(Q @ K^T) @ V + + .. note:: + The input tensor is required to have float16 dtype + + Parameters + ---------- + query: relax.Expr + The input query to the operator. The layout of the input query should be + (batch_size, seq_len, num_head, head_dim). + + key: relax.Expr + The input key to the operator. The layout of the input key should be + (batch_size, seq_len_kv, num_head, head_dim). + + value: relax.Expr + The input value to the operator. The layout of the input value should be + (batch_size, seq_len_kv, num_head, head_dim_v). + + bias: Optional[Expr] + The optional attention bias to the operator. The layout of the attention bias should be + (batch_size, num_head, seq_len, seq_len_kv), + (batch_size, seq_len, seq_len_kv) or (batch_size, seq_len_kv). + + scale: Optional[FloatImm] + The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim). + + Returns + ------- + result : relax.Expr + The computed result. The layout of the output should be + (batch_size, seq_len, num_head, head_dim_v). + """ + return _ffi_api.attention(query, key, value, bias, scale) # type: ignore diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py new file mode 100644 index 000000000000..2d0fdd14b34b --- /dev/null +++ b/python/tvm/relax/op/op_attrs.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The attributes node used for Relax operators""" +from tvm.ir import Attrs +import tvm._ffi + + +@tvm._ffi.register_object("relax.attrs.InitAttrs") +class InitAttrs(Attrs): + """Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operator""" + + +@tvm._ffi.register_object("relax.attrs.TriluAttrs") +class TriluAttrs(Attrs): + """Attributes used in tril and triu operator""" + + +@tvm._ffi.register_object("relax.attrs.AstypeAttrs") +class AstypeAttrs(Attrs): + """Attributes used in astype operator""" + + +@tvm._ffi.register_object("relax.attrs.TakeAttrs") +class TakeAttrs(Attrs): + """Attributes used in take operator""" + + +@tvm._ffi.register_object("relax.attrs.StridedSliceAttrs") +class StridedSliceAttrs(Attrs): + """Attributes used in strided_slice operator""" + + +@tvm._ffi.register_object("relax.attrs.MatmulAttrs") +class MatmulAttrs(Attrs): + """Attributes for matmul operator""" + + +@tvm._ffi.register_object("relax.attrs.Conv2DAttrs") +class Conv2DAttrs(Attrs): + """Attributes for nn.conv2d""" + + +@tvm._ffi.register_object("relax.attrs.Conv2DTransposeAttrs") +class Conv2DTransposeAttrs(Attrs): + """Attributes for nn.conv2d_transpose""" + + +@tvm._ffi.register_object("relax.attrs.Pool2DAttrs") +class Pool2DAttrs(Attrs): + """Attributes for nn.max_pool2d""" + + +@tvm._ffi.register_object("relax.attrs.AdaptivePool2DAttrs") +class AdaptivePool2DAttrs(Attrs): + """Attributes for 2d adaptive pool operator""" + + +@tvm._ffi.register_object("relax.attrs.SoftmaxAttrs") +class SoftmaxAttrs(Attrs): + """Attributes for nn.softmax""" + + +@tvm._ffi.register_object("relax.attrs.BatchNormAttrs") +class BatchNormAttrs(Attrs): + """Attributes used in batch_norm operator""" + + +@tvm._ffi.register_object("relax.attrs.LayerNormAttrs") +class LayerNormAttrs(Attrs): + """Attributes used in layer_norm operator""" + + +@tvm._ffi.register_object("relax.attrs.DropoutAttrs") +class DropoutAttrs(Attrs): + """Attributes for dropout operator""" + + +@tvm._ffi.register_object("relax.attrs.StatisticalAttrs") +class StatisticalAttrs(Attrs): + """Attributes used in statistical operator""" + + +@tvm._ffi.register_object("relax.attrs.ConcatAttrs") +class ConcatAttrs(Attrs): + """Attributes for concat operator""" + + +@tvm._ffi.register_object("relax.attrs.ExpandDimsAttrs") +class ExpandDimsAttrs(Attrs): + """Attributes for expand_dims operator""" + + +@tvm._ffi.register_object("relax.attrs.PermuteDimsAttrs") +class PermuteDimsAttrs(Attrs): + """Attributes for permute_dims operator""" + + +@tvm._ffi.register_object("relax.attrs.SplitAttrs") +class SplitAttrs(Attrs): + """Attributes used in split operator""" + + +@tvm._ffi.register_object("relax.attrs.SqueezeAttrs") +class SqueezeAttrs(Attrs): + """Attributes for squeeze operator""" + + +@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs") +class LayoutTransformAttrs(Attrs): + """Attributes used in layout_transform operator""" + + +@tvm._ffi.register_object("relax.attrs.Resize2DAttrs") +class Resize2DAttrs(Attrs): + """Attributes used in image resize2d operator""" + + +@tvm._ffi.register_object("relax.attrs.ArgmaxArgminAttrs") +class ArgmaxArgminAttrs(Attrs): + """Attributes for argmax/argmin operator""" + + +@tvm._ffi.register_object("relax.attrs.RepeatAttrs") +class RepeatAttrs(Attrs): + """Attributes for repeat operator""" + + +@tvm._ffi.register_object("relax.attrs.TileAttrs") +class TileAttrs(Attrs): + """Attributes for tile operator""" diff --git a/python/tvm/relax/op/search.py b/python/tvm/relax/op/search.py new file mode 100644 index 000000000000..b097d78234d5 --- /dev/null +++ b/python/tvm/relax/op/search.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Search operators.""" +from typing import Optional + +from . import _ffi_api +from ..expr import Expr + + +def where(condition: Expr, x1: Expr, x2: Expr) -> Expr: + """Selecting elements from either the input tensors depending on the value of the + condition. + + For a given position, return the corresponding value in `x1` if `condition` is True, + and return the corresponding value in `x2` otherwise. + + Parameters + ---------- + condition : relax.Expr + When True, yield `x1`; otherwise, yield `x2`. + Must be broadcasting compatible with `x1` and `x2`. + Must have boolean dtype. + + x1 : relax.Expr + The first input tensor. + Must be broadcasting compatible with `condition` and `x2`. + + x2 : relax.Expr + The second input tensor. + Must be broadcasting compatible with `condition` and `x1`. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.where(condition, x1, x2) # type: ignore + + +def argmax(x: Expr, axis: Optional[int] = None, keepdims: bool = False) -> Expr: + """Computes the argmax of tensor elements over given axis. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[int] + Axis along which an argmax operation is performed. + The default, axis=None, will compute the argmax of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axis being reduced is left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.argmax(x, axis, keepdims) # type: ignore + + +def argmin(x: Expr, axis: Optional[int] = None, keepdims: bool = False) -> Expr: + """Computes the argmin of tensor elements over given axis. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[int] + Axis along which an argmin operation is performed. + The default, axis=None, will compute the argmin of all elements in the + input tensor. Negative indexing is supported. + + keepdims : bool + If this is set to True, the axis being reduced is left in the result as + dimensions with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.argmin(x, axis, keepdims) # type: ignore diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py new file mode 100644 index 000000000000..4d106ad6d23c --- /dev/null +++ b/python/tvm/relax/op/set.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel, redefined-builtin, unused-argument +"""Set operators.""" +from typing import Optional, Union + +import numpy as np # type: ignore +import tvm + +from . import _ffi_api +from ..expr import Expr, PrimValue + + +def unique( + x: Expr, + sorted: Union[bool, Expr] = True, + return_index: Union[bool, Expr] = False, + return_inverse: Union[bool, Expr] = False, + return_counts: Union[bool, Expr] = False, + axis: Optional[Union[int, Expr]] = None, +) -> Expr: + """Find the unique elements in a given tensor. + In addition, it optionally returns + - the indices of the input tensor that give the unique values; + - the indices of the unique tensor that reconstruct the input tensor; + - the number of times each unique value comes up in the input tensor. + + Parameters + ---------- + x : relax.Expr + The input tensor. + + sorted : Union[bool, Expr] + Whether to sort the unique elements in ascending order before + returning as output. + + return_index : Union[bool, Expr] + Whether to return an additional tensor with indices for where elements in + the unique tensor come from the original input. + + return_inverse : Union[bool, Expr] + Whether to return an additional tensor with indices for where elements in + the original input ended up in the returned unique list. + + return_counts : Union[bool, Expr] + Whether to return an additional tensor with counts of each unique elements. + + axis : Optional + The dimension to apply unique. + If not specified, the unique values of the flattened input are returned. + + Returns + ------- + ret : relax.Expr + The created relax call with + """ + + if isinstance(sorted, bool): + sorted = PrimValue(sorted) + if isinstance(return_index, bool): + return_index = PrimValue(return_index) + if isinstance(return_inverse, bool): + return_inverse = PrimValue(return_inverse) + if isinstance(return_counts, bool): + return_counts = PrimValue(return_counts) + if axis and isinstance(axis, int): + axis = PrimValue(axis) + return _ffi_api.unique( # type: ignore + x, sorted, return_index, return_inverse, return_counts, axis + ) + + +@tvm.register_func("relax.run.unique") +def numpy_unique( + x: tvm.nd.array, + sorted: int, + return_index: int, + return_inverse: int, + return_counts: int, +) -> tvm.nd.array: + """Returns the unique elements of the input tensor. + + Uses numpy.unique to compute unique elements. + """ + import builtins + + # TODO(prakalp): add support for returning a tuple when return_inverse or return_counts is True + if bool(return_index) or bool(return_inverse) or bool(return_counts): + raise NotImplementedError("missing support return_inverse or return_counts set to true") + x_numpy = x.numpy() + # TODO(prakalp): use torch.unique instead of numpy when torch is installed in ci. + output_sorted_numpy, indices = np.unique(x_numpy, return_index=True) + if sorted: + return tvm.nd.array(output_sorted_numpy) + output_numpy = [x_numpy.flatten()[index] for index in builtins.sorted(indices, reverse=True)] + return tvm.nd.array(output_numpy) diff --git a/python/tvm/relax/op/statistical.py b/python/tvm/relax/op/statistical.py new file mode 100644 index 000000000000..4669c783adda --- /dev/null +++ b/python/tvm/relax/op/statistical.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +"""Statistical operators.""" +from typing import List, Optional, Union + +from . import _ffi_api +from ..expr import Expr + + +def max(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the max of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a max operation is performed. + The default, axis=None, will compute the max of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.max(x, axis, keepdims) # type: ignore + + +def mean(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the mean of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a mean operation is performed. + The default, axis=None, will compute the mean of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.mean(x, axis, keepdims) # type: ignore + + +def min(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the min of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a min operation is performed. + The default, axis=None, will compute the min of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.min(x, axis, keepdims) # type: ignore + + +def prod(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the product of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a product is performed. + The default, axis=None, will compute the product of all elements of the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.prod(x, axis, keepdims) # type: ignore + + +def std(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the standard deviation of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a standard deviation is performed. + The default, axis=None, will compute the std of all elements of the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.std(x, axis, keepdims) # type: ignore + + +def sum(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the sum of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a sum is performed. + The default, axis=None, will sum all of the elements of the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.sum(x, axis, keepdims) # type: ignore + + +def variance(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the variance of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a variance operation is performed. + The default, axis=None, will compute the variance of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.variance(x, axis, keepdims) # type: ignore diff --git a/python/tvm/relax/op/ternary.py b/python/tvm/relax/op/ternary.py new file mode 100644 index 000000000000..7c320cc1ca48 --- /dev/null +++ b/python/tvm/relax/op/ternary.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, invalid-name +"""Relax ternary arithmetic operators.""" +from . import _ffi_api +from ..expr import Expr + + +def ewise_fma(x1: Expr, x2: Expr, x3: Expr) -> Expr: + """Elementwise fused multiply-add operator + Returns elementwise result of :math:`x1 * x2 + x3` + + Parameters + ---------- + x1 : relax.Expr + The left hand operand of the multiplication + + x2 : relax.Expr + The right hand operand of the multiplication + + x3 : relax.Expr + The operand of the addition + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.ewise_fma(x1, x2, x3) # type: ignore diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py new file mode 100644 index 000000000000..866d2a8273d6 --- /dev/null +++ b/python/tvm/relax/op/unary.py @@ -0,0 +1,529 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, invalid-name +"""Relax unary arithmetic operators.""" +from . import _ffi_api +from ..expr import Expr +from ..utils import args_converter + +###################### Arithmetic operators ###################### + + +def abs(x: Expr) -> Expr: + """Compute element-wise absolute value of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.abs(x) # type: ignore + + +def acos(x: Expr) -> Expr: + """Compute element-wise arc cos of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.acos(x) # type: ignore + + +def acosh(x: Expr) -> Expr: + """Compute element-wise arc cosh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.acosh(x) # type: ignore + + +def asin(x: Expr) -> Expr: + """Compute element-wise arc sin of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.asin(x) # type: ignore + + +def asinh(x: Expr) -> Expr: + """Compute element-wise arc sinh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.asinh(x) # type: ignore + + +def atan(x: Expr) -> Expr: + """Compute element-wise arc tan of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.atan(x) # type: ignore + + +def atanh(x: Expr) -> Expr: + """Compute element-wise arc tanh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.atanh(x) # type: ignore + + +def ceil(x: Expr) -> Expr: + """Take ceil of input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.ceil(x) # type: ignore + + +def cos(x: Expr) -> Expr: + """Compute element-wise cos of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.cos(x) # type: ignore + + +def cosh(x: Expr) -> Expr: + """Compute element-wise cosh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.cosh(x) # type: ignore + + +def exp(x: Expr) -> Expr: + """Compute element-wise exp of data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.exp(x) # type: ignore + + +def floor(x: Expr) -> Expr: + """Take floor of input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.floor(x) # type: ignore + + +def log(x: Expr) -> Expr: + """Compute element-wise natural logarithm of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.log(x) # type: ignore + + +def negative(x: Expr) -> Expr: + """Compute element-wise negative of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result + """ + return _ffi_api.negative(x) # type: ignore + + +def round(x: Expr) -> Expr: + """Rounds each element of the input data to nearest integer. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.round(x) # type: ignore + + +def sigmoid(x: Expr) -> Expr: + """Compute element-wise sigmoid of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.sigmoid(x) # type: ignore + + +def sign(x: Expr) -> Expr: + """Returns an indication of the sign of a number for each element of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.sign(x) # type: ignore + + +def sin(x: Expr) -> Expr: + """Compute element-wise sin of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.sin(x) # type: ignore + + +def sinh(x: Expr) -> Expr: + """Compute element-wise sinh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.sinh(x) # type: ignore + + +def square(x: Expr) -> Expr: + """Squares each element of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.square(x) # type: ignore + + +def sqrt(x: Expr) -> Expr: + """Compute element-wise square root of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.sqrt(x) # type: ignore + + +def tan(x: Expr) -> Expr: + """Compute element-wise tan of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.tan(x) # type: ignore + + +def tanh(x: Expr) -> Expr: + """Compute element-wise tanh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.tanh(x) # type: ignore + + +@args_converter.auto +def clip(x: Expr, min: Expr, max: Expr) -> Expr: + """Clips tensor values to a specified min and max. + + Parameters + ---------- + x : relax.Expr + The input data + + min : relax.Expr + The minimum value + + max : relax.Expr + The maximum value + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.clip(x, min, max) # type: ignore + + +###################### Check operators ###################### + + +def isfinite(x: Expr) -> Expr: + """Check if input value is finite. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.isfinite(x) # type: ignore + + +def isinf(x: Expr) -> Expr: + """Check if input value is infinite. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.isinf(x) # type: ignore + + +def isnan(x: Expr) -> Expr: + """Check if input value is Nan. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.isnan(x) # type: ignore diff --git a/python/tvm/relax/op/vm/__init__.py b/python/tvm/relax/op/vm/__init__.py new file mode 100644 index 000000000000..ecb2857a893c --- /dev/null +++ b/python/tvm/relax/op/vm/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax vm primitives.""" + +from .vm import * diff --git a/python/tvm/relax/op/vm/_ffi_api.py b/python/tvm/relax/op/vm/_ffi_api.py new file mode 100644 index 000000000000..786b73c76c64 --- /dev/null +++ b/python/tvm/relax/op/vm/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.relax.op.vm""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.vm", __name__) diff --git a/python/tvm/relax/op/vm/vm.py b/python/tvm/relax/op/vm/vm.py new file mode 100644 index 000000000000..a20407a4c94e --- /dev/null +++ b/python/tvm/relax/op/vm/vm.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""Relax vm primitives.""" + +from typing import Union +from . import _ffi_api +from ...expr import Expr, Call, PrimValue, DataTypeImm, Tuple +from ...utils import args_converter + + +@args_converter.auto +def alloc_storage( + size: Expr, + runtime_device_index: Union[int, Expr], + dtype: Union[str, Expr], +) -> Call: + """Construct a Call to allocate a storage with specific size, + runtime_device_index, and dtype. + + Parameters + ---------- + size : Expr + The size of the storage to be allocated. + + runtime_device_index : Union[int, Expr] + The device index indicating on which device the tensor is to + be allocated at runtime. Index -1 is reserved for the host device. + + dtype : Union[str, Expr] + The datatype of the storage to be allocated. + + Returns + ------- + result : Call + A relax Call, which gets the allocated storage. + """ + if isinstance(dtype, str): + dtype = DataTypeImm(dtype) + if isinstance(runtime_device_index, int): + runtime_device_index = PrimValue(runtime_device_index) + return _ffi_api.alloc_storage(size, runtime_device_index, dtype) # type: ignore + + +@args_converter.auto +def alloc_tensor( + storage: Expr, offset: Union[int, Expr], shape: Expr, dtype: Union[str, Expr] +) -> Call: + """Construct a Call to allocate a tensor on a certain storage starting from the given offset. + + Parameters + ---------- + storage : Expr + The storage to allocate the tensor to. + + offset : Union[int, Expr] + The storage offset to allocate the tensor. + + shape : Expr + The shape of the tensor to be allocated. + + dtype : Union[str, Expr] + The datatype of the tensor to be allocated. + + Returns + ------- + result : Call + A relax Call, which gets the allocated tensor. + """ + if isinstance(offset, int): + offset = PrimValue(offset) + if isinstance(dtype, str): + dtype = DataTypeImm(dtype) + return _ffi_api.alloc_tensor(storage, offset, shape, dtype) # type: ignore + + +@args_converter.auto +def call_tir_dyn(func: Expr, args: Tuple) -> Call: + """Construct a Call to call_tir_dyn (invoke the given TIR PrimFunc) + consisting of the input tensors and the shape of the result. + + Parameters + ---------- + func : Expr + An expression evaluating to a TIR PrimFunc. + + args : Tuple + The input args, includes a list of tensors, and a ShapeExpr. + + Returns + ------- + result : Call + A relax Call to call_tir_dyn. + """ + if isinstance(args, (list, tuple)): + args = Tuple(args) + + return _ffi_api.call_tir_dyn(func, args) # type: ignore diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py new file mode 100644 index 000000000000..a5da15b76d3b --- /dev/null +++ b/python/tvm/relax/pipeline.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Pre-defined pipelines. + +oRelax enables flexible pipeline optimizations before min build. +This namespace offers a pre-defined collection that can be used +as it is or serves as a basis to do further composition. +""" +# pylint: disable=unused-argument +import tvm +from tvm import meta_schedule as ms +from . import transform + + +@tvm.transform.module_pass(opt_level=0) +def zero_pipeline(mod: tvm.ir.IRModule, ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + """Pipeline that applies pre-tuned logs. + + Parameters + ---------- + mod : tvm.ir.IRModule + Input IRModule. + + ctx : tvm.transform.PassContext + The pass context + + Returns + ------- + mod: tvm.ir.IRModule + The result transformed module. + """ + seq = tvm.transform.Sequential( + [ + transform.LegalizeOps(), + transform.AnnotateTIROpPattern(), + transform.FoldConstant(), + transform.FuseOps(), + transform.FuseTIR(), + ] + ) + mod = seq(mod) + if ms.Database.current(): + mod = transform.MetaScheduleApplyDatabase()(mod) + return mod + + +# global map of pre-built pipelines +PIPELINE_MAP = {"zero": zero_pipeline} + + +def get_pipeline(name: str = "zero") -> tvm.transform.Pass: + """Get pre-build pipeline by name + + Parameters + ---------- + name : Optional[str] + Name of the pipeline + + Returns + ------- + pipeline: tvm.transform.Pass + The transformation pipeline. + """ + + if name in PIPELINE_MAP: + return PIPELINE_MAP[name] + else: + raise ValueError( + f"Unknown pre-built pipeline {name}," f"candidates are {list(PIPELINE_MAP.keys())}" + ) diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py new file mode 100644 index 000000000000..2ff027b22924 --- /dev/null +++ b/python/tvm/relax/struct_info.py @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import +"""The struct info nodes of the Relax language.""" +from typing import List, Optional, Tuple, Union + +import tvm._ffi +import tvm + +from tvm.ir import Span, Node, EnvFunc, Array, Type +from tvm.tir import PrimExpr +from .expr import StructInfo, Var, Expr, ShapeExpr + +from . import _ffi_api, ty, expr + + +@tvm._ffi.register_object("relax.ObjectStructInfo") +class ObjectStructInfo(StructInfo): + """StructInfo of an Object.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ObjectStructInfo, span) # type: ignore + + +@tvm._ffi.register_object("relax.PrimStructInfo") +class PrimStructInfo(StructInfo): + """StructInfo of a primitive POD value. + + Parameters + ---------- + dtype : str + The data type of the prim value. + """ + + dtype: str + + def __init__(self, dtype: str, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.PrimStructInfo, dtype, span) # type: ignore + + +@tvm._ffi.register_object("relax.ShapeStructInfo") +class ShapeStructInfo(StructInfo): + """StructInfo of a shape value. + + Parameters + ---------- + values : Optional[List[PrimExpr]] + The symbolic shape values if known. + + ndim : Optional[int] + The size of the shape. + + Note + ---- + Do not specify values and ndim at the same time. + """ + + values: Optional[List[PrimExpr]] + ndim: int + span: Span + + def __init__( + self, values: Optional[List[PrimExpr]] = None, ndim: int = -1, span: Span = None + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ShapeStructInfo, values, ndim, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.TensorStructInfo") +class TensorStructInfo(StructInfo): + """StructInfo of a Tensor value. + + Parameters + ---------- + shape : Optional[Expr] + The shape expression. + + dtype : Optional[str] + The content data type. + + ndim : Optional[int] + The number of dimensions of the tensor. + + Note + ---- + Do not specify shape and ndim at the same time. + """ + + shape: Optional[Expr] + dtype: str + ndim: int + span: Span + + def __init__( + self, + shape: Union[Optional[Expr], List[PrimExpr]] = None, + dtype: str = "float32", + ndim: int = -1, + span: Span = None, + ) -> None: + if isinstance(shape, (list, tuple, Array)): + shape = ShapeExpr(shape) + + self.__init_handle_by_constructor__( + _ffi_api.TensorStructInfo, shape, dtype, ndim, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.TupleStructInfo") +class TupleStructInfo(StructInfo): + """StructInfo of a Tuple value. + + Parameters + ---------- + fields: List[StructInfo] + The struct info of the fields. + """ + + fields: List[StructInfo] + span: Span + + def __init__(self, fields: List[StructInfo], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.TupleStructInfo, fields, span) # type: ignore + + +@tvm._ffi.register_object("relax.FuncStructInfo") +class FuncStructInfo(StructInfo): + """StructInfo of a function value. + + Parameters + ---------- + params: List[StructInfo] + The struct info of the fields. + + ret: StructInfo + The struct info of return value + """ + + params: Optional[List[StructInfo]] + ret: StructInfo + derive_func: Optional[EnvFunc] + span: Span + + def __init__(self, params: List[StructInfo], ret: StructInfo, span: Span = None) -> None: + self.__init_handle_by_constructor__( + _ffi_api.FuncStructInfo, params, ret, span # type: ignore + ) + + @staticmethod + def opaque_func( + *, + ret: Optional[StructInfo] = None, + derive_func: Optional[EnvFunc] = None, + span: Span = None, + ) -> "FuncStructInfo": + """ + Create an opaque FuncStructInfo. + + The opaque function takes either a ret + that specificies the struct info of the return value + or a derive_func that provides a customized derivation rule. + + Parameters + ---------- + ret: Optional[StructInfo] + The struct info of the the function return value. + + derive_func: Optional[EnvFunc] + The environment function used for derivation + + span: Optional[Span] + Optional span information of the ast. + + Returns + ------- + info: FuncStructInfo + + Note + ---- + We cannot specify ret and derive_func simultaneously. + """ + return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, span) # type: ignore diff --git a/python/tvm/relax/testing/__init__.py b/python/tvm/relax/testing/__init__.py new file mode 100644 index 000000000000..a6e3a9425147 --- /dev/null +++ b/python/tvm/relax/testing/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""The Relax testing namespace containing nn and translator.""" + +from .nn import * +from .relay_translator import * +from .ast_printer import dump_ast diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py new file mode 100644 index 000000000000..6727b2429202 --- /dev/null +++ b/python/tvm/relax/testing/ast_printer.py @@ -0,0 +1,372 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, abstract-method, arguments-differ +""" +Utility script for printing Relax modules as AST diagrams, +only intended to show how the AST is put together. +It is not a pretty-printer and, in fact, is more of an ugly-printer, +but it can be useful for tutorials and debugging. +""" +from typing import Iterable +import tvm +from tvm import relax +from tvm.ir.expr import PrimExpr +from tvm.relax import ExprFunctor + + +def wrap_quotes(text: str) -> str: + """ + Wraps the text in quotes. + """ + return f'"{text}"' + + +class ASTPrinter(ExprFunctor): + """ + Class for recursing down ASTs and printing them in a very simple format, + mainly for instructive purposes and, perhaps, debugging. + """ + + def __init__( + self, + indent_str=" ", + include_struct_info_annotations=True, + include_type_annotations=False, + include_call_attrs=True, + ): + self.indent_str = indent_str + self.include_type_annotations = include_type_annotations + self.include_struct_info_annotations = include_struct_info_annotations + self.include_call_attrs = include_call_attrs + + def visit_expr(self, expr: relax.Expr) -> str: + # extend so we also dispatch to bindings and binding blocks, + # a little silly but IRFunctor hasn't been ported to Python + if isinstance(expr, relax.DataflowBlock): + return self.visit_dataflow_block_(expr) + if isinstance(expr, relax.BindingBlock): + return self.visit_binding_block_(expr) + if isinstance(expr, relax.Binding): + return self.visit_binding_(expr) + return super().visit_expr(expr) + + def indent(self, text: str) -> str: + """ + Indent all lines of the input. + """ + if text == "": + return "" + lines = text.split("\n") + return self.indent_str + f"\n{self.indent_str}".join(lines) + + def build_ast_node(self, nodename: str, force_newline=False, **kwargs: str) -> str: + """ + Returns 'nodename(..., fields[i][0]=fields[i][1], ...)' + with appropriate indentation + """ + return self.build_list( + map(lambda field: f"{field[0]}={field[1]}", kwargs.items()), + open_tok=f"{nodename}(", + close_tok=")", + force_newline=force_newline, + ) + + def build_expr(self, node: relax.Expr, nodename: str, force_newline=False, **kwargs: str): + """ + Renders a Relax expression as a string using `build_ast_node`. + Handles whether to include the checked_type_ and struct_info fields. + """ + fields = kwargs.copy() + if node.struct_info_ and self.include_struct_info_annotations: + fields["struct_info"] = self.visit_struct_info_(node.struct_info) + if node._checked_type_ and self.include_type_annotations: + fields["checked_type_"] = self.visit_type_(node.checked_type) + return self.build_ast_node(nodename, force_newline=force_newline, **fields) + + def build_list( + self, members: Iterable[str], open_tok="[", close_tok="]", force_newline=False + ) -> str: + """ + Builds a list of the members given, appropriately indented, + with each field on a line. + (special case: if there is only one field, then we do not put it on a new line + unless that field contains a newline or `force_newline` is set to true). + `open_tok` and `close_tok` are used to open and close the list, respectively. + """ + mem_list = list(members) + if not mem_list: + return f"{open_tok}{close_tok}" + if len(mem_list) == 1 and not force_newline and "\n" not in mem_list[0]: + return f"{open_tok}{mem_list[0]}{close_tok}" + member_lines = ",\n".join(map(self.indent, mem_list)) + return f"{open_tok}\n{member_lines}\n{close_tok}" + + def visit_constant_(self, op: relax.Constant) -> str: + # simple rule of thumb: keep scalars inline, but anything larger goes on a new one + force_newline = len(op.data.shape) > 0 + return self.build_expr(op, "Constant", force_newline=force_newline, data=str(op.data)) + + def visit_tuple_(self, op: relax.Tuple) -> str: + return self.build_expr(op, "Tuple", fields=self.build_list(map(self.visit_expr, op.fields))) + + def visit_dataflow_var_(self, op: relax.DataflowVar) -> str: + return self.build_expr(op, "DataflowVar", name_hint=wrap_quotes(op.name_hint)) + + def visit_var_(self, op: relax.Var) -> str: + return self.build_expr(op, "Var", name_hint=wrap_quotes(op.name_hint)) + + def visit_shape_expr_(self, op: relax.ShapeExpr) -> str: + return self.build_expr( + op, "ShapeExpr", values=self.build_list(map(self.visit_prim_expr_, op.values)) + ) + + def visit_extern_func_(self, op: relax.ExternFunc) -> str: + # ExternFunc does not inherit from relax.Expr either, + # so it doesn't have checked_type_ or struct_info fields and we don't use build_expr + return self.build_ast_node("ExternFunc", global_symbol=wrap_quotes(op.global_symbol)) + + def visit_global_var_(self, op: relax.GlobalVar) -> str: + return self.build_expr(op, "GlobalVar", name_hint=wrap_quotes(op.name_hint)) + + def visit_function_(self, op: relax.Function) -> str: + fields = { + "params": self.build_list(map(self.visit_expr, op.params)), + "body": self.visit_expr(op.body), + "ret_struct_info": self.visit_struct_info_(op.ret_struct_info), + } + if op.attrs: + fields["attrs"] = self.build_list( + map( + lambda kv: f"{wrap_quotes(str(kv[0]))}: {wrap_quotes(str(kv[1]))}", + op.attrs.items(), + ), + open_tok="{", + close_tok="}", + ) + return self.build_expr(op, "Function", **fields) + + def visit_call_(self, op: relax.Call) -> str: + fields = { + "op": self.visit_expr(op.op), + "args": self.build_list(map(self.visit_expr, op.args)), + } + if op.sinfo_args: + fields["sinfo_args"] = self.build_list(map(self.visit_struct_info_, op.sinfo_args)) + if op.attrs and self.include_call_attrs: + + def display_attrs(attr_key): + attr_val = op.attrs[attr_key] + # attrs can be strings but also other types; + # we want to wrap strings in quotes + # (__repr__ would work but it uses single quotes) + attr_str = wrap_quotes(attr_val) if isinstance(attr_val, str) else str(attr_val) + return f"{wrap_quotes(attr_key)}: {attr_str}" + + fields["attrs"] = self.build_list( + map(display_attrs, op.attrs.keys()), + open_tok="{", + close_tok="}", + ) + return self.build_expr(op, "Call", **fields) + + def visit_seq_expr_(self, op: relax.SeqExpr) -> str: + return self.build_expr( + op, + "SeqExpr", + blocks=self.build_list(map(self.visit_binding_block_, op.blocks)), + body=self.visit_expr(op.body), + ) + + def visit_if_(self, op: relax.If) -> str: + return self.build_expr( + op, + "If", + cond=self.visit_expr(op.cond), + true_branch=self.visit_expr(op.true_branch), + false_branch=self.visit_expr(op.false_branch), + ) + + def visit_prim_value_(self, op: relax.PrimValue) -> str: + return self.build_expr(op, "PrimValue", value=self.visit_prim_expr_(op.value)) + + def visit_string_imm_(self, op: relax.StringImm) -> str: + return self.build_expr(op, "StringImm", value=wrap_quotes(op.value)) + + def visit_data_type_imm_(self, op: relax.DataTypeImm) -> str: + return self.build_expr(op, "DataTypeImm", value=op.value) + + def visit_op_(self, op: tvm.ir.Op) -> str: + # TODO: List other attributes? + # op is not actually a Relax expr and does not have checked_type_ + # or struct_info fields, so we don't use build_expr here + return self.build_ast_node("Op", name=wrap_quotes(op.name)) + + def visit_prim_expr_(self, prim_expr: PrimExpr) -> str: + # TODO: We may want to print PrimExpr ASTs, but this is a simplification for now + return self.build_ast_node("PrimExpr", value=f"`{str(prim_expr)}`") + + def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> str: + return self.build_expr( + op, + "TupleGetItem", + tuple_value=self.visit_expr(op.tuple_value), + index=str(op.index), + ) + + def visit_type_(self, type_node: relax.Type) -> str: + """ + Recurse down types and print their ASTs too + """ + if isinstance(type_node, relax.ShapeType): + return self.build_ast_node("ShapeType", ndim=str(type_node.ndim)) + if isinstance(type_node, relax.ObjectType): + return self.build_ast_node("ObjectType") + if isinstance(type_node, relax.PackedFuncType): + return self.build_ast_node("PackedFuncType") + if isinstance(type_node, tvm.ir.PrimType): + return self.build_ast_node("PrimType", dtype=type_node.dtype) + if isinstance(type_node, relax.DynTensorType): + fields = {} + if type_node.ndim is not None: + fields["ndim"] = str(type_node.ndim) + if type_node.dtype != "": + fields["dtype"] = type_node.dtype + return self.build_ast_node("DynTensorType", **fields) + if isinstance(type_node, relax.TupleType): + return self.build_ast_node( + "TupleType", fields=self.build_list(map(self.visit_type_, type_node.fields)) + ) + if isinstance(type_node, relax.FuncType): + return self.build_ast_node( + "FuncType", + arg_types=self.build_list(map(self.visit_type_, type_node.arg_types)), + ret_type=self.visit_type_(type_node.ret_type), + # TODO: skipping type params and type constraints + ) + raise ValueError(f"Invalid Relax Type {type_node} ({type(type_node)})") + + def visit_struct_info_(self, struct_info_node: relax.StructInfo) -> str: + """ + Recurse down struct info and print their ASTs too + """ + if isinstance(struct_info_node, relax.ShapeStructInfo): + fields = {} + fields["ndim"] = str(struct_info_node.ndim) + if struct_info_node.values is not None: + fields["values"] = self.build_list( + map(self.visit_prim_expr_, struct_info_node.values) + ) + return self.build_ast_node("ShapeStructInfo", **fields) + elif isinstance(struct_info_node, relax.ObjectStructInfo): + return self.build_ast_node("ObjectStructInfo") + elif isinstance(struct_info_node, relax.PrimStructInfo): + return self.build_ast_node("PrimStructInfo", dtype=struct_info_node.dtype) + elif isinstance(struct_info_node, relax.TensorStructInfo): + fields = {} + fields["dtype"] = struct_info_node.dtype + if struct_info_node.shape: + fields["shape"] = self.visit_expr(struct_info_node.shape) + else: + fields["ndim"] = str(struct_info_node.ndim) + return self.build_ast_node("TensorStructInfo", **fields) + elif isinstance(struct_info_node, relax.TupleStructInfo): + return self.build_ast_node( + "TupleStructInfo", + fields=self.build_list(map(self.visit_struct_info_, struct_info_node.fields)), + ) + elif isinstance(struct_info_node, relax.FuncStructInfo): + fields = {} + if struct_info_node.params is not None: + fields["params"] = self.build_list( + map(self.visit_struct_info_, struct_info_node.params) + ) + fields["ret"] = self.visit_struct_info_(struct_info_node.ret) + return self.build_ast_node("FuncStructInfo", **fields) + else: + raise ValueError( + f"Invalid Relax StructInfo {struct_info_node} ({type(struct_info_node)})" + ) + + def visit_binding_block_(self, block: relax.BindingBlock) -> str: + """ + Recurse down binding blocks + """ + return self.build_ast_node( + "BindingBlock", + bindings=self.build_list(map(self.visit_binding_, block.bindings), force_newline=True), + ) + + def visit_dataflow_block_(self, block: relax.DataflowBlock) -> str: + """ + Recurse down a dataflow block + """ + return self.build_ast_node( + "DataflowBlock", + bindings=self.build_list(map(self.visit_binding_, block.bindings), force_newline=True), + ) + + def visit_binding_(self, binding: relax.Binding) -> str: + """ + Distinguish between binding types + """ + if isinstance(binding, relax.MatchCast): + return self.visit_match_cast_(binding) + if isinstance(binding, relax.VarBinding): + return self.visit_var_binding_(binding) + raise ValueError(f"Invalid binding type in {binding}: {type(binding)}") + + def visit_match_cast_(self, match_cast: relax.MatchCast) -> str: + """ + Handle match shape + """ + fields = { + "var": self.visit_expr(match_cast.var), + "value": self.visit_expr(match_cast.value), + "struct_info": self.visit_struct_info_(match_cast.struct_info), + } + return self.build_ast_node("MatchCast", **fields) + + def visit_var_binding_(self, var_binding: relax.VarBinding) -> str: + """ + Handle ordinary var bindings + """ + return self.build_ast_node( + "VarBinding", + var=self.visit_expr(var_binding.var), + value=self.visit_expr(var_binding.value), + ) + + +def dump_ast( + exp: relax.Expr, + indent_str=" ", + include_struct_info_annotations=True, + include_type_annotations=False, + include_call_attrs=True, +) -> str: + """ + Dump an AST in a text format. + Can vary the indentation string and choose whether to include + type and shape annotations or call attributes. + """ + printer = ASTPrinter( + indent_str=indent_str, + include_struct_info_annotations=include_struct_info_annotations, + include_type_annotations=include_type_annotations, + include_call_attrs=include_call_attrs, + ) + return printer.visit_expr(exp) diff --git a/python/tvm/relax/testing/lib_comparator.py b/python/tvm/relax/testing/lib_comparator.py new file mode 100644 index 000000000000..a9cecc69dc6f --- /dev/null +++ b/python/tvm/relax/testing/lib_comparator.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Tools to compare libraries.""" +from typing import List, Tuple, Iterable, Union + +import tvm +import tvm.testing + + +class LibCompareVMInstrument: + """Instrument class to compare libs. + + This class build an instrument function that + pair tests an existing compiled relax vm implementation + and an extra module, which can sits in another backend + but offers a same subset of compiled TIR functions. + + The instrumentation enables us to automatically + check and compare each ops being called in the pipeline + by looking up the same name in the provided mod and run testing. + + Parameters + ---------- + mod: runtime.Module + The module of interest to be validated. + + device: runtime.Device + The device to run the target module on. + + verbose: bool + Whether print out messages. + + rtol: float + rtol used in validation + + atol: float + atol used in validation + """ + + def __init__(self, mod, device, verbose=True, rtol=1e-5, atol=1e-5): + self.mod = mod + self.device = device + self.verbose = verbose + self.counter = 0 + self.rtol = rtol + self.atol = atol + + def compare( + self, + name: str, + ref_args: Union[List[tvm.nd.NDArray], Tuple[tvm.nd.NDArray, ...]], + new_args: Union[List[tvm.nd.NDArray], Tuple[tvm.nd.NDArray, ...]], + ret_indices: Iterable[int], + ): + """Comparison function, can be overloaded. + + Parameters + ---------- + name: str + Name of the function. + + ref_args: + The reference arguments. + + new_args: + The args to be passed to the comparison function. + + ret_indices: + List of indices to validate return values. + """ + my_func = self.mod.get_function(name, query_imports=True) + if self.verbose: + print(f"[{self.counter}] Validating {name} ...") + my_func(*new_args) + for rindex in ret_indices: + tvm.testing.assert_allclose( + new_args[rindex].numpy(), ref_args[rindex].numpy(), atol=self.atol, rtol=self.rtol + ) + if self.verbose: + print(f"[{self.counter}] Validating {name}, passed.") + self.counter += 1 + + def skip_instrument(self, func, name, before_run, ret_val, *args): + return False + + def __call__(self, func, name, before_run, ret_val, *args): + if before_run: + return + if name.startswith("vm.builtin."): + return + if any(not isinstance(x, tvm.nd.NDArray) for x in args): + return + try: + self.mod.get_function(name, query_imports=True) + except AttributeError: + if self.verbose: + print(f"Cannot find {name}, skip...") + return + + if self.skip_instrument(func, name, before_run, ret_val, *args): + return + + new_args = [] + # not always true, true for most ops. + ret_indices = (len(args) - 1,) + for i, arg in enumerate(args): + arr = tvm.nd.empty(arg.shape, device=self.device) + # copy from cpu since we look at different device + if i not in ret_indices: + arr.copyfrom(arg.copyto(tvm.cpu())) + new_args.append(arr) + + self.compare(name, args, new_args, ret_indices) diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py new file mode 100644 index 000000000000..830ddd779fe5 --- /dev/null +++ b/python/tvm/relax/testing/nn.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +"""PyTorch-like nn.Module API for constructing workloads.""" + + +from typing import List, Any, Callable, Union +import typing +import numpy as np # type: ignore + +import tvm +from tvm import relax, topi, tir + + +def emit_te(func: Callable, *args: Any, **kwargs: Any) -> relax.Var: + return relax.BlockBuilder.current().emit_te(func, *args, **kwargs) + + +class Placeholder(relax.Var): + """A placeholder variable that can represent model input.""" + + def __init__( + self, shape: Union[List[Any], typing.Tuple[Any, ...]], dtype="float32", name="data" + ): + if not isinstance(shape, (list, tuple)): + raise TypeError("the shape of Placeholder is expected to be a list or a tuple") + super().__init__( + relax.BlockBuilder.current().get_unique_name(name), relax.TensorStructInfo(shape, dtype) + ) + + +class Parameter(relax.Var): + """A special kind of relax Var that represents model parameter(weight).""" + + def __init__( + self, shape: Union[List[Any], typing.Tuple[Any, ...]], dtype="float32", name="param" + ): + if not isinstance(shape, (list, tuple)): + raise TypeError("the shape of Parameter is expected to be a list or a tuple") + super().__init__( + relax.BlockBuilder.current().get_unique_name(name), relax.TensorStructInfo(shape, dtype) + ) + + +class Module: + """Base class for all model modules. + + A neural network or a layer can subclass this class. + + Example + ------- + .. code-block:: python + + # Define a linear layer + class Linear(Module) + def __init__(self, in_features, out_features, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter((in_features, out_features), name="linear_weight") + if bias: + self.bias = Parameter((out_features,), name="linear_bias") + else: + self.bias = None + + # All submodules should implement forward. + # Defines the forward computation performed at every call. + def forward(self, input: relax.Expr) -> relax.Var: + y = emit_te(topi.matmul, input, self.weight) + if self.bias is not None: + y = emit_te(topi.add, y, self.bias) + return y + """ + + def parameters(self) -> List[Parameter]: + """Return the list of parameters in the module.""" + return _unpack_params(self.__dict__) + + def forward(self, input: relax.Expr): + """Define the computation performed at every call.""" + raise NotImplementedError() + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +def _unpack_params(value: object) -> List[relax.Var]: + if isinstance(value, Parameter): + return [value] + if isinstance(value, Module): + return value.parameters() + if isinstance(value, dict): + params = [] + for v in value.values(): + params += _unpack_params(v) + return params + if isinstance(value, (list, tuple)): + params = [] + for v in value: + params += _unpack_params(v) + return params + if value is None or isinstance(value, (int, float, str)): + return [] + raise TypeError("not supported type when unpacking parameters: {}".format(type(value))) + + +def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]: + """Utility function to initialize model's parameters.""" + shape_dict = {v.name_hint: v.struct_info.shape for v in mod["main"].params} + params = [] + for k, v in shape_dict.items(): + if k.startswith("data"): + continue + if isinstance(v, relax.ShapeExpr): + shape = [] + for i in v: + if isinstance(i, tir.IntImm): + shape.append(int(i)) + else: + raise TypeError("cannot initialize for unknown-shape parameters.") + params.append(tvm.nd.array(np.zeros(shape).astype(np.float32))) + else: + raise TypeError("cannot initialize for unknown-shape parameters.") + return params + + +class Sequential(Module): + """A sequential container that concatenates modules in it. + + Example + ------- + .. code-block:: python + + model = nn.Sequential( + nn.Conv2d(1, 20, 5), + nn.ReLU(), + nn.Conv2d(20, 64, 5), + nn.ReLU() + ) + """ + + def __init__(self, *modules: Module): + self.modules = modules + + def forward(self, input: relax.Expr) -> relax.Var: + for module in self.modules: + input = module(input) + return input + + +class ReLU(Module): + """Applies the rectified linear unit activation function on the input.""" + + def forward(self, input: relax.Expr) -> relax.Var: + return emit_te(topi.nn.relu, input) + + +class LogSoftmax(Module): + """Applies log softmax activation function on the input.""" + + def forward(self, input: relax.Expr) -> relax.Var: + return emit_te(topi.nn.log_softmax, input) + + +class Linear(Module): + """Applies a linear transformation to the input data: :math:`y = xA + b`.""" + + def __init__(self, in_features, out_features, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter((in_features, out_features), name="linear_weight") + if bias: + self.bias = Parameter((out_features,), name="linear_bias") + else: + self.bias = None + + def forward(self, input: relax.Expr) -> relax.Var: + y = emit_te(topi.matmul, input, self.weight) + if self.bias is not None: + y = emit_te(topi.add, y, self.bias) + return y diff --git a/python/tvm/relax/testing/relay_translator.py b/python/tvm/relax/testing/relay_translator.py new file mode 100644 index 000000000000..46fdb7021d20 --- /dev/null +++ b/python/tvm/relax/testing/relay_translator.py @@ -0,0 +1,265 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, invalid-name, no-else-return, too-many-nested-blocks +"""Relay to Relax translator.""" + +from typing import Any, Dict, List, Optional + +import tvm +from tvm import relax, relay +from tvm.ir.module import IRModule +from tvm.relax.testing import nn +from tvm.relay.backend.te_compiler import select_implementation +from tvm.runtime import NDArray +from tvm.target import Target +from tvm.meta_schedule.relay_integration import _autotvm_silencer + + +def from_relay( + func: relay.Function, + target: Target, + relay_params: Optional[Dict[str, NDArray]] = None, + *, + opt_level: int = 3, + pass_config: Optional[Dict[str, Any]] = None, + disabled_pass: Optional[List[str]] = None, + translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] = None, + append_op_attrs: bool = False, +) -> IRModule: + """Convert a Relay function into a Relax program. + + Parameters + ---------- + func : relay.Function + Relay function to be converted. + + target: Target + The target to compile the model, used for selecting topi functions. + + relay_params: Optional[Dict[str, NDArray]] + Parameters to bind. + + opt_level: int + The optimization level. + + pass_config: Optional[Dict[str, Any]] + Pass configuration. + + disabled_pass: Optional[List[str]] + Passes to disable. + + translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] + Dict that maps op names to user-defined PrimFuncs. + Takes relay operator names and forces them to user-defined PrimFuncs during translation. + + append_op_attrs: bool + Append relay op attrs to generated prim_funcs + + Returns + ------- + mod : tvm.IRModule + The Relax IRModule for compilation + """ + # A map to store the mapping of Relay Expr to its corresponding Relax var + var_map = {} + # The output of the function + output_var = None + + if not isinstance(target, Target): + target = Target(target) + if disabled_pass is None: + disabled_pass = [] + if pass_config is None: + pass_config = { + "relay.FuseOps.max_depth": 1, # Disable relay fusion + "relay.backend.use_meta_schedule": True, + "relay.backend.use_meta_schedule_dispatch": True, + } + + if relay_params: + func = relay.build_module.bind_params_by_name(func, relay_params) + + params = [] + tir_var_map: Dict[tvm.tir.Var, tvm.tir.PrimExpr] = dict() + + def convert_shape(shape: List[tvm.tir.PrimExpr]) -> List[tvm.tir.PrimExpr]: + """Convert the relay shape to relax shape by changing Any dim to symbolic dim""" + ret = [] + for dim in shape: + if isinstance(dim, tvm.tir.IntImm): + ret.append(tvm.tir.IntImm("int64", int(dim))) + elif isinstance(dim, tvm.tir.Any): + ret.append(tvm.tir.Var("d", "int64")) + else: + ret.append(dim) + return ret + + def _copy_undefined_var_in_shape(sinfo: relax.TensorStructInfo): + def _visit_expr(e: tvm.tir.PrimExpr): + if isinstance(e, tvm.tir.Var) and e not in tir_var_map: + new_var = tvm.tir.Var(e.name, e.dtype) + tir_var_map[e] = new_var + + assert isinstance( + sinfo.shape, relax.ShapeExpr + ), "arg with TensorStructInfo in Relay translator must have ShapeExpr shape" + for shape_value in sinfo.shape.values: + tvm.tir.stmt_functor.post_order_visit(shape_value, _visit_expr) + + def visit_func(node): + nonlocal output_var + if isinstance(node, relay.Var): + if isinstance(node.type_annotation, relay.TensorType): + var_map[node] = nn.Placeholder( + tuple(convert_shape(node.type_annotation.shape)), + node.type_annotation.dtype, + node.name_hint, + ) + params.append(var_map[node]) + else: + raise TypeError("The type of relay.Var to be translated must be of TensorType.") + elif isinstance(node, relay.Call): + args = node.args + new_args = [] + te_inputs = [] + for arg in args: + if arg in var_map: + arg_expr = var_map[arg] + if isinstance(arg_expr.struct_info, relax.TensorStructInfo): + _copy_undefined_var_in_shape(arg_expr.struct_info) + new_args.append(arg_expr) + te_inputs.append(tvm.relax.expr.te_tensor(arg_expr, tir_var_map)) + elif isinstance(arg_expr.struct_info, relax.TupleStructInfo): + n_tensor = len(arg_expr.struct_info.fields) + bound_tuple = bb.lookup_binding(arg_expr) + if isinstance(bound_tuple, relax.Tuple): + assert len(bound_tuple) == n_tensor + for i in range(n_tensor): + if isinstance(bound_tuple, relax.Tuple): + item = bb.emit(bound_tuple[i]) + else: + item = bb.emit(relax.TupleGetItem(arg_expr, i)) + + assert isinstance(item.struct_info, relax.TensorStructInfo), ( + "Relay translator doesn't support Call " + "argument being nested Tensor tuple." + ) + _copy_undefined_var_in_shape(item.struct_info) + new_args.append(item) + te_inputs.append(tvm.relax.expr.te_tensor(item, tir_var_map)) + else: + raise TypeError( + f"CallTIR argument type being {type(arg_expr.checked_type)} is not " + "supported." + ) + + op_name = node.op.name + attrs = node.attrs + out_type = node.checked_type + + op_attrs_map = {} + if append_op_attrs: + func_attr_map = {"op_name": op_name} + if attrs: + for attr in attrs.keys(): + func_attr_map[attr] = attrs[attr] + + op_attrs_map["op_attrs"] = func_attr_map + + if translate_op_with_tir and op_name in translate_op_with_tir: + tir_gvar = bb.add_func(translate_op_with_tir[op_name], op_name) + call = relax.call_tir( + tir_gvar, new_args, relax.TensorStructInfo(out_type.shape, out_type.dtype) + ) + var = bb.emit(call) + else: + with target: + best_impl, outputs = select_implementation( + node.op, + attrs, + te_inputs, + out_type, + target, + use_autotvm=False, + ) + compute_func = best_impl.compute + name_hint = op_name.split(".")[-1] + var = bb.emit_te( + compute_func, + attrs, + new_args, + node.checked_type, + primfunc_name_hint=name_hint, + primfunc_attrs=op_attrs_map, + ) + + output_var = var + var_map[node] = var + elif isinstance(node, relay.Constant): + # fill the shape and checked_type fields of the Constant + new_constant = relax.Constant(node.data) + var_map[node] = new_constant + elif isinstance(node, relay.Tuple): + new_fields = [] + for field in node.fields: + if field in var_map: + new_fields.append(var_map[field]) + else: + raise RuntimeError("field is not in var_map.") + new_tuple = relax.Tuple(new_fields) + new_tuple_var = relax.BlockBuilder.current().emit(new_tuple) + var_map[node] = new_tuple_var + output_var = new_tuple_var + elif isinstance(node, relay.TupleGetItem): + if node.tuple_value in var_map: + new_tuple = var_map[node.tuple_value] + new_tuple_get_item_node = relax.TupleGetItem(new_tuple, node.index) + new_tuple_get_item_var = relax.BlockBuilder.current().emit(new_tuple_get_item_node) + var_map[node] = new_tuple_get_item_var + output_var = new_tuple_get_item_var + else: + raise RuntimeError("tuple is not in var_map") + elif isinstance(node, relay.Function): + cur_bb = relax.BlockBuilder.current() + gv = cur_bb.emit_output(output_var) + df_block = cur_bb._end_block() + cur_bb._blocks.append(df_block) + cur_bb.emit_func_output(gv, params) + elif isinstance(node, tvm.ir.Op): + pass + else: + raise TypeError("{} is not supported yet.".format(str(type(node)))) + + # List of subset of relay->relay optimizations + # See src/relay/backend/utils.cc::GetPassPrefix() for full list + seq = tvm.get_global_func("relay.backend.GetPassPrefixSeq")(True, True) + + # Since optimization passes and OpStrategy are highly context-dependent, + # we match the exact same context with `extract_task_from_relay()` env + with target, _autotvm_silencer(), tvm.transform.PassContext( + opt_level=opt_level, + config=pass_config, + disabled_pass=disabled_pass, + ): + mod = tvm.IRModule.from_expr(func) + mod = seq(mod) + bb = relax.BlockBuilder() + with bb.function("main"): + bb._begin_dataflow_block() + relay.analysis.post_order_visit(mod["main"], visit_func) + + return bb.get() diff --git a/python/tvm/relax/testing/runtime_builtin.py b/python/tvm/relax/testing/runtime_builtin.py new file mode 100644 index 000000000000..1b04364e69fa --- /dev/null +++ b/python/tvm/relax/testing/runtime_builtin.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities for runtime builtin functions.""" +from enum import IntEnum + + +class MatchShapeCode(IntEnum): + """Code passed to match shape builtin""" + + ASSERT_EQUAL_TO_IMM = 0 + STORE_TO_HEAP = 1 + NO_OP = 2 + ASSERT_EQUAL_TO_LOAD = 3 + + +class MakeShapeCode(IntEnum): + """Code passed to match shape builtin""" + + USE_IMM = 0 + LOAD_SHAPE = 1 diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py new file mode 100644 index 000000000000..c8ca618d4c1a --- /dev/null +++ b/python/tvm/relax/testing/transform.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ +"""Relax transformation passes for testing""" + +from tvm import ir +from tvm import relax +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.target import Target +from tvm.ir import transform +from tvm.relax import PyExprMutator +from tvm.relax.expr import Call +from tvm.relay.backend.te_compiler import select_implementation + + +@ir.transform.module_pass(opt_level=0) +class LowerWithRelayOpStrategyPass(transform.Pass): + """Lower Relax Op into TIR by using Relay OpStrategy. + + Since operators like conv2d, add, matmul are relay-, relax- independent, + this pass assumes we can always find relay op equivalent for such relax ops, + and use Relay Op Strategy (legacy) to perform lowering and find the TOPI implementation. + + Parameters + ---------- + target : Target + target info + + Returns + ------- + pass : transform.Pass + lowering pass + """ + + def __init__(self, target: Target): + self.target = target + + def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: + """Implement lowering mechanism. + + Parameters + ---------- + mod : IRModule + Input IRModule with Relax ops + + ctx: PassContext + Pass context + + Returns + ------- + out_mod : IRModule + Output IRModule with lowered TIR functions + """ + target = self.target + + @relax.expr_functor.mutator + class Lowerer(PyExprMutator): + """Mutator that performs lowering.""" + + def visit_call_(self, call_node: Call): + # Ignore function calls + # We only target calls for operators + if isinstance(call_node.op, (relax.GlobalVar, relax.expr.ExternFunc)): + return call_node + + # Current relax op name simply adds "relax." prefix to relay op name. + # Thus, remove "relax." prefix to deduce relay op name. + relay_op_name = call_node.op.name[6:] + # Check if equivalent relay op exists. If not, return the original call. + if relay_op_name in ir.Op.list_op_names(): + relay_op = ir.Op.get(relay_op_name) + + # Todo(relax-team): to be revisited - support dyn shape or deprecate. + tir_var_map = dict() + te_inputs = [relax.expr.te_tensor(arg, tir_var_map) for arg in call_node.args] + best_impl_tuple = select_implementation( + relay_op, + call_node.attrs, + te_inputs, + call_node.checked_type, + target, + use_autotvm=False, + ) + compute_func = best_impl_tuple[0].compute + # Extract the name of the operator without the prefix + # e.g., for relay op "nn.conv2d", name_hint would be conv2d + name_hint = relay_op_name.split(".")[-1] + + return self.builder_.call_te( + compute_func, + call_node.attrs, + call_node.args, + call_node.attrs, + primfunc_name_hint=name_hint, + ) + else: + return call_node + + # TOOD(@team): transform() wapper is necessary to include TIR functions. + # IMO, this is bit unintuitive. Can we improve this? + def transform(self): + for gv, func in mod.functions.items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + self.builder_.update_func(gv, updated_func) + new_mod = self.builder_.get() + new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else new_mod + return new_mod + + return Lowerer().transform() diff --git a/python/tvm/relax/testing/vm.py b/python/tvm/relax/testing/vm.py new file mode 100644 index 000000000000..79da54be1010 --- /dev/null +++ b/python/tvm/relax/testing/vm.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Testing utilities for relax VM""" +from typing import Any, List +import numpy as np # type: ignore + +import tvm +from tvm import relax +from tvm.runtime.object import Object + + +@tvm.register_func("test.vm.move") +def move(src): + return src + + +@tvm.register_func("test.vm.add") +def add(a, b): + ret = a.numpy() + b.numpy() + return tvm.nd.array(ret) + + +@tvm.register_func("test.vm.mul") +def mul(a, b): + ret = a.numpy() * b.numpy() + return tvm.nd.array(ret) + + +@tvm.register_func("test.vm.equal_zero") +def equal_zero(a): + ret = np.all((a.numpy() == 0)) + return tvm.nd.array(ret) + + +@tvm.register_func("test.vm.subtract_one") +def subtract_one(a): + ret = np.subtract(a.numpy(), 1) + return tvm.nd.array(ret) + + +@tvm.register_func("test.vm.identity") +def identity_packed(a, b): + b[:] = tvm.nd.array(a.numpy()) + + +@tvm.register_func("test.vm.tile") +def tile_packed(a, b): + b[:] = tvm.nd.array(np.tile(a.numpy(), (1, 2))) + + +@tvm.register_func("test.vm.add_scalar") +def add_scalar(a, b): + return a + b + + +@tvm.register_func("test.vm.get_device_id") +def get_device_id(device): + return device.device_id + + +def check_saved_func(vm: relax.VirtualMachine, func_name: str, *inputs: List[Any]) -> Object: + # uses save_function to create a closure with the given inputs + # and ensure the result is the same + # (assumes the functions return tensors and that they're idempotent) + saved_name = f"{func_name}_saved" + vm.save_function(func_name, saved_name, *inputs) + res1 = vm[func_name](*inputs) + res2 = vm[saved_name]() + tvm.testing.assert_allclose(res1.numpy(), res2.numpy(), rtol=1e-7, atol=1e-7) + return res1 diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py new file mode 100644 index 000000000000..78f450b25ce2 --- /dev/null +++ b/python/tvm/relax/transform/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax transformations. """ + +from .transform import * + +# Import to register the legalization functions. +from . import legalize_ops diff --git a/python/tvm/relax/transform/_ffi_api.py b/python/tvm/relax/transform/_ffi_api.py new file mode 100644 index 000000000000..667aa62c2c95 --- /dev/null +++ b/python/tvm/relax/transform/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.transform""" +import tvm._ffi + +tvm._ffi._init_api("relax.transform", __name__) diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py b/python/tvm/relax/transform/legalize_ops/__init__.py new file mode 100644 index 000000000000..3e57b815dbd8 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/__init__.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Legalize high-level operator calls in Relax functions to call_tir.""" +from . import binary +from . import creation +from . import datatype +from . import image +from . import index +from . import linear_algebra +from . import manipulate +from . import nn +from . import search +from . import statistical +from . import unary diff --git a/python/tvm/relax/transform/legalize_ops/binary.py b/python/tvm/relax/transform/legalize_ops/binary.py new file mode 100644 index 000000000000..897b67651883 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/binary.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for binary operators.""" +from tvm import topi +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import TEFunc, LegalizeFunc, _try_convert_to_scalar_const, register_legalize + + +def _binary(te_func: TEFunc) -> LegalizeFunc: + """A common wrapper util for the legalization of binary operators. + + It detects if one of the binary op arguments is a constant scalar. It so, + it extracts the scalar value to simplify the generated PrimFunc. + """ + + def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: + # To simplify the created PrimFunc, we first check if arg1 is a constant scalar. + # If it is not, we then check if arg0 is a constant scalar. + arg0 = call.args[0] + arg1 = _try_convert_to_scalar_const(call.args[1]) + if isinstance(arg1, Expr): # type: ignore + arg0 = _try_convert_to_scalar_const(arg0) + return bb.call_te(te_func, arg0, arg1) + + return binary_call_te + + +register_legalize("relax.add", _binary(topi.add)) +register_legalize("relax.divide", _binary(topi.divide)) +register_legalize("relax.floor_divide", _binary(topi.floor_divide)) +register_legalize("relax.multiply", _binary(topi.multiply)) +register_legalize("relax.power", _binary(topi.power)) +register_legalize("relax.subtract", _binary(topi.subtract)) +register_legalize("relax.equal", _binary(topi.equal)) + +register_legalize("relax.greater", _binary(topi.greater)) +register_legalize("relax.greater_equal", _binary(topi.greater_equal)) +register_legalize("relax.less", _binary(topi.less)) +register_legalize("relax.less_equal", _binary(topi.less_equal)) +register_legalize("relax.not_equal", _binary(topi.not_equal)) + +register_legalize("relax.maximum", _binary(topi.maximum)) +register_legalize("relax.minimum", _binary(topi.minimum)) diff --git a/python/tvm/relax/transform/legalize_ops/common.py b/python/tvm/relax/transform/legalize_ops/common.py new file mode 100644 index 000000000000..4ee9c6758f2c --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/common.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common functionality for legalization.""" +from typing import Callable, Optional, Union + +import tvm +from tvm import te +from tvm.tir import FloatImm, IntImm +from ...block_builder import BlockBuilder +from ...expr import Call, Expr, Constant + + +##################### Types ##################### + + +# The function type of a TE function, which accepts TE Tensors and +# other attributes, and returns the output TE Tensor. +TEFunc = Callable[..., te.Tensor] + +# The function type of a legalization function, which takes a +# BlockBuilder and the Call to be legalized, and outputs the legalization +# result Expr. +LegalizeFunc = Callable[[BlockBuilder, Call], Expr] + + +##################### Utilities ##################### + + +def _try_convert_to_scalar_const( + expr: Expr, python_native: bool = False +) -> Union[Expr, FloatImm, IntImm, bool, float, int]: + """Check if the input Expr is a scalar constant. + If it is, return its plain value with the same data type or in native python type. + If it is not, return the input expr. + + Note that if the python_native flag is True, the returned value will be in native python type, + this might cause loss of data type for example, a float16 constant will be converted to float32 + and a int64 constant will be converted to int32. + + Parameters + ---------- + expr : Expr + The expr to be checked and converted. + + Returns + ------- + ret : Union[Expr, FloatImm, IntImm, bool, float, int] + Return a FloatImm or IntImm if the given expr is a scalar integer or float constant, and the + python native flag is False. Or return the plain value of the constant in native python type + if the python native flag is True. + Or return the input itself if it is not a scalar constant. + """ + if isinstance(expr, Constant) and expr.struct_info.ndim == 0: + # get the value of the scalar constant + value = expr.data.numpy()[()].item() + dtype = expr.struct_info.dtype + if python_native: + return value + # preserve the data type of the constant + if dtype.startswith("float"): + return tvm.tir.FloatImm(dtype, value) + elif dtype.startswith("int") or dtype.startswith("uint") or dtype.startswith("bool"): + return tvm.tir.IntImm(dtype, value) + return expr + + +def _call_topi_without_attr(te_func: TEFunc, primfunc_name: Optional[str] = None) -> LegalizeFunc: + """A common wrapper util for the ops who has no attributes and whose + legalization is simply passing its arguments to some TE function. + + Parameters + ---------- + te_func : TEFunc + The input TE function which is to be converted to PrimFunc. + + primfunc_name : Optional[str] + The name of the generated PrimFunc. + If it is not specified, the name of `te_func` will be used by default. + + Returns + ------- + func : LegalizeFunc + The legalization wrapper function, which wraps the input TE function. + """ + if primfunc_name is None: + primfunc_name = te_func.__name__ + return lambda bb, call: bb.call_te(te_func, *call.args, primfunc_name_hint=primfunc_name) + + +##################### Decorators ##################### + +_LEGALIZE_ATTR_NAME = "FLegalize" + + +def register_legalize(op_name: str, legal_func: LegalizeFunc = None): + """Register legal transformation function for a Relax op. + + Parameters + ---------- + op_name : str + The name of the operator + + legal_func: function (bb: BlockBuilder, call: Call) -> new_expr: Expr + The function for transforming an expr to another expr. + """ + return tvm.ir.register_op_attr(op_name, _LEGALIZE_ATTR_NAME, legal_func) diff --git a/python/tvm/relax/transform/legalize_ops/creation.py b/python/tvm/relax/transform/legalize_ops/creation.py new file mode 100644 index 000000000000..76548fcfb439 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/creation.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for creation operators.""" +from typing import Optional + +from tvm import topi, tir +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import LegalizeFunc, register_legalize, _try_convert_to_scalar_const + + +def _full(is_like: bool, fill_value: Optional[float], primfunc_name: str) -> LegalizeFunc: + def full_call_te(bb: BlockBuilder, call: Call) -> Expr: + _fill_value = ( + _try_convert_to_scalar_const(call.args[1], python_native=True) + if fill_value is None + else fill_value + ) + + return bb.call_te( + topi.full, + call.args[0].struct_info.shape if is_like else call.args[0], + call.struct_info.dtype, + _fill_value, + primfunc_name_hint=primfunc_name, + ) + + return full_call_te + + +def _tril_triu(is_upper: bool, primfunc_name: str) -> LegalizeFunc: + def tril_triu_call_te(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.trilu, + call.args[0], + tir.const(call.attrs.k, "int32"), + upper=is_upper, + primfunc_name_hint=primfunc_name, + ) + + return tril_triu_call_te + + +register_legalize("relax.full", _full(is_like=False, fill_value=None, primfunc_name="full")) +register_legalize("relax.full_like", _full(is_like=True, fill_value=None, primfunc_name="full")) +register_legalize("relax.ones", _full(is_like=False, fill_value=1.0, primfunc_name="ones")) +register_legalize("relax.ones_like", _full(is_like=True, fill_value=1.0, primfunc_name="ones")) +register_legalize("relax.zeros", _full(is_like=False, fill_value=0.0, primfunc_name="zeros")) +register_legalize("relax.zeros_like", _full(is_like=True, fill_value=0.0, primfunc_name="zeros")) +register_legalize("relax.tril", _tril_triu(is_upper=False, primfunc_name="tril")) +register_legalize("relax.triu", _tril_triu(is_upper=True, primfunc_name="triu")) diff --git a/python/tvm/relax/transform/legalize_ops/datatype.py b/python/tvm/relax/transform/legalize_ops/datatype.py new file mode 100644 index 000000000000..8e1d88577577 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/datatype.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for datatype operators.""" +from tvm import topi, relax +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import _try_convert_to_scalar_const, register_legalize + + +@register_legalize("relax.astype") +def _astype(bb: BlockBuilder, call: Call) -> Expr: + arg = _try_convert_to_scalar_const(call.args[0], python_native=True) + if isinstance(arg, Expr): # type: ignore + return bb.call_te(topi.cast, arg, call.attrs.dtype) + else: + return relax.const(arg, call.attrs.dtype) diff --git a/python/tvm/relax/transform/legalize_ops/image.py b/python/tvm/relax/transform/legalize_ops/image.py new file mode 100644 index 000000000000..1b2a342b0b53 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/image.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for image operators.""" +from tvm import topi +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import register_legalize + + +@register_legalize("relax.image.resize2d") +def _image_resize2d(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.image.resize2d, + call.args[0], + roi=call.attrs.roi, + size=call.args[1], + layout=call.attrs.layout, + method=call.attrs.method, + coordinate_transformation_mode=call.attrs.coordinate_transformation_mode, + rounding_method=call.attrs.rounding_method, + bicubic_alpha=call.attrs.cubic_alpha, + bicubic_exclude=call.attrs.cubic_exclude, + extrapolation_value=call.attrs.extrapolation_value, + ) diff --git a/python/tvm/relax/transform/legalize_ops/index.py b/python/tvm/relax/transform/legalize_ops/index.py new file mode 100644 index 000000000000..eccccc7c6d3f --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/index.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for index operators.""" +import logging + +from tvm import topi, tir +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import register_legalize + + +@register_legalize("relax.take") +def _take(bb: BlockBuilder, call: Call) -> Expr: + # Currently Relax `take` operator doesn't provide the mode choices and + # requires input indices to be in range. + # We use fast mode, which leads to runtime error whenever some index is + # out of bound. + return bb.call_te(topi.take, call.args[0], call.args[1], call.attrs.axis, mode="fast") + + +@register_legalize("relax.strided_slice") +def _strided_slice(bb: BlockBuilder, call: Call) -> Expr: + if not all( + isinstance(call.args[0].struct_info.shape.values[i.value], tir.IntImm) + for i in call.attrs.axes + ): + logging.info( + "Cases where an axis with symbolic length is sliced are not able " + "to be legalized through TOPI" + ) + return call + + strides = ( + [tir.IntImm("int64", 1)] * len(call.attrs.axes) + if call.attrs.strides is None + else call.attrs.strides + ) + return bb.call_te( + topi.strided_slice, + call.args[0], + call.attrs.begin, + call.attrs.end, + strides, + call.attrs.axes, + slice_mode="end", + ) diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py b/python/tvm/relax/transform/legalize_ops/linear_algebra.py new file mode 100644 index 000000000000..abe21d9fffee --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for linear algebra operators.""" +from tvm import te, relax, tir +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import register_legalize + + +@register_legalize("relax.matmul") +def _matmul(bb: BlockBuilder, call: Call) -> Expr: + def te_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: + a_shape = list(a.shape) + b_shape = list(b.shape) + a_prepended = False + b_appended = False + if len(a_shape) == 1: + a_prepended = True + a_shape.insert(0, 1) + if len(b_shape) == 1: + b_appended = True + b_shape.append(1) + + is_a_larger = len(a_shape) > len(b_shape) + offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape) + + a_relax = relax.Var("a", relax.TensorStructInfo(a.shape)) + b_relax = relax.Var("b", relax.TensorStructInfo(b.shape)) + f_infer_sinfo = call.op.get_attr("FInferStructInfo") + output_shape = f_infer_sinfo(relax.op.matmul(a_relax, b_relax), bb).shape + + def matmul_compute(*idx_spatial): + k = te.reduce_axis((0, a_shape[-1]), name="k") + + def multiply_compute(idx_reduce): + a_indices = [] + b_indices = [] + + for i in range(offset): + if is_a_larger: + a_indices.append(idx_spatial[i]) + else: + b_indices.append(idx_spatial[i]) + for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)): + a_dim = a_shape[i if is_a_larger else i - offset] + b_dim = b_shape[i if not is_a_larger else i - offset] + a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1 + b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1 + a_indices.append(0 if a_dim_is_one else idx_spatial[i]) + b_indices.append(0 if b_dim_is_one else idx_spatial[i]) + if not a_prepended: + a_indices.append(idx_spatial[-2 + b_appended]) + a_indices.append(idx_reduce) + b_indices.append(idx_reduce) + if not b_appended: + b_indices.append(idx_spatial[-1]) + + dtype = call.attrs.out_dtype + if dtype != "": + return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) + else: + return a(*a_indices) * b(*b_indices) + + return te.sum(multiply_compute(k), axis=k) + + return te.compute( + output_shape, + lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda + name="matmul", + ) + + return bb.call_te(te_matmul, call.args[0], call.args[1], primfunc_name_hint="matmul") diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py new file mode 100644 index 000000000000..144ef04748c5 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for manipulate operators.""" +import logging +from typing import Optional + +import tvm +from tvm import topi, tir, relax, te +from tvm.tir.expr import IntImm +from ...block_builder import BlockBuilder +from ...expr import Call, Expr, Var, Tuple, TupleGetItem, ShapeExpr +from .common import TEFunc, LegalizeFunc, register_legalize + + +def _reshape( + te_func: TEFunc, primfunc_name: str, is_collapse_sum_like: bool = False +) -> LegalizeFunc: + def reshape_call_te(bb: BlockBuilder, call: Call): + tgt_shape = call.args[1].struct_info.shape if is_collapse_sum_like else call.args[1] + # If target shape is Var, pass its bound expr only when it is ShapeExpr + if isinstance(tgt_shape, Var): + tgt_shape = bb.lookup_binding(tgt_shape) + assert isinstance(tgt_shape, ShapeExpr) + return bb.call_te(te_func, call.args[0], tgt_shape, primfunc_name_hint=primfunc_name) + + return reshape_call_te + + +register_legalize("relax.broadcast_to", _reshape(topi.broadcast_to, "broadcast_to")) +register_legalize("relax.reshape", _reshape(topi.reshape, "reshape")) +register_legalize( + "relax.collapse_sum_like", + _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), +) +register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum, "collapse_sum")) + + +@register_legalize("relax.concat") +def _concat(bb: BlockBuilder, call: Call) -> Expr: + t = call.args[0] + n_field = len(t.struct_info.fields) + while isinstance(t, Var): + binding = bb.lookup_binding(t) + if not isinstance(binding, (Tuple, Var)): + break + t = binding + + assert isinstance(t, (Tuple, Var)) + fields = ( + t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] + ) + return bb.call_te( + topi.concatenate, fields, None if call.attrs.axis is None else call.attrs.axis.value + ) + + +@register_legalize("relax.expand_dims") +def _expand_dims(bb: BlockBuilder, call: Call) -> Expr: + def te_expand_dims(data, axis): + data_relax = relax.Var("data", relax.TensorStructInfo(data.shape)) + f_infer_sinfo = call.op.get_attr("FInferStructInfo") + output_shape = f_infer_sinfo(relax.op.expand_dims(data_relax, axis), bb).shape + output_ndim = len(output_shape) + + data_dims = [] + for i in range(output_ndim): + if i not in axis and (i - output_ndim) not in axis: + data_dims.append(i) + return te.compute( + output_shape, + lambda *idx: data(*[idx[dim] for dim in data_dims]), + name="expand_dims", + ) + + return bb.call_te( + te_expand_dims, call.args[0], call.attrs.axis, primfunc_name_hint="expand_dims" + ) + + +@register_legalize("relax.flatten") +def _flatten(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.reshape, call.args[0], call.struct_info.shape.values) + + +@register_legalize("relax.permute_dims") +def _permute_dims(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.transpose, call.args[0], call.attrs.axes) + + +@register_legalize("relax.split") +def _split(bb: BlockBuilder, call: Call) -> Expr: + if isinstance(call.attrs.indices_or_sections, tir.IntImm): + indices_or_sections = call.attrs.indices_or_sections.value + modulo = tvm.arith.Analyzer().simplify( + call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections + ) + if modulo != 0: + logging.info( + "Split cannot be legalized by TOPI when the axis being split has " + "length that not divisible by the input number of section." + ) + return call + else: + indices_or_sections = call.attrs.indices_or_sections + return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) + + +@register_legalize("relax.squeeze") +def _squeeze(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.squeeze, call.args[0], call.attrs.axis) + + +@register_legalize("relax.repeat") +def _repeat(bb: BlockBuilder, call: Call) -> Expr: + def te_repeat(data: te.Tensor, repeats: IntImm, axis: Optional[IntImm]): + if axis is None: + # flatten data + out_shape = data.shape[0] + for i in data.shape[1:]: + out_shape *= i + data = topi.reshape(data, (out_shape,)) + axis = 0 + # topi only receives int repeats and axis + return topi.repeat(data, int(repeats), int(axis)) + + return bb.call_te( + te_repeat, call.args[0], call.attrs.repeats, call.attrs.axis, primfunc_name_hint="repeat" + ) + + +@register_legalize("relax.tile") +def _tile(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.tile, call.args[0], call.attrs.repeats) + + +@register_legalize("relax.cumsum") +def _cumsum(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.cumsum, call.args[0], call.attrs.axis, call.attrs.dtype) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py new file mode 100644 index 000000000000..1ce45206354d --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -0,0 +1,370 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Default legalization function for neural network operators.""" +import logging + +from tvm import topi, tir, te +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import register_legalize, _call_topi_without_attr + + +@register_legalize("relax.nn.conv1d") +def _nn_conv1d(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.data_layout: + logging.info( + "TOPI conv1d does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + if len(call.attrs.data_layout) != 3 or len(call.attrs.kernel_layout) != 3: + logging.info( + "Conv1D where data layout or kernel layout have channel chunk " + "cannot be legalized by TOPI at this moment." + ) + return call + if call.attrs.groups != 1: + data_layout = tir.layout(call.attrs.data_layout) + kernel_layout = tir.layout(call.attrs.kernel_layout) + ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")] + oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")] + if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm): + logging.info( + "Conv1D where number of groups is more than one and input or output " + "channel size is symbolic cannot be legalized by TOPI at this moment." + ) + return call + + return bb.call_te( + topi.nn.conv1d, + data=call.args[0], + kernel=call.args[1], + strides=call.attrs.strides, + padding=call.attrs.padding, + dilation=call.attrs.dilation, + data_layout=call.attrs.data_layout, + kernel_layout=call.attrs.kernel_layout, + out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None, + primfunc_name_hint="conv1d", + ) + + +@register_legalize("relax.nn.conv2d") +def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.data_layout: + logging.info( + "TOPI conv2d does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + if len(call.attrs.data_layout) != 4 or len(call.attrs.kernel_layout) != 4: + logging.info( + "Conv2D where data layout or kernel layout have channel chunk " + "cannot be legalized by TOPI at this moment." + ) + return call + if call.attrs.groups != 1: + data_layout = tir.layout(call.attrs.data_layout) + kernel_layout = tir.layout(call.attrs.kernel_layout) + ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")] + oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")] + if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm): + logging.info( + "Conv2D where number of groups is more than one and input or output " + "channel size is symbolic cannot be legalized by TOPI at this moment." + ) + return call + + return bb.call_te( + topi.nn.conv, + inp=call.args[0], + filt=call.args[1], + stride=call.attrs.strides, + padding=call.attrs.padding, + dilation=call.attrs.dilation, + groups=call.attrs.groups, + data_layout=call.attrs.data_layout, + kernel_layout=call.attrs.kernel_layout, + out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None, + primfunc_name_hint="conv2d", + ) + + +@register_legalize("relax.nn.conv2d_transpose") +def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.data_layout: + logging.info( + "TOPI conv2d_transpose does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + if call.attrs.data_layout != "NCHW" or call.attrs.kernel_layout != "IOHW": + logging.info( + "TOPI conv2d_transpose does not support input layout other than NCHW, " + "and kernel layout other than IOHW, so cannot be legalized by TOPI" + ) + return call + dilation = call.attrs.dilation + if len(dilation) != 2 or dilation[0] != 1 or dilation[1] != 1: + logging.info( + "TOPI conv2d_transpose does not support dilations other than 1, " + "and thus cannot be legalized by TOPI" + ) + return call + + return bb.call_te( + topi.nn.group_conv2d_transpose_nchw, + call.args[0], + call.args[1], + stride=call.attrs.strides, + padding=call.attrs.padding, + out_dtype=call.struct_info.dtype, + output_padding=call.attrs.output_padding, + groups=call.attrs.groups, + primfunc_name_hint="conv2d_transpose", + ) + + +@register_legalize("relax.nn.max_pool2d") +def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.layout: + logging.info( + "TOPI max_pool2d does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + + return bb.call_te( + topi.nn.pool2d, + call.args[0], + kernel=call.attrs.pool_size, + stride=call.attrs.strides, + dilation=call.attrs.dilation, + padding=call.attrs.padding, + pool_type="max", + ceil_mode=call.attrs.ceil_mode, + layout=call.attrs.layout, + primfunc_name_hint="max_pool2d", + ) + + +@register_legalize("relax.nn.avg_pool2d") +def _nn_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.layout: + logging.info( + "TOPI avg_pool2d does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + + return bb.call_te( + topi.nn.pool2d, + call.args[0], + kernel=call.attrs.pool_size, + stride=call.attrs.strides, + dilation=call.attrs.dilation, + padding=call.attrs.padding, + pool_type="avg", + ceil_mode=call.attrs.ceil_mode, + layout=call.attrs.layout, + primfunc_name_hint="avg_pool2d", + ) + + +@register_legalize("relax.nn.adaptive_avg_pool2d") +def _nn_adaptive_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.layout: + logging.info( + "TOPI adaptive_avg_pool2d does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + + def te_adaptive_avg_pool2d(data, output_size, layout_str): + if output_size is None: + layout = tir.layout(layout_str) + idx_H = layout.index_of("H") + idx_W = layout.index_of("W") + assert idx_H != -1 and idx_W != -1 + output_size = (data.shape[idx_H], data.shape[idx_W]) + + return topi.nn.adaptive_pool(data, output_size, "avg", layout_str) + + return bb.call_te( + te_adaptive_avg_pool2d, + call.args[0], + call.attrs.output_size, + call.attrs.layout, + primfunc_name_hint="adaptive_avg_pool2d", + ) + + +register_legalize("relax.nn.relu", _call_topi_without_attr(topi.nn.relu)) + + +@register_legalize("relax.nn.gelu") +def _nn_gelu(bb: BlockBuilder, call: Call) -> Expr: + def te_gelu(x: te.Tensor): + dtype = x.dtype + erf_inp = x * tir.const(0.5**0.5, dtype) + + if dtype == "float16": + erf = topi.math.cast(topi.erf(topi.math.cast(erf_inp, "float32")), "float16") + else: + erf = topi.erf(erf_inp) + + return x * (tir.const(0.5, dtype) + erf * tir.const(0.5, dtype)) + + return bb.call_te(te_gelu, call.args[0], primfunc_name_hint="gelu") + + +@register_legalize("relax.nn.silu") +def _nn_silu(bb: BlockBuilder, call: Call) -> Expr: + def te_silu(x: te.Tensor): + return topi.multiply(x, topi.sigmoid(x)) + + return bb.call_te(te_silu, call.args[0], primfunc_name_hint="silu") + + +@register_legalize("relax.nn.softmax") +def _nn_softmax(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.nn.softmax, call.args[0], call.attrs.axis) + + +@register_legalize("relax.nn.log_softmax") +def _nn_log_softmax(bb: BlockBuilder, call: Call): + return bb.call_te(topi.nn.log_softmax, call.args[0], call.attrs.axis) + + +@register_legalize("relax.nn.cross_entropy_with_logits") +def _nn_cross_entropy_with_logits(bb: BlockBuilder, call: Call): + def te_cross_entropy_with_logits(x, y): + if len(x.shape) > 1: + return -topi.sum(x * y) / x.shape[0] + return -topi.sum(x * y) + + return bb.call_te( + te_cross_entropy_with_logits, + call.args[0], + call.args[1], + primfunc_name_hint="cross_entropy_with_logits", + ) + + +@register_legalize("relax.nn.batch_norm") +def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.batch_norm, + data=call.args[0], + gamma=call.args[1], + beta=call.args[2], + moving_mean=call.args[3], + moving_var=call.args[4], + axis=call.attrs.axis, + epsilon=call.attrs.epsilon, + center=call.attrs.center, + scale=call.attrs.scale, + ) + + +@register_legalize("relax.nn.layer_norm") +def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.layer_norm, + call.args[0], + call.args[1], + call.args[2], + axis=call.attrs.axes, + epsilon=call.attrs.epsilon, + ) + + +@register_legalize("relax.nn.group_norm") +def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.group_norm, + call.args[0], + call.args[1], + call.args[2], + call.attrs.num_groups, + call.attrs.channel_axis, + call.attrs.axes, + call.attrs.epsilon, + ) + + +@register_legalize("relax.nn.dropout") +def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr: + logging.info("Dropout is handled by frontend translator at this moment and is not legalized.") + return call + + +def _te_attention( + q: te.Tensor, k: te.Tensor, v: te.Tensor, bias: te.Tensor, scale: tir.FloatImm +) -> te.Tensor: + batch_size, seq_len, num_head, head_dim = q.shape + _, seq_len_kv, _, head_dim_v = v.shape + q = topi.transpose(q, [0, 2, 1, 3]) + k = topi.transpose(k, [0, 2, 1, 3]) + v = topi.transpose(v, [0, 2, 1, 3]) + q = topi.reshape(q, [batch_size * num_head, seq_len, head_dim]) + k = topi.reshape(k, [batch_size * num_head, seq_len_kv, head_dim]) + v = topi.reshape(v, [batch_size * num_head, seq_len_kv, head_dim_v]) + p = topi.nn.batch_matmul(q, k) + if scale is not None: + p = topi.multiply(p, scale) + else: + p = topi.divide(p, tir.sqrt(tir.Cast(p.dtype, head_dim))) + if bias is not None: + p = topi.reshape(p, [batch_size, num_head, seq_len, seq_len_kv]) + if len(bias.shape) == 2: + bias = topi.reshape(bias, [batch_size, 1, 1, seq_len_kv]) + elif len(bias.shape) == 3: + bias = topi.reshape(bias, [batch_size, 1, seq_len, seq_len_kv]) + p = topi.add(p, bias) + p = topi.reshape(p, [batch_size * num_head, seq_len, seq_len_kv]) + s = topi.nn.softmax(p) + o = topi.nn.batch_matmul(s, v, transpose_b=False) + o = topi.reshape(o, [batch_size, num_head, seq_len, head_dim_v]) + return topi.transpose(o, [0, 2, 1, 3]) + + +@register_legalize("relax.nn.attention") +def _nn_attention(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + _te_attention, + call.args[0], + call.args[1], + call.args[2], + None, + call.attrs.scale, + primfunc_name_hint="attention", + ) + + +@register_legalize("relax.nn.attention_bias") +def _nn_attention_bias(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + _te_attention, + call.args[0], + call.args[1], + call.args[2], + call.args[3], + call.attrs.scale, + primfunc_name_hint="attention_bias", + ) diff --git a/python/tvm/relax/transform/legalize_ops/search.py b/python/tvm/relax/transform/legalize_ops/search.py new file mode 100644 index 000000000000..6d36795e76f3 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/search.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for search operators.""" +from tvm import topi +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import TEFunc, LegalizeFunc +from .common import _call_topi_without_attr, register_legalize + +register_legalize("relax.where", _call_topi_without_attr(topi.where)) + + +def _argmax_argmin(te_func: TEFunc) -> LegalizeFunc: + def argmax_argmin_call_te(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + te_func, + call.args[0], + None if call.attrs.axis is None else call.attrs.axis.value, + call.attrs.keepdims, + ) + + return argmax_argmin_call_te + + +register_legalize("relax.argmax", _argmax_argmin(topi.argmax)) +register_legalize("relax.argmin", _argmax_argmin(topi.argmin)) diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py new file mode 100644 index 000000000000..3307d49f219f --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/statistical.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for statistical operators.""" +from typing import List +from tvm import topi, tir, te +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import TEFunc, LegalizeFunc, register_legalize + + +def _statistical(te_func: TEFunc) -> LegalizeFunc: + def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(te_func, call.args[0], call.attrs.axis, call.attrs.keepdims) + + return statistical_call_te + + +def _compute_shape_prod(x: te.Tensor, axis: List[tir.IntImm]) -> tir.PrimExpr: + shape_prod = tir.const(1, "int32") + axes = [_axis.value for _axis in axis] if axis is not None else range(0, len(x.shape)) + for dim in axes: + shape_prod = shape_prod * x.shape[dim] + return shape_prod + + +def _te_mean(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor: + shape_prod = _compute_shape_prod(x, axis) + res_sum = topi.sum(x, axis, keepdims) + return topi.divide(res_sum, shape_prod) + + +def _te_variance(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor: + dev = x - _te_mean(x, axis, keepdims) + return _te_mean(dev * dev, axis, keepdims) + + +@register_legalize("relax.mean") +def _mean(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + _te_mean, call.args[0], call.attrs.axis, call.attrs.keepdims, primfunc_name_hint="mean" + ) + + +@register_legalize("relax.std") +def _std(bb: BlockBuilder, call: Call) -> Expr: + def te_std(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor: + return topi.sqrt(_te_variance(x, axis, keepdims)) + + return bb.call_te( + te_std, call.args[0], call.attrs.axis, call.attrs.keepdims, primfunc_name_hint="std" + ) + + +@register_legalize("relax.variance") +def _variance(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + _te_variance, + call.args[0], + call.attrs.axis, + call.attrs.keepdims, + primfunc_name_hint="variance", + ) + + +register_legalize("relax.max", _statistical(topi.max)) +register_legalize("relax.min", _statistical(topi.min)) +register_legalize("relax.prod", _statistical(topi.prod)) +register_legalize("relax.sum", _statistical(topi.sum)) diff --git a/python/tvm/relax/transform/legalize_ops/unary.py b/python/tvm/relax/transform/legalize_ops/unary.py new file mode 100644 index 000000000000..ca84cbf0add7 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/unary.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default legalization function for unary operators.""" +from tvm import topi +from .common import _call_topi_without_attr, register_legalize + +# To avoid conflict of IRModule function name and libc function name, we add +# "tir_" as the prefix of the generated PrimFunc name. +register_legalize("relax.abs", _call_topi_without_attr(topi.abs, "tir_abs")) +register_legalize("relax.ceil", _call_topi_without_attr(topi.ceil, "tir_ceil")) +register_legalize("relax.cos", _call_topi_without_attr(topi.cos, "tir_cos")) +register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log")) +register_legalize("relax.exp", _call_topi_without_attr(topi.exp, "tir_exp")) +register_legalize("relax.floor", _call_topi_without_attr(topi.floor, "tir_floor")) +register_legalize("relax.negative", _call_topi_without_attr(topi.negative, "tir_negative")) +register_legalize("relax.round", _call_topi_without_attr(topi.round, "tir_round")) +register_legalize("relax.sigmoid", _call_topi_without_attr(topi.sigmoid, "tir_sigmoid")) +register_legalize("relax.sign", _call_topi_without_attr(topi.sign, "tir_sign")) +register_legalize("relax.sin", _call_topi_without_attr(topi.sin, "tir_sin")) +register_legalize("relax.sqrt", _call_topi_without_attr(topi.sqrt, "tir_sqrt")) +register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh")) +register_legalize("relax.clip", _call_topi_without_attr(topi.clip, "tir_clip")) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py new file mode 100644 index 000000000000..049ac2947f27 --- /dev/null +++ b/python/tvm/relax/transform/transform.py @@ -0,0 +1,1026 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Relax transformation passes.""" +import functools +import inspect +import types +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union + +import numpy as np # type: ignore + +import tvm.ir +from tvm.relax import Expr, Var +from tvm.relax.dpl import DFPattern +from tvm.runtime import NDArray, Object +from tvm.tir import IndexMap, PrimFunc + +from . import _ffi_api +from .legalize_ops.common import LegalizeFunc + + +@tvm._ffi.register_object("relax.FunctionPass") +class FunctionPass(tvm.ir.transform.Pass): + """A pass that works on each tvm.relax.Function in a module. A function + pass class should be created through `function_pass`. + """ + + +@tvm._ffi.register_object("relax.DataflowBlockPass") +class DataflowBlockPass(tvm.ir.transform.Pass): + """A pass that works on each tvm.relax.DataflowBlock in a module.""" + + +def ToNonDataflow() -> tvm.ir.transform.Pass: + """Transform all dataflow structure to non-dataflow version. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.ToNonDataflow() # type: ignore + + +def LambdaLift(): + """A pass that lifts local functions into global. + + Returns + ------- + ret : tvm.ir.transform.Pass + """ + return _ffi_api.LambdaLift() + + +def CallTIRRewrite() -> tvm.ir.transform.Pass: + """Perform explicit tensor allocation for call_tir and call_dps_packed. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.CallTIRRewrite() # type: ignore + + +def Normalize() -> tvm.ir.transform.Pass: + """Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting + and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.Normalize() # type: ignore + + +def CanonicalizeBindings() -> tvm.ir.transform.Pass: + """ + Canonicalizes variable definitions + (e.g., if there is y = x and z = y, it replaces uses of y and z with x). + + Best combined with constant folding and the elimination of unused definitions. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.CanonicalizeBindings() # type: ignore + + +def EliminateCommonSubexpr() -> DataflowBlockPass: + """Eliminate common subexpressions within dataflow blocks. + + Note: For functions local to dataflow blocks, this pass performs + CSE *within* those functions + + Returns + ------- + ret : tvm.transform.Pass + The registered pass that eliminates common subexpressions. + """ + return _ffi_api.EliminateCommonSubexpr() # type: ignore + + +def RewriteDataflowReshape() -> tvm.ir.transform.Pass: + """Convert all reshape-like call_tir to VM reshape operator call. + The VM reshape operator calls will be further lowered to a CreateView + operation at runtime, instead of doing real data copy. + Here "reshape-like" includes reshape, expand_dims, flatten, etc. + + Returns + ------- + ret : tvm.ir.transform.Pass + """ + return _ffi_api.RewriteDataflowReshape() # type: ignore + + +def StaticPlanBlockMemory() -> tvm.ir.transform.Pass: + """The static memory planning pass on BindingBlock level. + The pass will reuse allocated memory to its best effort, in order to + reduce the total amount of allocated memory size. + Returns + ------- + ret : tvm.ir.transform.Pass + """ + return _ffi_api.StaticPlanBlockMemory() # type: ignore + + +def VMBuiltinLower() -> tvm.ir.transform.Pass: + """Lowering generic intrinsic to VM intrinsics. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.VMBuiltinLower() # type: ignore + + +def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass: + """Lower the symbolic shape and argument and match-cast structinfo matching. + + Parameters + ---------- + emit_err_ctx: Optional[bool] + Whether emit err context string, can be turned off for testing purposes. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.VMShapeLower(emit_err_ctx) # type: ignore + + +def AttachGlobalSymbol() -> tvm.ir.transform.Pass: + """Attach global_symbol to Relax functions and TIR Primfuncs for codegen. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.AttachGlobalSymbol() # type: ignore + + +def BindParams( + func_name: str, + params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]], +) -> tvm.ir.transform.Pass: + """Bind params of function of the module to constant tensors. + + Parameters + ---------- + + func_name: str + The function name to be bound + + params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]] + The map from param name to constant tensors. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + tvm_params = {} + for k, v in params.items(): + if isinstance(v, np.ndarray): + v = tvm.nd.array(v) + assert isinstance( + v, tvm.runtime.NDArray + ), f"param values are expected to be TVM.NDArray or numpy.ndarray, but got {type(v)}" + tvm_params[k] = v + + return _ffi_api.BindParams(func_name, tvm_params) # type: ignore + + +def RunCodegen( + target_options: Optional[dict] = None, + entry_functions: Optional[List[str]] = None, +) -> tvm.ir.transform.Pass: + """Produce the runtime::Module with an annotated codegen and global symbol. + + Parameters + ---------- + target_options: Optional[dict] + Pairs of a target name and compilation options + entry_functions: Optional[List[str]] + The set of entry functions to start from. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass to remove unused functions. + """ + if entry_functions is None: + entry_functions = ["main"] + # enable cutlass byoc registries + # pylint: disable=unused-import,import-outside-toplevel + from tvm.contrib import cutlass as _cutlass + + return _ffi_api.RunCodegen(target_options, entry_functions) # type: ignore + + +def FoldConstant() -> tvm.ir.transform.Pass: + """Fold constant expressions. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.FoldConstant() # type: ignore + + +def AnnotateTIROpPattern() -> tvm.ir.transform.Pass: + """Annotate Op Pattern Kind for TIR functions + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.AnnotateTIROpPattern() # type: ignore + + +def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass: + """This pass groups bindings in a dataflow block of Relax functions and generate a new grouped + Relax function for each group, according to the fusion algorithm described in the pass + implementation. By grouping bindings into new Relax functions, we substitute the bindings in + the function being manipulated into function calls to the new grouped function. + + A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. + + Parameters + ---------- + fuse_opt_level : int + The level of fuse optimization. -1 indicates that the level will be + inferred from pass context. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for operator fusion. + """ + return _ffi_api.FuseOps(fuse_opt_level) # type: ignore + + +def FuseTIR() -> tvm.ir.transform.Pass: + """Fuse primitive relax function into a larger TIR function if possible + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for tir fusion. + """ + return _ffi_api.FuseTIR() # type: ignore + + +@tvm._ffi.register_object("relax.transform.PatternCheckContext") +class PatternCheckContext(Object): + """ + The input of check function `FusionPattern.check`. + + Parameters + ---------- + matched_expr: Expr + The expression that's matched with the FusionPattern.pattern. + + annotated_expr: Mapping[str, Expr] + A map which contains all expressions matched by the sub patterns in + FusionPattern.annotation_patterns. + + matched_bindings: Mapping[Var, Expr] + Map from variable to its value. It contains variables from bindings that is + being fused by FuseOpsByPattern. + + var_usages: Mapping[Var, Sequence[Var]] + A map mapping variable definitions to a set of uses. It has all variables + used in the function. + + value_to_bound_var: Mapping[Expr, Var] + Map from value to its bound variable. It doesn't have variables after the + matched expression. + """ + + matched_expr: Expr + annotated_expr: Mapping[str, Expr] + matched_bindings: Mapping[Var, Expr] + var_usages: Mapping[Var, Sequence[Var]] + value_to_bound_var: Mapping[Expr, Var] + + +@tvm._ffi.register_object("relax.transform.FusionPattern") +class FusionPattern(Object): + """ + The pattern used by `FuseOpsByPattern`. It's mainly DFPattern but with other + information to help during the fusion pass. + + Parameters + ---------- + name: str + The name of pattern. Usually it starts with the name of backend, like 'cutlass.matmul'. + + pattern: DFPattern + The dataflow pattern that will be used to match expressions that can be handled + by external backends. + + annotation_patterns: Mapping[str, DFPattern] + The map which is used to extract important expressions from the pattern match + result. All DFPattern in this map should be part of the `pattern`. + + check: Callable[[PatternCheckContext], bool] + The function to check whether the match result is accepted. + """ + + name: str + pattern: DFPattern + annotation_patterns: Mapping[str, DFPattern] + check: Callable[[PatternCheckContext], bool] + + def __init__( + self, + name: str, + pattern: DFPattern, + annotation_patterns: Optional[Mapping[str, DFPattern]] = None, + check: Optional[Callable[[Mapping[str, Expr]], bool]] = None, + ): + if annotation_patterns is None: + annotation_patterns = {} + self.__init_handle_by_constructor__( + _ffi_api.FusionPattern, name, pattern, annotation_patterns, check # type: ignore + ) + + +def FuseOpsByPattern( + patterns: List[Union[FusionPattern, Tuple]], + bind_constants: bool = True, + annotate_codegen: bool = False, +) -> tvm.ir.transform.Pass: + """Apply pattern matching to each function in the given module, and group matched expressions + into a new function. + + The end result is similar to FuseOps, but fusion is driven completely by the provided patterns. + + Parameters + ---------- + patterns : List[Union[FusionPattern, Tuple]] + A list of patterns to be matched. The order of the patterns determines the order of priority + in which they are matched. Higher-priority patterns should come earlier in the list. + + In addition to FusionPattern, a tuple can be passed as item of this list. The pattern + will be constructed through FusionPattern(*item) + + bind_constants : bool + Whether or not to keep bound constants in the grouped function. + + annotate_codegen : bool + If True, wrap each created composite function with another function, whose body consists + only of a call to the composite function, and annotate the outer function with "Codegen" + and "global_symbol" attributes. The "Codegen" attribute is set as the prefix of the + corresponding pattern name. For example, "dnnl" if the pattern name is "dnnl.conv2d_relu". + + This must be True if the created composite functions are intended to be offloaded to + an external backend without using the MergeCompositeFunctions pass. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for pattern-based fusion. + + """ + converted_patterns = [] + for pattern in patterns: + if isinstance(pattern, tuple): + converted_patterns.append(FusionPattern(*pattern)) + elif isinstance(pattern, FusionPattern): + converted_patterns.append(pattern) + else: + raise ValueError(f"Invalid pattern: {pattern}") + + return _ffi_api.FuseOpsByPattern( + converted_patterns, + bind_constants, + annotate_codegen, + ) # type: ignore + + +def MergeCompositeFunctions() -> tvm.ir.transform.Pass: + """Group one or multiple composite functions created by FuseOpsByPattern into a new function. + The new function will be annotated with "Codegen" and "global_symbol" attributes, and it + is intented to be offloaded to an external backend. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for merging composite functions. + """ + return _ffi_api.MergeCompositeFunctions() # type: ignore + + +def LiftTransformParams() -> tvm.ir.transform.Pass: + """Lift transformation of the parameters of a function. + + When some inputs of the function is marked as 'parameters' (the model weights), this pass + identifies the transformation of the parameters and lifts them to a separate function called + `transform_params`. `transform_params` takes a tuple of the original parameters as input and + returns a tuple of the transformed parameters. The original function will be rewritten to accept + a tuple of transformed parameters as input. + + Users are expected to invoke the `transform_params` function in runtime and pass the transformed + parameters to the original function as input. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for lifting transformation of parameters. + """ + return _ffi_api.LiftTransformParams() # type: ignore + + +def LegalizeOps(customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None): + """Legalize high-level operator calls in Relax functions to call_tir + with corresponding low-level TIR PrimFuncs. + + For each high-level operator, we register the way of legalizing it as a + function, which takes a context BlockBuilder and the Call being legalized + as input, and returns the legalized call. Here the input BlockBuilder is + mainly used for adding the PrimFunc created by call_te into the context + IRModule. + + The legalization function for each operator is registered as an attribute (with + attribute key `FLegalize`) of the operator. + + This pass provides customizability for users to use their own legalization + function for operators. The pass takes an optional customized map, + with the key to be the operator name (`str`) and value to be the function + (`LegalizeFunc`). The default legalization function will be overridden by the customized + one. + + Parameters + ---------- + customize_legalize_map : Optional[Dict[str, LegalizeFunc]] + The customized operator legalization function map. The customized function will override + the default one. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass + + Examples + -------- + The following code shows how to use this pass: + + .. code-block:: python + + # Define the pass input IRModule + @tvm.script.ir_module + class Module: + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + z: R.Tensor((2, 3), "float32") = R.add(x, y) + r: R.Tensor((2, 3), "float32") = R.multiply(y, z) + return r + + # Define the customized legalization function for "relax.add" + def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr: + from tvm import topi + return bb.call_te(topi.add, call.args[1], call.args[0]) + + # Apply the pass with the customized function to the module. + mod = LegalizeOps({"relax.add": customize_legalize_add})(Module) + + Print out the result by `mod.show()`, we can see the IRModule after + legalization becomes + + .. code-block:: python + + @tvm.script.ir_module + class Module: + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + z = R.call_tir(add, (y, x), (2, 3), dtype="float32") + r = R.call_tir(multiply, (y, z), (2, 3), dtype="float32") + return r + + @T.prim_func + def add( + A: T.Buffer((2, 3), "float32"), + B: T.Buffer((2, 3), "float32"), + T_add: T.Buffer((2, 3), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for ax0, ax1 in T.grid(2, 3): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1] + + @T.prim_func + def multiply( + A: T.Buffer((2, 3), "float32"), + B: T.Buffer((2, 3), "float32"), + T_multiply: T.Buffer((2, 3), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for ax0, ax1 in T.grid(2, 3): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(T_multiply[v_ax0, v_ax1]) + T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1] + """ + + return _ffi_api.LegalizeOps(customize_legalize_map) # type: ignore + + +def MetaScheduleApplyDatabase( + work_dir: Optional[str] = None, +) -> tvm.ir.transform.Pass: + """Apply the best schedule from tuning database. + work_dir : Optional[str] + work directory to deduce default database if database is not provided + (it will be ignored when an user passes database) + Returns + ------- + ret : tvm.transform.Pass + The registered pass + """ + return _ffi_api.MetaScheduleApplyDatabase(work_dir) # type: ignore + + +def MetaScheduleTuneTIR( + work_dir: str, + max_trials_global: int, +) -> tvm.ir.transform.Pass: + """Tune TIR with MetaSchedule. + Parameters + ---------- + work_dir: str + work directory + max_trials_gloabl: int + maximum number of total trials allowed for tuning + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.MetaScheduleTuneTIR(work_dir, max_trials_global) # type: ignore + + +def MetaScheduleTuneIRMod( + params: Dict[str, NDArray], + work_dir: str, + max_trials_global: int, +) -> tvm.ir.transform.Pass: + """Tune Relax IRModule with MetaSchedule. + Parameters + ---------- + params: Dict[str, NDArray] + model params + work_dir: str + work directory + max_trials_gloabl: int + maximum number of total trials allowed for tuning + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.MetaScheduleTuneIRMod(params, work_dir, max_trials_global) # type: ignore + + +def DecomposeCompositeOps() -> tvm.ir.transform.Pass: + """Decompose composite operators that are composed by other operators during inference. + For example, the result of a batch norm which is indexed at tuple index 0 will be unpacked + into a number of simplified operators. Attention, tensor_to_shape, etc. can be also + decomposed into a number of simplified operators as well. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass + """ + + return _ffi_api.DecomposeCompositeOps() # type: ignore + + +def AlterOpImpl( + op_impl_map: Dict[str, PrimFunc], + op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]], +): + """Replace all PrimFunc's which have matching 'operator_name' attribute, with replacement + PrimFunc that could possibly have different layouts on i/o buffers. The layout + transformations on i/o buffers is present in the op_buffer_transforms map. Inserts the layout + transformations in the call sites of PrimFuncs being replaced to transform i/o + tensors into expected layout by new PrimFunc. + + Parameters + ---------- + op_impl_map: Dict[str, PrimFunc] + op_kind to PrimFunc map + op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]] + op_kind to layout transformation map for each of the buffers + Returns + ------- + ret: tvm.ir.transform.Pass + """ + for operator_name, transform_list in op_buffer_transforms.items(): + l = [] + for transform in transform_list: + if isinstance(transform, Callable): + transform = IndexMap.from_func(transform) + l.append(transform) + op_buffer_transforms[operator_name] = l + + return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms) # type: ignore + + +def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pass: + """Automatic layout conversion pass. + Parameters + ---------- + desired_layouts : Dict[str, List[str]] + The desired layout of conv2d ops is a map from the name of the op to the desired layout + of the desired feature map, weight and output. For example, if we want to convert the + layout of conv2d from NCHW to NHWC, we can set the desired layout of conv2d to be + {"conv2d": ["NHWC", "OHWI"]}. + Returns + ------- + ret : tvm.transform.Pass + The registered pass for layout conversion. + """ + return _ffi_api.ConvertLayout(desired_layouts) # type: ignore + + +def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> tvm.ir.transform.Pass: + """Remove dead code in the IRModule. + Currently it removes: + 1. Unused local VarBindings in a DataflowBlock. + 2. Unused DataflowBlocks in a function. + 3. Unused Relax functions in the module. + We detect the call chain from the entry function, and remove all unused functions. + + Parameters + ---------- + entry_functions: Optional[List[str]] + The set of entry functions to start from. + + Notes + ----- + For function-wise DCE, use py:func:`tvm.relax.analysis.remove_all_unused`. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass. + """ + if entry_functions is None: + entry_functions = ["main"] + return _ffi_api.DeadCodeElimination(entry_functions) # type: ignore + + +def ToMixedPrecision(out_dtype="float32") -> tvm.ir.transform.Pass: + """Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 + only, and will automatically cast fp32 to fp16 for certain ops. + Parameters + ---------- + out_dtype : str + The output data type of gemm/conv, which is the data type of the accumulator. + Returns + ------- + ret : tvm.transform.Pass + The registered pass for mixed precision. + """ + return _ffi_api.ToMixedPrecision(out_dtype) # type: ignore + + +def SplitCallTIRByPattern(patterns, fcodegen) -> tvm.ir.transform.Pass: + """Split a PrimFunc into 2 parts: the first part is a TIR PrimFunc which is + matched with some pattern, and the second part is the rest of the original + PrimFunc. It will call fcodegen to generate the code for the matched pattern + to replace it with a ExternFunc call. + Parameters + ---------- + patterns : List[PrimFunc] + The list of patterns to match. + fcodegen: Callable[[List[MatchResult]], List[Object]] + The function to generate the code for the matched patterns. + Returns + ------- + ret : tvm.transform.Pass + The registered pass for splitting call_tir. + """ + return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen) # type: ignore + + +def _wrap_class_function_pass(pass_cls, pass_info): + """Wrap a python class as function pass.""" + + class PyFunctionPass(FunctionPass): + """Internal wrapper class to create a class instance.""" + + def __init__(self, *args, **kwargs): + # initialize handle in case pass_cls creation failed. + self.handle = None + inst = pass_cls(*args, **kwargs) + + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(func, mod, ctx): + return inst.transform_function(func, mod, ctx) + + self.__init_handle_by_constructor__( + _ffi_api.MakeFunctionPass, _pass_func, pass_info # type: ignore + ) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__) + PyFunctionPass.__name__ = pass_cls.__name__ + PyFunctionPass.__doc__ = pass_cls.__doc__ + PyFunctionPass.__module__ = pass_cls.__module__ + return PyFunctionPass + + +def function_pass( + pass_func=None, + opt_level=None, + name=None, + required=None, + traceable=False, +) -> Union[Callable, FunctionPass]: + """Decorate a function pass. + + This function returns a callback when pass_func + is provided. Otherwise, it returns the created function pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]] + The transformation function or class. + + opt_level : int + The optimization level of this function pass. + + name : Optional[str] + The name of the function pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the function pass is dependent on. + + traceable: Boolean + Boolean variable whether the function pass is traceable + + Returns + ------- + create_function_pass : Union[Callable, FunctionPass] + + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new FunctionPass will be returned when we decorate a pass function. + A new FunctionPass class will be returned when we decorate a class type. + + Examples + -------- + The following code block decorates a function pass class. + + .. code-block:: python + + @relax.transform.function_pass(opt_level=1) + class TestReplaceFunc: + def __init__(self, new_func): + self.new_func = new_func + + def transform_function(self, func, mod, ctx): + # just for demo purposes + # transform func to new_func + return self.new_func + + @R.function + def f1(x: Tensor[(m, n), "float32"]): + return x + + @tvm.script.ir_module + class InputMod: + @R.function + def f2(x: Tensor[(m, n), "float32"]): + gv0 = relax.add(x, x) + return gv0 + # fpass is now a special pass that replaces every + # function to f1 + fpass = TestReplaceFunc(f1) + # now every function in InputMod is replaced by f1 + updated_mod = fpass(InputMod) + + + The following code creates a function pass by decorating + a user defined transform function. + + .. code-block:: python + + @relax.transform.function_pass(opt_level=2) + def transform(func, mod, ctx): + # my transformations here. + return func + + function_pass = transform + assert isinstance(function_pass, relax.transform.FunctionPass) + assert function_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = function_pass(m) + # Now transform should have been applied to every function in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the function pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + "list/tuple.") + + def create_function_pass(pass_arg): + """Internal function that creates a function pass""" + fname = name if name else pass_arg.__name__ + info = tvm.transform.PassInfo(opt_level, fname, required, traceable) + if inspect.isclass(pass_arg): + return _wrap_class_function_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Function pass") + return _ffi_api.MakeFunctionPass(pass_arg, info) # type: ignore + + if pass_func: + return create_function_pass(pass_func) + return create_function_pass + + +def _wrap_class_dataflowblock_pass(pass_cls, pass_info): + """Wrap a python class as dataflowblock pass""" + + class PyDataflowBlockPass(DataflowBlockPass): + """Internal wrapper class to create a class instance.""" + + def __init__(self, *args, **kwargs): + # initialize handle in case pass_cls creation failed. + self.handle = None + inst = pass_cls(*args, **kwargs) + + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(func, mod, ctx): + return inst.transform_dataflowblock(func, mod, ctx) + + self.__init_handle_by_constructor__( + _ffi_api.MakeDataflowBlockPass, _pass_func, pass_info # type: ignore + ) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyDataflowBlockPass.__init__, pass_cls.__init__) + PyDataflowBlockPass.__name__ = pass_cls.__name__ + PyDataflowBlockPass.__doc__ = pass_cls.__doc__ + PyDataflowBlockPass.__module__ = pass_cls.__module__ + return PyDataflowBlockPass + + +def dataflowblock_pass( + pass_func=None, opt_level=None, name=None, required=None, traceable=False +) -> Union[Callable, DataflowBlockPass]: + """Decorate a dataflowblock pass. + + This function returns a callback when pass_func + is provided. Otherwise, it returns the created dataflowblock pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(DataflowBlock, Module, PassContext) -> DataflowBlock]] + The transformation function or class. + + opt_level : int + The optimization level of this dataflowblock pass. + + name : Optional[str] + The name of the dataflowblock pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the dataflowblock pass is dependent on. + + traceable: Boolean + Boolean variable whether the dataflowblock pass is traceable + + Returns + ------- + create_dataflowblock_pass : Union[Callable, DataflowBlockPass] + + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new DataflowBlockPass will be returned when we decorate a pass function. + A new DataflowBlockPass class will be returned when we decorate a class type. + + Examples + -------- + The following code block decorates a dataflowblock pass class. + + .. code-block:: python + + @relax.transform.dataflowblock_pass(opt_level=1) + class TestReplaceBinding: + # Simple test function to replace the first VarBinding to another. + + def __init__(self): + # create a new VarBinding + m, n = tir.Var("m", "int64"), tir.Var("n", "int64") + lv0 = relax.Var("lv1", relax.TensorStructInfo([m, n], "float32")) + val = relax.const(np.random.rand(24, 56)) + self.new_binding = relax.VarBinding(lv0, val) + + def transform_dataflowblock(self, block, mod, ctx): + # just for demo purposes + # Replace the first binding in the DataflowBlock + new_bindings = [self.new_binding, block.bindings[1]] + new_block = relax.expr.DataflowBlock(new_bindings, block.span) + return new_block + + @tvm.script.ir_module + class InputMod: + @R.function + def f1(x: Tensor[(m, n), "float32"]): + with relax.dataflow(): + lv0 = relax.multiply(x, x) + gv0 = relax.add(x, x) + relax.output(gv0) + return gv0 + # block_pass is now a special pass that replaces every + # first binding to the constant value binding + block_pass = TestReplaceBinding() + # now every first binding in DataflowBlock of InputMod + # is replaced by new_binding + updated_mod = block_pass(InputMod) + + + The following code creates a dataflowblock pass by decorating + a user defined transform function. + + .. code-block:: python + + @relax.transform.dataflowblock_pass(opt_level=2) + def transform(block, mod, ctx): + # my transformations here. + return block + + block_pass = transform + assert isinstance(block_pass, relax.transform.DataflowBlockPass) + assert block_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = block_pass(m) + # Now transform should have been applied to every DataflowBlock in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the dataflowblock pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + "list/tuple.") + + def create_dataflowblock_pass(pass_arg): + """Internal function that creates a dataflowblock pass""" + fname = name if name else pass_arg.__name__ + info = tvm.transform.PassInfo(opt_level, fname, required, traceable) + if inspect.isclass(pass_arg): + return _wrap_class_dataflowblock_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for DataflowBlock pass") + return _ffi_api.MakeDataflowBlockPass(pass_arg, info) # type: ignore + + if pass_func: + return create_dataflowblock_pass(pass_func) + return create_dataflowblock_pass diff --git a/python/tvm/relax/transform/tuning_api/__init__.py b/python/tvm/relax/transform/tuning_api/__init__.py new file mode 100644 index 000000000000..6c39d5c5359e --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax Tunign Pass API""" + +from .primitives import * +from .default_functions import * +from .database import * diff --git a/python/tvm/relax/transform/tuning_api/_ffi_api.py b/python/tvm/relax/transform/tuning_api/_ffi_api.py new file mode 100644 index 000000000000..f31522d02595 --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for relax.tuning_api""" +import tvm._ffi + +tvm._ffi._init_api("relax.tuning_api", __name__) diff --git a/python/tvm/relax/transform/tuning_api/database.py b/python/tvm/relax/transform/tuning_api/database.py new file mode 100644 index 000000000000..9477e142bad4 --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/database.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Tuning Pass API default functions""" +from typing import List, Optional +import logging + +from tvm.runtime import Object +from tvm.ir.module import IRModule +from tvm.meta_schedule.utils import _json_de_tvm +from tvm.meta_schedule.database import Workload +from tvm.tir.schedule.trace import JSON_TYPE +from tvm.target import Target +from tvm._ffi import register_object +from .primitives import Trace +from . import _ffi_api + +logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name + + +@register_object("relax.tuning_api.TuningRecord") +class TuningRecord(Object): + """The class of tuning records. + + Parameters + ---------- + trace : tvm.relax.transform.tuning_api.Trace + The trace of the tuning record. + run_secs : Optional[List[float]] + The run-time of the tuning record. + """ + + trace: Trace + run_secs: Optional[List[float]] + + def __init__( # type: ignore # pylint: disable=too-many-arguments + self, + trace: Trace, + run_secs: Optional[List[float]] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.TuningRecord, # type: ignore # pylint: disable=no-member + trace, + run_secs, + ) + + def as_json(self, include_irmod: bool = False) -> JSON_TYPE: + """Export the tuning record to a JSON string. + Parameters + ---------- + include_irmod: bool + Decides whether to serialize in_mod as well. + + Returns + ------- + json_str : str + The JSON string exported. + """ + return _json_de_tvm(_ffi_api.TuningRecordAsJSON(self, include_irmod)) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "TuningRecord": + """Create a tuning record from a json object. + + Parameters + ---------- + json_obj : JSON_TYPE + The json object to parse. + + Returns + ------- + tuning_record : TuningRecord + The parsed tuning record. + """ + return _ffi_api.TuningRecordFromJSON(json_obj) # type: ignore # pylint: disable=no-member + + +@register_object("relax.tuning_api.Database") +class Database(Object): + """The abstract database interface.""" + + def has_workload(self, mod: IRModule) -> bool: + """Check if the database has the given workload. + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + + Returns + ------- + result : bool + Whether the given workload is committed. + """ + return _ffi_api.DatabaseHasWorkload(self, mod) # type: ignore # pylint: disable=no-member + + def has_measurement_record(self, workload: Workload, target: Target) -> bool: + """Check if the database has a measurement record for the given workload and target pair. + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + + Returns + ------- + result : bool + Whether the given workload and target pair is committed for the measurement record. + """ + return _ffi_api.DatabaseHasMeasurementRecord(self, workload, target) # type: ignore # pylint: disable=no-member + + def has_tuning_record(self, workload: Workload, target: Target) -> bool: + """Check if the database has a tuning record for the given workload and target pair. + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + + Returns + ------- + result : bool + Whether the given workload and target pair is committed for the tuning record. + """ + return _ffi_api.DatabaseHasTuningRecord(self, workload, target) # type: ignore # pylint: disable=no-member + + def commit_workload(self, mod: IRModule) -> Workload: + """Commit a workload to the database if missing. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for or added. + + Returns + ------- + workload : Workload + The workload corresponding to the given IRModule. + """ + return _ffi_api.DatabaseCommitWorkload(self, mod) # type: ignore # pylint: disable=no-member + + def commit_measurement_record( + self, workload: Workload, target: Target, run_secs: List[float] + ) -> None: + """Commit a measurement record to the database. + A pair of workload and target will be used as a key. + + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + run_secs : Optional[List[float]] + The measurement record to add. + """ + _ffi_api.DatabaseCommitMeasurementRecord(self, workload, target, run_secs) # type: ignore # pylint: disable=no-member + + def commit_tuning_record( + self, workload: Workload, target: Target, record: TuningRecord + ) -> None: + """Commit a tuning record to the database. + A pair of workload and target will be used as a key. + + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + record : TuningRecord + The tuning record to add. + """ + _ffi_api.DatabaseCommitTuningRecord(self, workload, target, record) # type: ignore # pylint: disable=no-member + + def get_measurement_record(self, workload: Workload, target: Target) -> Optional[List[float]]: + """Get the measurement record of given workload and target from the database. + + Parameters + ---------- + workload : Workload + The workload to be searched for. + target: Target + The target to be searched for. + + Returns + ------- + measurement_record : Optional[List[float]] + Measurement record if exists. + """ + return _ffi_api.DatabaseGetMeasurementRecord(self, workload, target) # type: ignore # pylint: disable=no-member + + def get_top_k(self, workload: Workload, target: Target, top_k: int) -> List[TuningRecord]: + """Get the top K tuning records of given workload and target from the database. + + Parameters + ---------- + workload : Workload + The workload to be searched for. + target: Target + The target to be searched for. + top_k : int + The number of top records to get. + + Returns + ------- + top_k_records : List[TuningRecord] + The top K records. + """ + return _ffi_api.DatabaseGetTopK(self, workload, target, top_k) # type: ignore # pylint: disable=no-member + + +@register_object("relax.tuning_api.JSONDatabase") +class JSONDatabase(Database): + """The class of JSON database. + + Parameters + ---------- + path_workload : str + The path to the workload table. + path_tuning_record : str + The path to the tuning record table. + Manages pairs of + path_measurement_record : str + The path to the path_measurement_record table. + Manages pairs of + """ + + path_workload: str + path_tuning_record: str + path_measurement_record: str + + def __init__( + self, + path_workload: str, + path_tuning_record: str, + path_measurement_record: str, + allow_missing: bool = True, + ) -> None: + """Constructor. + + Parameters + ---------- + path_workload : str + The path to the workload table. + path_tuning_record : str + The path to the tuning record table. + path_measurement_record : str + The path to the path_measurement_record table. + allow_missing : bool + Whether to create new file when the given path is not found. + """ + self.__init_handle_by_constructor__( + _ffi_api.DatabaseJSONDatabase, # type: ignore # pylint: disable=no-member + path_workload, + path_tuning_record, + path_measurement_record, + allow_missing, + ) diff --git a/python/tvm/relax/transform/tuning_api/default_functions.py b/python/tvm/relax/transform/tuning_api/default_functions.py new file mode 100644 index 000000000000..7cdb211bd32f --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/default_functions.py @@ -0,0 +1,306 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Tuning Pass API default functions""" +from typing import Dict, List, Optional +import sys +import itertools +import logging +import numpy as np # type: ignore + +import tvm +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext, Pass +from tvm import meta_schedule +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder +from tvm.meta_schedule.utils import get_global_func_with_default_on_worker +from tvm.meta_schedule.runner import ( + EvaluatorConfig, + LocalRunner, + RunnerInput, +) +from tvm._ffi.registry import register_func +from .primitives import Knob, Trace + +logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name + +# Default transform func that returns original IRModule. +@tvm.register_func("relax.tuning_api.Choice.default_transform_func") +def default_transform_func(mod): + return mod + + +# Default constraint func that always returns true. +@tvm.register_func("relax.tuning_api.Choice.default_constr_func") +def default_constr_func(mod: IRModule) -> bool: # pylint: disable=unused-argument + return True + + +@register_func("relax.tuning_api.default_generate_candidate") +def default_generate_candidate( + knobs: List[Knob], trace: Trace, eval_passes: Optional[List[Pass]] = None +) -> List[Trace]: + """ + Default function to generate the search space for a given trace by using registered choices. + This function simply expands candidate space as long as the knob's constraint satisfies. + To reduce the search space, a developer may expand each choice with smart search method. + (e.g., genetic search, multi-armed bandit) + Note, each pass generates candidates without worrying about the interaction with other passes. + i.e., it only uses its incoming trace/IRModule and Choices for candidate generation. + This will help alleviating the complexity of joint-optimization significantly. + - consideration of interaction between optimizations has known to be extremely difficult. + + Parameters + ---------- + knobs : List[Knob] + List of Knobs to consider to generate candidate for input trace. + trace: Trace + Input trace. + eval_passes: Optional[List[Pass]] + List of passes to consider to evaluate each candidate. + This will enable joint-optimization. + + Return + ---------- + candidates: List[Trace] + List of candidate traces + """ + + candidates = [trace] + # Iterate over every decision + for knob in knobs: + num = len(candidates) + for _ in range(num): + cur_trace = candidates.pop(0) + for decision in knob.choices.keys(): + choice = knob.choices[decision] + # Generate new candidate when this condition satisfies. + if choice.check_constr(cur_trace.out_mod): + new_trace = cur_trace.deepcopy() + new_trace.add(knob, decision) + candidates.append(new_trace) + + # Expand candidates by using eval passes if provided. This will enable joint-optimization. + if eval_passes: + candidates = default_consider_eval_passes(candidates, eval_passes) + return candidates + + +@register_func("relax.tuning_api.default_consider_eval_passes") +def default_consider_eval_passes( + init_candidates: List[Trace], eval_passes: Optional[List[Pass]] = None +) -> List[Trace]: + """ + Default function to update traces with eval passes. + It visits each eval_pass in dfs order in transform.Sequential() and + returns the best possible candidate trace for each candidate. + + Parameters + ---------- + init_candidates: List[Trace] + Initial candidates + eval_passes: Optional[List[Pass]] + List of passes to consider to evaluate each candidate. + This will enable joint-optimization. + Return + ---------- + candidates: List[Trace] + List of candidate traces + """ + if not eval_passes: + return init_candidates + + eval_passes = list(eval_passes) if not isinstance(eval_passes, list) else eval_passes + ctx = PassContext.current() + candidates = [] + + for trace in init_candidates: + ctx.push_trace(trace) + tvm.transform.Sequential(eval_passes)(trace.out_mod) + new_trace = ctx.pop_trace() + # A new trace contains the best decisions in eval_passes + candidates.append(new_trace) + + return candidates + + +@register_func("relax.tuning_api.default_evaluate") +def default_evaluate( + candidates: List[Trace], + target_str: str, + params: Optional[Dict[str, np.ndarray]] = None, + builder: Optional[meta_schedule.builder.Builder] = None, + runner: Optional[meta_schedule.runner.Runner] = None, +) -> None: + """ + Default function to evaluate a set of candidate traces by using MetaSchedule builder/runner. + + Parameters + ---------- + candidates: List[Trace] + List of traces to evaluate. + target_str: str, + Compilation target (e.g., llvm, cuda). + params: Optional[Dict[str, np.ndarray]] + Params to bind. + builder: Optional[meta_schedule.builder.Builder] + builder function. If not provided, default local builder will be used. + runner: Optional[meta_schedule.runner.Runner] + runner function. If not provided, default local runner will be used. + """ + + ctx = PassContext.current() + target = tvm.target.Target(target_str) + database = PassContext.current().get_tuning_api_database() + # Setup default local builder if not provided + if builder is None: + + def relax_build( + mod: IRModule, + target: tvm.target.Target, + params: Optional[Dict[str, np.ndarray]], + ): + if params: + mod = tvm.relax.transform.BindParams("main", params)(mod) + relax_exec = tvm.relax.build(mod, target) + return relax_exec.mod + + builder = LocalBuilder(f_build=relax_build) + + # Setup default local runner if not provided + if runner is None: + + def relax_eval_func(rt_mod, device, evaluator_config, repeated_args): + relax_exec = tvm.relax.Executable(rt_mod) + relax_vm = tvm.relax.VirtualMachine(relax_exec, device=device) + + evaluator = relax_vm.module.time_evaluator( + func_name="main", + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + + return costs + + runner = LocalRunner( + evaluator_config=EvaluatorConfig( + number=3, repeat=5, min_repeat_ms=100, enable_cpu_cache_flush=False + ), + f_run_evaluator=relax_eval_func, + ) + + # set up clean up function + f_clean_build = get_global_func_with_default_on_worker("meta_schedule.remove_build_dir", None) + assert f_clean_build + + # Keep track of number of evaluations (mostly for the debugging purpose) + num_evals = 0 + # Evaluation + for candidate in candidates: + # If this candidate is already evaluated, skip the measurement + if candidate.perf != -1: + continue + + # Evaluate candidates + num_evals += 1 + mod = candidate.out_mod + workload = database.commit_workload(mod) + + # If this workload and target pair has measured before, fetch its data. + if database.has_measurement_record(workload, target): + run_secs = database.get_measurement_record(workload, target) + # Otherwise, measure it. + else: + # Build candidate + (builder_result,) = builder.build([BuilderInput(mod, target, params)]) + + if builder_result.artifact_path is None: + # Build error + # Assign the worst performance and move on to the next candidate. + logger.warning(builder_result.error_msg) + run_secs = [1e100] + else: + # If build passes, set up runner input and measure the performance. + args_info = [ + TensorInfo( + shape=[int(i) for i in p.struct_info.shape], dtype=p.struct_info.dtype + ) + for p in mod["main"].params + ] # convert list[Var] to list[TensorInfo] + runner_input = RunnerInput( + builder_result.artifact_path, target_str, args_info=args_info + ) + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + + run_secs = runner_result.run_secs + # Runtime error + # Assign the worst performance and move on to the next candidate. + if runner_result.error_msg is not None: + logger.warning(runner_result.error_msg) + run_secs = [1e100] + + database.commit_measurement_record(workload, target, run_secs) + + # Clean up the artifact + f_clean_build(builder_result.artifact_path) + + # For valid measurments, compute the average and update the trace performance. + perfs = [] + for result in run_secs: + if isinstance(result, tvm.tir.FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + perfs.append(result) + + # Store the evaluation result + candidate.set_perf(np.mean(perfs)) + + ctx.inc_num_evals(num_evals) + + +def select_best_candidate(candidates: List[Trace]) -> Trace: + """ + Select the best trace. + + Parameters + ---------- + candidates: List[Trace] + Candidate traces + + Return + ---------- + best_trace: Trace + Trace with the best performance + """ + best_perf, best_trace = sys.maxsize, None + for candidate in candidates: + avg = candidate.perf + # Select best one + if best_perf > avg: + best_perf = avg + best_trace = candidate + return best_trace diff --git a/python/tvm/relax/transform/tuning_api/primitives.py b/python/tvm/relax/transform/tuning_api/primitives.py new file mode 100644 index 000000000000..67b81ba7e99c --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/primitives.py @@ -0,0 +1,419 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Tuning Pass API primitives""" + +from typing import Callable, Union, Dict, List, Optional, Sequence +import logging +import tvm +from tvm.runtime import Object +from tvm.ir.module import IRModule +from tvm.relax import Expr +from tvm.tir.schedule.trace import JSON_TYPE, _json_from_tvm +from tvm._ffi import register_object +from . import _ffi_api + +logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name + + +@register_object("relax.tuning_api.Choice") +class Choice(Object): + """ + A TVM object Choice that maintains a set of transformation and constraint function keys. + Corresponding functions should be registered as PackedFunc with these keys. + Transformation function will be applied when constraint function returns true. + Parameters + ---------- + transform_func_key : Optional[str] + Key for transformation function. + transform_func_args : Optional[List] + Arguments for transformation function. + constr_func_key : Optional[str] + Key for constraint function. + constr_func_args : Optional[List] + Arguments for constraint function. + + Examples + -------- + The following code block defines a Choice. + + .. code-block:: python + @tvm.register_func("relax.tuning_api.test.transform_func") + def apply(mod): + return relax.tuning_api.FoldConstant()(mod) + @tvm.register_func("relax.tuning_api.test.constr_func") + def constr(mod): + return len(mod.functions) == 3 + # Define a choice to apply constant folding only when IRModule has three functions. + choice = Choice( + transform_func_key = "relax.tuning_api.test.transform_func", + constr_func_key = "relax.tuning_api.test.constr_func" + ) + """ + + def __init__( + self, + transform_func_key: Optional[str] = None, + transform_func_args: Optional[List] = None, + constr_func_key: Optional[str] = None, + constr_func_args: Optional[List] = None, + ): + """Constructor + Parameters + ---------- + transform_func_key : Optional[str] + Key for transformation function. + + f_tramsform_args: Optional[List] + Arguments for transformation function. + + constr_func_key : Optional[str] + Key for constraint function. + + constr_func_args: Optional[List] + Arguments for constraint function. + """ + + if transform_func_key is None: + transform_func_key = "relax.tuning_api.Choice.default_transform_func" + + if transform_func_args is None: + transform_func_args = [] + + if constr_func_key is None: + constr_func_key = "relax.tuning_api.Choice.default_constr_func" + + if constr_func_args is None: + constr_func_args = [] + + self.__init_handle_by_constructor__( + _ffi_api.Choice, # type: ignore + transform_func_key, + transform_func_args, + constr_func_key, + constr_func_args, # type: ignore # pylint: disable=no-member + ) + + def get_transform_func(self) -> Callable: + """Getter for transform_func + Returns + ------- + ret: Callable + registered transformation function + """ + return _ffi_api.ChoiceGetTransformFunc(self) # type: ignore + + def get_constr_func(self) -> Callable: + """Getter for constr_func + Returns + ------- + ret: Callable + registered constraint function + """ + return _ffi_api.ChoiceGetConstrFunc(self) # type: ignore + + def apply_transform_func(self, mod: IRModule) -> IRModule: + """Perform transform_func with its arguments + Returns + ------- + ret: IRModule + Transformed IRModule + """ + return _ffi_api.ChoiceApplyTransformFunc(self, mod) # type: ignore + + def check_constr(self, mod: IRModule) -> bool: + """Perform constr_func with its arguments + Returns + ------- + ret: bool + Returns whether the IRModule satisfies the constraint or not + """ + return _ffi_api.ChoiceCheckConstr(self, mod) # type: ignore + + def as_json(self) -> JSON_TYPE: + """Serialize the trace as a JSON-style object + Returns + ------- + json: JSON_TYPE + The JSON-style object + """ + return _ffi_api.ChoiceAsJSON(self) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Choice": + """Create Choice from JSON obj + + Parameters + ---------- + json_obj: JSON_TYPE + Choice serialized with JSON + + Return + ---------- + choice: Choice + Deserialized choice + """ + return _ffi_api.ChoiceFromJSON(json_obj) # type: ignore + + def deepcopy(self): + return Choice.from_json(self.as_json()) + + +@register_object("relax.tuning_api.Knob") +class Knob(Object): + """ + A TVM object Knob that maintains a set of valid Choices. + By using Knobs, a tuning pass can generate candidates and define the search space. + Parameters + ---------- + name : str + Name of the knob. + + choices: Union[List[Choice], Dict[str, Choice]] + A list of valid choices + + Examples + -------- + The following code block defines a Knob. + + .. code-block:: python + @tvm.register_func("relax.tuning_api.test.transform_func") + def apply(mod): + return relax.tuning_api.FoldConstant()(mod) + choices = {"apply": Choice("relax.tuning_api.test.transform_func"), "noapply": Choice()} + # A knob manages a set of its valid choices + knob = Knob("MockTuningKnob", choices) + """ + + def __init__(self, name: str, choices: Union[List[Choice], Dict[str, Choice]]): + """Constructor.""" + if isinstance(choices, list): + choices = {str(idx): val for idx, val in enumerate(choices)} + + self.__init_handle_by_constructor__( + _ffi_api.Knob, name, choices # type: ignore # pylint: disable=no-member + ) + + def verify(self, decision: Union[str, int]) -> bool: + """Verify if the decision is valid.""" + if isinstance(decision, int): + decision = str(decision) + return _ffi_api.KnobIsValidDecision(self, decision) # type: ignore + + def apply(self, mod: IRModule, decision: Union[str, int]) -> IRModule: + """Get choice if a decision is valid.""" + if isinstance(decision, int): + decision = str(decision) + return _ffi_api.KnobApply(self, mod, decision) # type: ignore + + def as_json(self) -> JSON_TYPE: + """Serialize the trace as a JSON-style object + Returns + ------- + json: JSON_TYPE + The JSON-style object + """ + return _ffi_api.KnobAsJSON(self) # type: ignore + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Knob": + """Create Knob from JSON obj + + Parameters + ---------- + json_obj: JSON_TYPE + Knob serialized with JSON + + Return + ---------- + knob: Knob + Deserialized knob + """ + return _ffi_api.KnobFromJSON(json_obj) # type: ignore + + def __str__(self) -> str: + msg = f"{self.name} (# of choices: {len(self.choices)})\n" + for name, choice in self.choices.items(): + msg += f" - {name}: {choice}\n" + return msg + + def deepcopy(self): + return Knob.from_json(self.as_json()) + + +@register_object("relax.tuning_api.Trace") +class Trace(Object): + """ + A TVM object Trace logs the history of transformations (decisions). + Parameters + ---------- + in_mod : IRModule + Input IRModule. + knobs: Optional[List[Knob]] + A list of knobs applied in the trace. + decisions: Optional[Sequence[Union[str, int]]] + A list of decisions made for each knob + + Examples + -------- + The following code block defines a Trace. + + .. code-block:: python + + trace = Trace(mod, [knob1, knob2, knob3], ["c1", "c0", "c3"]) + assert trace.size == 3 # Length of history. + # 'out' contains IRModule that applies transformations in the trace. + out: IRModule = trace.add(knob4, "c2") + assert trace.size == 4 # Length of history. + trace.set_perf(0.03) # Set the performance number of the trace. + """ + + def __init__( + self, + in_mod: IRModule, + knobs: Optional[List[Knob]] = None, + decisions: Optional[Sequence[Union[str, int]]] = None, + ): + """Constructor.""" + knobs = knobs if knobs else list() + decisions = ( + [str(v) if isinstance(v, int) else v for v in decisions] if decisions else list() + ) + self.__init_handle_by_constructor__( + _ffi_api.Trace, in_mod, knobs, decisions # type: ignore # pylint: disable=no-member + ) + + def verify(self) -> bool: + """Verify if current history is valid.""" + return _ffi_api.TraceVerify() # type: ignore + + def add(self, knob: Knob, decision: Union[str, int]) -> IRModule: + """Add & Apply new decision (with knob).""" + if isinstance(decision, int): + decision = str(decision) + return _ffi_api.TraceAdd(self, knob, decision) # type: ignore + + def set_perf(self, perf: float) -> None: + """Set performance number for the trace.""" + return _ffi_api.TraceSetPerf(self, perf) # type: ignore + + def set_out_mod(self, mod: IRModule) -> None: + """Set out_mod for the trace.""" + return _ffi_api.TraceSetOutMod(self, mod) # type: ignore + + def as_json(self, include_irmod: bool = True) -> JSON_TYPE: + """Serialize the trace as a JSON-style object. + Parameters + ---------- + include_irmod: bool + Decides whether to serialize in_mod as well. + + Returns + ------- + json: JSON_TYPE + The JSON-style object. + """ + obj = _ffi_api.TraceAsJSON(self, include_irmod) # type: ignore + return _json_from_tvm(obj) + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Trace": + """Create Trace from JSON obj. + + Parameters + ---------- + json_obj: JSON_TYPE + Trace serialized with JSON. + + Return + ---------- + trace: Trace + Deserialized trace. + """ + return _ffi_api.TraceFromJSON(json_obj) # type: ignore + + def __str__(self) -> str: + n = len(self.knobs) + msg = f"Trace length: {n}\n" + for idx in range(n): + msg += f"[{idx+1}] {self.knobs[idx].name}: {self.decisions[idx]}\n" + return msg + + def deepcopy(self) -> "Trace": + new_in_mod = deepcopy_irmodule(self.in_mod) + new_knobs = [knob.deepcopy() for knob in self.knobs] + new_decisions = [str(decision) for decision in self.decisions] + new_trace = Trace(new_in_mod, new_knobs, new_decisions) + new_out_mod = deepcopy_irmodule(self.out_mod) + new_trace.set_out_mod(new_out_mod) + return new_trace + + +def get_trace(in_: Union[Trace, IRModule, Expr]) -> Trace: + """ + Getter for a trace wrapper. + + Parameters + ---------- + in_: Union[Trace, IRModule, Expr] + Input entity + Return + ---------- + wrapped: Trace + Traced entity + """ + if isinstance(in_, Trace): + return in_ + if isinstance(in_, IRModule): + return Trace(in_) + if isinstance(in_, Expr): # type: ignore + return Trace(tvm.IRModule.from_expr(in_)) + + raise Exception(f"Invalid input type for trace: {type(in_)}") + + +@tvm.register_func("relax.tuning_api.deepcopy_irmodule") +def deepcopy_irmodule(mod: IRModule) -> IRModule: + """ + Deepcopy for an IRModule. + Parameters + ---------- + mod: IRModule + input IRModule + Return + ---------- + copied_mod: IRModule + deep-copied IRModule + """ + func_save_json = tvm.get_global_func("node.SaveJSON") + func_load_json = tvm.get_global_func("node.LoadJSON") + new_mod = None + # Handle external modules separately if exist + # TODO(tvm-team): + # Serialization of IRModule with external mods is tricky. + # (1) External mod is runtime module. + # (2) Currently, `export_library` does not support serialization of + # runtime module without the host module + # Therefore, we simply pass around the compiled external modules without copy for now. + # Revisit later when we have a better solution. + if mod.attrs and "external_mods" in mod.attrs: + tmp_mod = mod.without_attr("external_mods") + new_mod = func_load_json(func_save_json(tmp_mod)) + new_mod = new_mod.with_attr("external_mods", mod.attrs["external_mods"]) + else: + new_mod = func_load_json(func_save_json(mod)) + + return new_mod diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py new file mode 100644 index 000000000000..05492d6a9c34 --- /dev/null +++ b/python/tvm/relax/ty.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import +"""The type nodes of the Relax language.""" +import tvm._ffi +from tvm.ir import Type, TensorType, TupleType, FuncType, Span + +from . import _ffi_api + + +@tvm._ffi.register_object("relax.ShapeType") +class ShapeType(Type): + """The type of shape in Relax. + + Parameters + ---------- + ndim : Optional[int] + The size of the shape. + """ + + # TODO(relax-team): consider make ndim mandatory + def __init__(self, ndim: int = -1, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span) # type: ignore + + +@tvm._ffi.register_object("relax.ObjectType") +class ObjectType(Type): + """A type that corresponds to tvm::runtime::Object, is base of all possible object + values in TVM.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ObjectType, span) # type: ignore + + +@tvm._ffi.register_object("relax.DynTensorType") +class DynTensorType(Type): + """A dynamic tensor type in Relax. + + This is the type assigned to tensors with a known dtype and unknown shape. + + Parameters + ---------- + ndim : Optional[int] + The ndim of the Tensor + + dtype : Optional[str] + The content data type. + """ + + def __init__(self, ndim=-1, dtype="float32", span: Span = None) -> None: + self.__init_handle_by_constructor__( + _ffi_api.DynTensorType, ndim, dtype, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.PackedFuncType") +class PackedFuncType(Type): + """The type of ExternFunc in Relax.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.PackedFuncType, span) # type: ignore diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py new file mode 100644 index 000000000000..e8ff144f9533 --- /dev/null +++ b/python/tvm/relax/utils.py @@ -0,0 +1,452 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,too-many-locals +"""Utility functions for Relax""" +import functools +import inspect +from typing import Tuple as typing_Tuple +from typing import Any, Callable, List, Dict, Optional, TypeVar + +from .. import tir +from ..tir import PrimExpr +from ..runtime import String, convert_to_object +from . import _ffi_api +from .expr import Tuple as rx_Tuple +from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor +from ..te import Tensor as te_Tensor, create_relax_prim_func +from ..ir import Array, Attrs, Type, Map +from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo + + +def metadata_partitioner(rx_txt: str) -> List[str]: + """Extract Relax program and metadata section. + + Parameters + ---------- + rx_txt : str + The input relax text. + + Returns + ------- + output : List[str] + The result list of partitioned text, the first element + is the relax program, and the second is metadata section. + """ + partitions = [] + left_curly = 0 + meta_start = 0 + meta_end = 0 + for i, char in enumerate(rx_txt): + if i < 0: + raise ValueError("The program is invalid.") + if char == "{": + if meta_start == 0: + meta_start = i + left_curly += 1 + elif char == "}": + left_curly -= 1 + if left_curly == 0: + meta_end = i + 1 + break + + if meta_end == 0: + raise ValueError("The metadata section was not found.") + metadata = rx_txt[meta_start:meta_end] + rx_program = rx_txt[meta_end:-1] + + partitions.append(rx_program) + partitions.append(metadata) + + return partitions + + +def convert_to_expr(value: Any) -> Expr: + """Helper function to convert the input to Expr, which follows the rules: + 1. Return the input itself if it's already a `relax.Expr`; + 2. Return `relax.PrimValue` if the input is a `PrimExpr`; + 3. Return `relax.StringImm` if the input is `tvm.String` or `str`; + 4. Return `relax.Tuple` if the input is a tuple/list of `Expr`. + + Notes + ----- + 1. `tvm.tir.StringImm` is not allowed because of ambiguity, + which can be either `relax.StringImm` or `relax.PrimValue`. + """ + if isinstance(value, int): + return PrimValue(tir.IntImm("int64", value)) + + tvm_value = convert_to_object(value) + # Case 1 + if isinstance(tvm_value, Expr): # type: ignore + return tvm_value + # Note`` 1 + if isinstance(tvm_value, tir.StringImm): + raise TypeError( + "Cannot convert `tir.StringImm` to `relax.Expr` because of ambiguity," + "which can be either `relax.StringImm` or `relax.PrimValue` " + ) + # Case 2 + if isinstance(tvm_value, PrimExpr): + return PrimValue(value) + # Case 3 + if isinstance(tvm_value, String): + return StringImm(value) + # Case 4 + if isinstance(value, (tuple, list)): + # `convert_to_expr` ensures that all elements are `Expr` if no exception raises + return rx_Tuple([convert_to_expr(v) for v in value]) + raise TypeError(f"Cannot convert {value} with type {type(value)} to `relax.Expr`") + + +FType = TypeVar("FType", bound=Callable[..., Expr]) + + +class _ArgsConverter: + """A helper class to convert the arguments to Expr.""" + + @staticmethod + def convert(args_to_expr: List[str], args_to_list_expr: List[str]): + """Convert the arguments to Expr. + + Parameters + ---------- + args_to_expr : List[str] + The argument names to be converted to Expr. + + args_to_list_expr : List[str] + The argument names to be converted to List[Expr]. + + Returns + ------- + output : Callable[[FType], FType] + The decorator. + """ + + if any([x in args_to_list_expr for x in args_to_expr]): + raise ValueError(f"`args_to_expr` and `args_to_list_expr` should be disjoint.") + + def _convert(name: str, value: Any) -> Any: + if value is None: + return value + if name in args_to_expr: + try: + return convert_to_expr(value) + except: + raise TypeError( + f"Argument `{name}` is expected to be converted to `Expr`, " + f"but failed with input value: {value}" + ) + elif name in args_to_list_expr: + try: + return [convert_to_expr(x) for x in value] + except: + raise TypeError( + f"Argument `{name}` is expected to be converted to `List[Expr]`, " + f"but failed with input value: {value}" + ) + else: + return value + + def inner(func: FType) -> FType: + sig = inspect.signature(func) + param_names = list(sig.parameters.keys()) + for name in args_to_expr + args_to_list_expr: + if name not in param_names: + raise ValueError(f"Argument `{name}` is not found in function signature.") + + @functools.wraps(func) + def wrapper(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + for param in sig.parameters.values(): + if param.kind == param.VAR_POSITIONAL: + # *args case + values = [_convert(param.name, x) for x in bound.arguments[param.name]] + bound.arguments[param.name] = tuple(values) + elif param.kind == param.VAR_KEYWORD: + # **kwargs case + key_value = { + key: _convert(param.name, value) + for key, value in bound.arguments[param.name].items() + } + bound.arguments[param.name] = key_value + else: + bound.arguments[param.name] = _convert( + param.name, bound.arguments[param.name] + ) + return func(*bound.args, **bound.kwargs) + + return wrapper # type: ignore + + return inner + + @staticmethod + def to_expr(*arg_names: str) -> Callable: + """Convert the arguments to Expr. + + Parameters + ---------- + *arg_names: str + The list of argument names that need to be converted to Expr. + + Returns + ------- + output: Callable + The decorator. + """ + + return _ArgsConverter.convert(args_to_expr=list(arg_names), args_to_list_expr=[]) + + @staticmethod + def to_list_expr(*arg_names: str) -> Callable: + """Convert the arguments to List of Expr. + + Parameters + ---------- + *arg_names: str + The list of argument names that need to be converted to List of Expr. + + Returns + ------- + output: Callable + The decorator. + """ + + return _ArgsConverter.convert(args_to_expr=[], args_to_list_expr=list(arg_names)) + + @staticmethod + def auto(func: FType) -> FType: + """Decorator for automatically convert the arguments to Expr according to type annotation. + Only two patterns are supported: + + 1. The argument is Expr or Optional[Expr]. + + 2. The argument is List[Expr] or Optional[List[Expr]]. + + """ + sig = inspect.signature(func) + args_to_expr = [] + args_to_list_expr = [] + + for param in sig.parameters.values(): + anno = param.annotation + if anno in (Expr, Optional[Expr]): + args_to_expr.append(param.name) + if anno in (List[Expr], Optional[List[Expr]]): + args_to_list_expr.append(param.name) + + return _ArgsConverter.convert(args_to_expr, args_to_list_expr)(func) + + +args_converter = _ArgsConverter() # pylint: disable=invalid-name + + +def copy_with_new_vars(func: Function) -> Function: + """Copy the given function. All variables that are bound inside the original function + would be copied to satisfy the restriction in the well-formed check: Variables in + Relax must be bound exactly once. This also ensures that both the function and its copy + can be inserted into the same IRModule, and be asserted on the structural equality + agaisnt IRModule created by TVMScript. + + Parameters + ---------- + func : Function + The relax function to copy. + + Returns + ------- + ret : Function + The copied function. + """ + return _ffi_api.CopyWithNewVars(func) # type: ignore + + +def gen_call_tir_inputs( + func: Callable, *args: Any, **kwargs: Any +) -> typing_Tuple[tir.PrimFunc, Expr, List[TensorStructInfo], Optional[ShapeExpr]]: + """Generate the inputs for call_tir according to the te function. + This function converts arguments from relax expression to te tensor, + The callback func should return a te tensor or a list of te tensors. + + Parameters + ---------- + func : Callable + A function that returns a te tensor or a list of te tensors. + + args : Any, optional + arguments passed to the function. + + kwargs : Any, optional + The keyword arguments passed to the function. + Note that the keyword args 'primfunc_attrs' is reserved for passing func + attributes to be added to the PrimFunc that gets created. + + Returns + ------- + ret : Tuple[tir.PrimFunc, Expr, List[TensorStructInfo], Optional[ShapeExpr]] + ret contains the inputs for call_tir, including a tir prim_func, args, + out_sinfo, and tir_vars. + """ + + def _convert_te_arg( + te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr] + ) -> typing_Tuple[Any, List[te_Tensor]]: + """Helper function used to convert Relax expressions to TE tensor. + + In the common case, the type of te_args is a Relax expression and is converted + into a TE tensor. + If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array), + we recursive and convert any value of type Relax expression into a TE tensor. + Common values of type int, float, and str are preserved. + + In dynamic shape cases, the passed in arguments may contain TIR variable. + For example, the argument can be a Relax Var with TensorStructInfo, which + has symbolic shape, or the argument can be a ShapeExpr with symbolic variables. + To make the PrimFunc generated has independent variables with + the caller Relax function, we will substitute the TIR variables in the input + arguments with fresh ones, which is done by maintaining a TIR variable mapping. + + Parameters + ---------- + te_args : Any + Argument to convert to TE + + tir_var_map : Dict[tir.Var, tir.PrimExpr] + The TIR variable mapping, which maps TIR variables on the Relax function + side to the new set of variables used on the PrimFunc side. + + Returns + ------- + ret : (Any, [tvm.te.Tensor]) + A tuple of the converted te_args, and a list of te tensors for each converted + Relax expression + """ + te_args_list = [] + + def _copy_undefined_var(expr: tir.PrimExpr): + def _visit_expr(e: tir.PrimExpr): + if isinstance(e, tir.Var) and e not in tir_var_map: + new_var = tir.Var(e.name, e.dtype) + tir_var_map[e] = new_var + + tir.stmt_functor.post_order_visit(expr, _visit_expr) + + def _convert_te_arg_helper(arg): + if isinstance(arg, Expr): # type: ignore + if isinstance(arg.struct_info, TensorStructInfo): + assert isinstance( + arg.struct_info.shape, ShapeExpr + ), "emit_te now only supports Tensor that has ShapeExpr shape" + for shape_value in arg.struct_info.shape.values: + _copy_undefined_var(shape_value) + + arg = te_tensor(arg, tir_var_map) + te_args_list.append(arg) + return arg + if isinstance(arg.struct_info, ShapeStructInfo): + assert isinstance( + arg, ShapeExpr + ), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr" + return [_convert_te_arg_helper(val) for val in arg.values] + if isinstance(arg.struct_info, PrimStructInfo): + return arg.value + elif isinstance(arg, (list, Array)): + return [_convert_te_arg_helper(x) for x in arg] + elif isinstance(arg, tuple): + return tuple(_convert_te_arg_helper(x) for x in arg) + elif isinstance(arg, (dict, Map)): + for key in arg: + assert isinstance( + key, str + ), "emit_te only supports dict with string as the key currently" + return {k: _convert_te_arg_helper(arg[k]) for k in arg} + elif isinstance(arg, tir.PrimExpr): + _copy_undefined_var(arg) + return tir.stmt_functor.substitute(arg, tir_var_map) + elif isinstance(arg, (int, float, str, Type, Attrs)) or arg is None: + return arg + raise TypeError("not supported type in emit_te: {}".format(type(arg))) + + new_arg = _convert_te_arg_helper(te_args) + return new_arg, te_args_list + + def _get_unbound_tir_vars(args: List[te_Tensor]) -> List[tir.Var]: + """get unbound TIR vars (i.e TIR vars used in the shape but is not + itself a dimension of a shape)""" + bound_vars = set() + used_vars = set() + + def _populate_used_vars(expr): + if isinstance(expr, tir.Var): + used_vars.add(expr) + + for x in args: + for s in x.shape: + tir.stmt_functor.post_order_visit(s, _populate_used_vars) + if isinstance(s, tir.Var): + bound_vars.add(s) + + diff = used_vars - bound_vars + return list(diff) + + def _shape_with_old_tir_var( + shape_values: List[tir.PrimExpr], tir_var_inverse_map: Dict[tir.Var, tir.PrimExpr] + ): + return ShapeExpr( + [tir.stmt_functor.substitute(value, tir_var_inverse_map) for value in shape_values] + ) + + primfunc_attrs = kwargs.pop("primfunc_attrs", None) + + tir_var_map: Dict[tir.Var, tir.PrimExpr] = {} + new_args, te_arg_list = _convert_te_arg(args, tir_var_map) + new_kwargs, te_kwarg_list = _convert_te_arg(kwargs, tir_var_map) + + te_args = te_arg_list + te_kwarg_list + + te_out = func(*new_args, **new_kwargs) + assert isinstance(te_out, te_Tensor) or ( + isinstance(te_out, (tuple, list, Array)) and all(isinstance(t, te_Tensor) for t in te_out) + ), "only support te.tensor or tuple/list/Array of te.tensor as function output" + + outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out) + unbound_tir_vars = _get_unbound_tir_vars(te_args + outs) + + inputs = [*te_args] + outs + tir_func = create_relax_prim_func(inputs, unbound_tir_vars, "int64") + + if primfunc_attrs: + tir_func = tir_func.with_attrs(primfunc_attrs) + + tir_func = tir_func.without_attr("global_symbol") + + call_tir_args = [x.op.value for x in te_args] + + # Invert the TIR variable mapping, to convert the output shape back + # with old set of variables. + tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} + + output_sinfo = [ + TensorStructInfo(_shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype) + for out in outs + ] + + tir_vars = None + if len(unbound_tir_vars) > 0: + tir_vars = _shape_with_old_tir_var(unbound_tir_vars, tir_var_inverse_map) + + return (tir_func, call_tir_args, output_sinfo, tir_vars) diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py new file mode 100644 index 000000000000..0586bf9217a2 --- /dev/null +++ b/python/tvm/relax/vm_build.py @@ -0,0 +1,329 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-member +"""VM build logics""" +from typing import List, Optional, Union, Dict, Any + +import tvm +from tvm import relax + +from tvm.contrib import utils as _utils + +from tvm.ir.module import IRModule +from tvm.tir.function import PrimFunc + +from . import _ffi_api + + +class Executable: + """The executable object emitted by the VM compiler or the ExecBuilder.""" + + def __init__(self, mod: tvm.runtime.Module): + self.mod = mod + self._stats = self.mod["stats"] + self._as_text = self.mod["as_text"] + self._as_python = self.mod["as_python"] + + def stats(self) -> str: + """print the detailed statistics of the executable.""" + return self._stats() + + def as_text(self) -> str: + """print the instructions as text format.""" + return self._as_text() + + def as_python(self) -> str: + """print the instructions as python program.""" + return self._as_python() + + def jit(self, fcompile=None, addons=None, **kwargs) -> tvm.runtime.Module: + """Just-in-time compile and link the modules. + + The Executable returned by relax.build may not be directly + runnable as they may contain cuda source files and objects that + are yet to be compiled and linked. + This function helps to create a runtime.Module for these cases. + + Parameters + ---------- + fcompile : function(target, file_list, kwargs), optional + The compilation function to use create the final library object during + + kwargs : dict, optional + Additional arguments passed to fcompile + + Returns + ------- + rt_mod: tvm.runtime.Module + A runnable runtime module that can be passed to VirtualMachine. + + Examples + -------- + .. code:: python + + ex = relax.build(mod, target) + # build a runnable module using nvcc to link everything + rt_mod = ex.jit() + vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda()) + """ + # TODO(tvm-team): Update runtime.Module interfac + # to query these properties as bitmask. + def _not_runnable(x): + return x.type_key in ("c", "static_library") + + # pylint:disable = protected-access + not_runnable_list = self.mod._collect_from_import_tree(_not_runnable) + + # everything is runnable, directly return mod. + if len(not_runnable_list) == 0: + return self.mod + + # found source module, or other not runnable modules + # need to be export and load + # TODO(tvm-team): Support runnable but not exportable module. + # by collecting the link and allow export_library skip those modules. + workspace_dir = _utils.tempdir() + dso_path = workspace_dir.relpath("exported.so") + self.mod.export_library(dso_path, fcompile=fcompile, addons=addons, **kwargs) + return tvm.runtime.load_module(dso_path) + + def export_library( + self, + file_name: str, + fcompile: Optional[Union[str, callable]] = None, + workspace_dir: Optional[str] = None, + **kwargs, + ) -> Any: + """Export the executable to a library which can then be loaded back. + + Parameters + ---------- + file_name : str + The name of the shared library. + + fcompile : function(target, file_list, kwargs), optional + The compilation function to use create the final library object during + + workspace_dir : str, optional + The path of the directory used to create the intermediate + artifacts when exporting the module. + If this is not provided a temporary dir will be created. + + kwargs : dict, optional + Additional arguments passed to fcompile + + Returns + ------- + result of fcompile() : unknown, optional + If the compilation function returns an artifact it would be returned via + export_library, if any. + + Examples + -------- + .. code:: python + + ex = relax.build(mod, target) + # export the library + ex.export_library("exported.so") + + # load it back for future uses. + rt_mod = tvm.runtime.load_module("exported.so") + vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda()) + """ + return self.mod.export_library( + file_name=file_name, fcompile=fcompile, workspace_dir=workspace_dir, **kwargs + ) + + +def _vmcodegen( + builder: "relax.ExecBuilder", + mod: tvm.IRModule, + exec_mode: str = "bytecode", +) -> tvm.IRModule: + """Running VM codegen. + + Parameters + ---------- + builder: relax.ExecBuilder + ExecBuilder to collect the vm executable. + + mod: IRModule + The input IRModule to be built. + + exec_mode: {"bytecode", "compiled"} + The execution mode. + + Return + ------ + leftover: IRModule + Left over IRModule that may contain extra functions. + """ + + if exec_mode == "bytecode": + return _ffi_api.VMCodeGen(builder, mod) # type:ignore + if exec_mode == "compiled": + return _ffi_api.VMTIRCodeGen(builder, mod) # type: ignore + raise ValueError("Unknown exec_mode %s" % exec_mode) + + +def _autodetect_system_lib_req(target: tvm.target.Target): + """Automatically detect system lib requirement""" + host = target if target.host is None else target.host + system_lib = False + if "wasm" in host.attrs.get("mtriple", ""): + system_lib = True + if system_lib: + # use packed-func to avoid relay dep. + return tvm.get_global_func("relay.backend.CreateRuntime")("cpp", {"system-lib": system_lib}) + return None + + +def _vmlink( + builder: "relax.ExecBuilder", + target: Union[str, tvm.target.Target], + tir_mod: Optional[tvm.IRModule] = None, + ext_libs: List[tvm.runtime.Module] = None, + params: Optional[Dict[str, list]] = None, +): + """ + Internal codegen function to make executable. + + This function is only used for unit-testing purpoes. + + Use build instead. + + Parameters + ---------- + builder: relax.ExecBuilder + Builder used to collect executables. + + target : Union[str, tvm.target.Target] + A build target which can have optional host side compilation target. + + tir_mod: IRModule + The input TIR IRModule to be linked together. + + ext_libs: List[tvm.runtime.Module] + List of compiled external modules. + + params: Optional[Dict[str, list]] + Extra parameter mappings. + + Returns + ------- + ex: tvm.relax.Executable + An executable that can be loaded by virtual machine. + """ + if isinstance(target, str): + target = tvm.target.Target(target) + if params is None: + params = {} + if ext_libs is None: + ext_libs = [] + lib = None + if tir_mod is not None: + lib = tvm.build(tir_mod, target=target, runtime=_autodetect_system_lib_req(target)) + return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore + + +def build( + mod: tvm.IRModule, + target: Union[str, tvm.target.Target], + params: Optional[Dict[str, list]] = None, + exec_mode: str = "bytecode", +) -> Executable: + """ + Build an IRModule to VM executable. + + Parameters + ---------- + mod: IRModule + The input IRModule to be built. + + target : Union[str, tvm.target.Target] + A build target which can have optional host side compilation target. + + When TVM compiles device specific program such as CUDA, + we also need host(CPU) side code to interact with the driver + to setup the dimensions and parameters correctly. + host is used to specify the host side codegen target. + By default, llvm is used if it is enabled, + otherwise a stackvm interpreter is used. + + params: Optional[Dict[str, list]] + Parameters for the input IRModule that will be bound. + + exec_mode: {"bytecode", "compiled"} + The execution mode. + + Returns + ------- + ex: tvm.relax.Executable + An executable that can be loaded by virtual machine. + + Example + ------- + + .. code-block:: python + class InputModule: + @R.function + def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): + z = R.add(x, y) + return z + + mod = InputModule + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target) + """ + if isinstance(target, str): + target = tvm.target.Target(target) + + passes = [] + passes.append(relax.transform.RewriteDataflowReshape()) + passes.append(relax.transform.ToNonDataflow()) + passes.append(relax.transform.CallTIRRewrite()) + passes.append(relax.transform.StaticPlanBlockMemory()) + passes.append(relax.transform.VMBuiltinLower()) + passes.append(relax.transform.VMShapeLower()) + passes.append(relax.transform.AttachGlobalSymbol()) + seq = tvm.transform.Sequential(passes) + new_mod = seq(mod) + + # Extract external runtime modules if exist. + attrs = dict(mod.attrs) if mod.attrs else {} + + ext_libs = attrs.get("external_mods", []) + constants = attrs.get("const_name_to_constant", {}) + + if params is not None: + params.update(dict(constants)) + else: + params = constants + + # builder collects the executable + builder = relax.ExecBuilder() + leftover_mod = _vmcodegen(builder, new_mod, exec_mode=exec_mode) + tir_mod = _filter_tir(leftover_mod) + return _vmlink(builder, target, tir_mod, ext_libs, params) + + +def _filter_tir(mod: tvm.IRModule) -> tvm.IRModule: + tir_mod = IRModule({}) + for gv in mod.get_global_vars(): + if isinstance(mod[gv], PrimFunc): + tir_mod[gv] = mod[gv] + return tir_mod diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 4e9a9a4707a1..deae9e2f48be 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -439,16 +439,6 @@ class AffineGridAttrs(Attrs): """Attributes used in affine_grid operators""" -@tvm._ffi.register_object("relay.attrs.AllocStorageAttrs") -class AllocStorageAttrs(Attrs): - """Attributes used in alloc_storage operators""" - - -@tvm._ffi.register_object("relay.attrs.AllocTensorAttrs") -class AllocTensorAttrs(Attrs): - """Attributes used in alloc_tensor operators""" - - @tvm._ffi.register_object("relay.attrs.CastHintAttrs") class CastHintAttrs(Attrs): """Attributes used in cast_hint annotation operators""" diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index d7027c88a4b5..59af53d4e164 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -203,11 +203,20 @@ def signal_close(self): self.close() +MIME_MAP = { + "js": "application/javascript", + "wasm": "application/wasm", + "json": "application/json", +} + + class RequestHandler(tornado.web.RequestHandler): """Handles html request.""" def __init__(self, *args, **kwargs): file_path = kwargs.pop("file_path") + self.format = file_path.split(".")[-1] + if file_path.endswith("html"): self.page = open(file_path).read() web_port = kwargs.pop("rpc_web_port", None) @@ -217,12 +226,15 @@ def __init__(self, *args, **kwargs): ) else: self.page = open(file_path, "rb").read() + super(RequestHandler, self).__init__(*args, **kwargs) def data_received(self, _): pass def get(self, *args, **kwargs): + if self.format in MIME_MAP: + self.set_header("Content-Type", MIME_MAP[self.format]) self.write(self.page) @@ -254,9 +266,14 @@ def __init__( ) logging.info("Serving RPC index html page at http://localhost:%d", web_port) resource_files = resource_files if resource_files else [] - for fname in resource_files: + for item in resource_files: + prefix, fname = item + if not prefix.endswith("/"): + prefix += "/" + if not prefix.startswith("/"): + prefix = "/" + prefix basename = os.path.basename(fname) - pair = (r"/%s" % basename, RequestHandler, {"file_path": fname}) + pair = (r"%s%s" % (prefix, basename), RequestHandler, {"file_path": fname}) handlers.append(pair) logging.info(pair) self.app = tornado.web.Application(handlers) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index c78a6d9c3136..9356c19c4bda 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -293,6 +293,10 @@ def is_dso_exportable(self): """ return (self.get_property_mask() & ModulePropertyMask.DSO_EXPORTABLE) != 0 + def clear_imports(self): + """Remove all imports of the module.""" + _ffi_api.ModuleClearImports(self) + def save(self, file_name, fmt=""): """Save the module to file. @@ -487,7 +491,7 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No raise RuntimeError("Cannot call export_library in runtime only mode") # Extra dependencies during runtime. from pathlib import Path - from tvm.contrib import cc as _cc, tar as _tar, utils as _utils + from tvm.contrib import cc as _cc, tar as _tar, utils as _utils, tvmjs as _tvmjs if isinstance(file_name, Path): file_name = str(file_name) @@ -540,7 +544,7 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No object_format = "cu" has_c_module = True else: - assert module.type_key == "llvm" or module.type_key == "static_library" + assert module.is_dso_exportable global_object_format = object_format = "o" path_obj = os.path.join(workspace_dir, f"lib{index}.{object_format}") @@ -552,6 +556,8 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No if not fcompile: if file_name.endswith(".tar"): fcompile = _tar.tar + elif file_name.endswith(".wasm"): + fcompile = _tvmjs.create_tvmjs_wasm else: fcompile = _cc.create_shared diff --git a/python/tvm/runtime/relax_vm.py b/python/tvm/runtime/relax_vm.py new file mode 100644 index 000000000000..c53882095d6c --- /dev/null +++ b/python/tvm/runtime/relax_vm.py @@ -0,0 +1,508 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, redefined-builtin, no-else-return, consider-using-dict-items +"""The Relax virtual machine.""" +from typing import Callable, List, Optional, Union, Dict, Tuple, Any +from enum import IntEnum +import numpy as np # type: ignore + +import tvm +from tvm._ffi import base as _base + +from tvm.runtime import Device, PackedFunc, container, Object +from tvm.runtime.profiling import Report + +from ..rpc.base import RPC_SESS_MASK + + +class VMInstrumentReturnKind(IntEnum): + NO_OP = 0 + # skip the following call, only valid in before + SKIP_RUN = 1 + + +class VirtualMachine(object): + """Relax VM runtime.""" + + NAIVE_ALLOCATOR = 1 + POOLED_ALLOCATOR = 2 + + def __init__( + self, + rt_mod: Union[tvm.runtime.Module, "tvm.relax.Executable"], + device: Union[Device, List[Device]], + memory_cfg: Optional[Union[str, Dict[Device, str]]] = None, + profile: bool = False, + ) -> None: + """ + Construct a VirtualMachine wrapper object. + + Parameters + ---------- + mod: Union[tvm.runtime.Module, tvm.relax.Executable] + Runtime module exported by the result of build. + + device : Union[Device, List[Device]] + The device to deploy the module. + + memory_cfg : Optional[Union[str, Dict[Device, str]]] + Config the type of memory allocator. The allocator type can be ["naive", + "pooled"]. If memory_cfg is None, all devices will use pooled allocator + by default. If memory_cfg is string, all devices will use the specified + allocator type. If memory_cfg is a dict, each device uses the allocator + type specified in the dict, or pooled allocator if not specified in the + dict. + + profile : Optional[bool] + Whether or not to enable profiling. + """ + if not isinstance(rt_mod, tvm.runtime.Module): + # important to keep this import local + # as the relax_vm needs to be isolated from compiler + # if we do not use the jit feature + # pylint:disable=import-outside-toplevel + from tvm import relax + + if isinstance(rt_mod, relax.Executable): + rt_mod = rt_mod.jit() + else: + raise ValueError("Expect the rt_mod to be an runtime.Module") + + load_exec = "vm_profiler_load_executable" if profile else "vm_load_executable" + self.module = rt_mod[load_exec]() + self._invoke_closure = self.module["invoke_closure"] + self._save_function = self.module["save_function"] + self._set_input = self.module["set_input"] + self._invoke_stateful = self.module["invoke_stateful"] + self._get_output = self.module["get_output"] + self._get_output_arity = self.module["get_output_arity"] + self._get_function_arity = self.module["get_function_arity"] + self._get_function_param_name = self.module["get_function_param_name"] + self._set_instrument = self.module["set_instrument"] + self._setup_device(device, memory_cfg) + + def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) -> None: + """init devices and allocators.""" + devs = dev + if not isinstance(dev, (list, tuple)): + if not isinstance(dev, tvm.runtime.Device): + raise TypeError( + "dev is expected to be Device or \ + List[Device]" + ) + devs = [dev] + + if any(dev.device_type % RPC_SESS_MASK == tvm.cpu().device_type for dev in devs[:-1]): + raise RuntimeError( + "CPU host is required to be the last element of the device list if provided." + ) + + # CPU is required for executing shape functions + if devs[-1].device_type % RPC_SESS_MASK != tvm.cpu().device_type: + devs.append(tvm.cpu()) + + default_alloc_type = VirtualMachine.POOLED_ALLOCATOR + if memory_cfg is None: + memory_cfg = {} + elif isinstance(memory_cfg, str): + assert memory_cfg in ["naive", "pooled"] + if memory_cfg == "naive": + default_alloc_type = VirtualMachine.NAIVE_ALLOCATOR + memory_cfg = {} + elif not isinstance(memory_cfg, dict): + raise TypeError( + "memory_cfg is expected be string or dictionary, " + + "but received {}".format(type(memory_cfg)) + ) + init_args = [] + for device in devs: + init_args.append(device.device_type % RPC_SESS_MASK) + init_args.append(device.device_id) + alloc_type = memory_cfg[device] if device in memory_cfg else default_alloc_type + init_args.append(alloc_type) + self.module["vm_initialization"](*init_args) + + def __getitem__(self, key: str) -> PackedFunc: + return self.module[key] + + def invoke_closure(self, closure: Object, *args: Any) -> Object: + """Invoke a closure. + + Parameters + ---------- + closure : Object + The VMClosure Object. + + args : list[tvm.runtime.NDArray] or list[np.ndarray] + The arguments to the closure. + + Returns + ------- + result : Object + The output. + """ + return self._invoke_closure(closure, *args) + + def save_function( + self, + func_name: str, + saved_name: str, + *args: List[Any], + include_return: bool = True, + **kwargs: Dict[str, Any], + ) -> None: + """ + Convenience function. Takes a function from the module and saves + a `PackedFunc` that, when called, will invoke the function with the given arguments. + The `PackedFunc` can be accessed from the module using `saved_name`. + This is included to facilitate timing trials: + Invoking the returned `PackedFunc` will have less overhead from dictionary lookups + than normally running through the VM. + + If the saved name is taken, it can be overridden, though it cannot override + the name of a function defined in the Relax source. + + This is really creating a closure, but the function has a different name + to avoid confusion with `invoke_closure` (they are not meant to be used together). + + Parameters + ---------- + func_name : str + The function that should be packaged up. + + saved_name : str + The name that the resulting closure should be saved under. + + include_return : bool + Whether the saved PackedFunc should return its output. + If timing over RPC, it may not be desirable to send output + between machines. + + args : List[Any] + The arguments to package up with the function. + + kwargs : Dict[str, Any] + Any named arguments to package up with the function + """ + cargs: List[Any] = [] + if kwargs: + args = self._convert_func_named_args(func_name, args, **kwargs) + for arg in args: + self._convert(arg, cargs) + self._save_function(func_name, saved_name, int(include_return), *cargs) + + def _convert(self, arg: Any, cargs: List) -> None: + """helper function to convert arguments to vm function.""" + + def _gettype(arg): + if isinstance(arg, np.float16): + return "float16" + elif isinstance(arg, (_base.integer_types, bool)): + return "int32" + else: + return "float32" + + if isinstance(arg, Object): + cargs.append(arg) + elif isinstance(arg, np.ndarray): + nd_arr = tvm.nd.array(arg, device=tvm.cpu(0)) + cargs.append(nd_arr) + elif isinstance(arg, tvm.runtime.NDArray): + cargs.append(arg) + elif isinstance(arg, (tuple, list)): + field_args: List[Any] = [] + for field in arg: + self._convert(field, field_args) + cargs.append(container.tuple_object(field_args)) + elif isinstance(arg, (_base.numeric_types, bool)): + dtype = _gettype(arg) + value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0)) + cargs.append(value) + elif isinstance(arg, str): + cargs.append(arg) + else: + raise TypeError("Unsupported type: %s" % (type(arg))) + + def _convert_func_named_args(self, func_name: str, args: Any, **kwargs: Any) -> Any: + """ + Takes named function parameters and returns a list of those needed, + in the order they should appear + """ + # kwargs can be a super set of the required function parameters. + # We only find the ones that are needed. + func_arity = self._get_function_arity(func_name) + func_params = [self._get_function_param_name(func_name, i) for i in range(func_arity)] + new_args = [None] * len(func_params) + cnt = 0 + for k in kwargs: + if k in func_params: + idx = func_params.index(k) + new_args[idx] = kwargs[k] + cnt += 1 + else: + print(f'Warning: Keyword argument "{k}" is unused in {func_name}') + assert len(args) + cnt == len(func_params) + idx = 0 + for i, arg in enumerate(new_args): + if arg is None: + new_args[i] = args[idx] + idx += 1 + return new_args + + def set_input(self, func_name: str, *args: Any, **kwargs: Any) -> None: + """Set the inputs to a function. + This interface works when using VM over RPC by internally converting NDArray in + the arguments to DLTensor, which is supported in RPC where remote could only + have a minimal C runtime. + + Note: If `set_input` is used, the function *must* be called using `invoke_stateful` + and the results must be obtained using `get_outputs`. + + Parameters + ---------- + func_name : str + The name of the function. + args: List[tvm.runtime.NDArray] or List[np.ndarray] + The arguments to the function. + kwargs: dict of str to tvm.runtime.NDArray or np.ndarray + Named arguments to the function. + """ + cargs: List[Any] = [] + + if kwargs: + args = self._convert_func_named_args(func_name, args, **kwargs) + + for arg in args: + self._convert(arg, cargs) + + self._set_input(func_name, *cargs) + + def invoke_stateful(self, func_name: str) -> None: + """ + Call the named function from the VM module using the arguments set using `set_input`. + It is an error to call `invoke_stateful` without using `set_input` first + (even if it's to set 0 inputs); conversely, if `set_input` has been called, + it is an error to call the function without using `invoke_stateful`. + + The results of the call can be obtained by calling `get_outputs`. + + Parameters + ---------- + func_name: str + The name of the function to call. + """ + self._invoke_stateful(func_name) + + def get_outputs(self, func_name: str) -> Union[tvm.Object, Tuple[Any]]: + """ + Get the value output by the function by the given name + after a call of `invoke_stateful`. + + It is an error to call this function without first calling `invoke_stateful`. + + Parameters + ---------- + func_name: str + The name of the function whose output should be fetched. + + Returns + ------- + ret: Union[tvm.Object, Tuple[Any]] + The result of the earlier call to the function via `invoke_stateful`. + If the result is a tuple, it returns a list of the fields. + The fields are potentially also tuples, so these can be arbitrily nested. + """ + # to deal with potentially nested tuples, we need to query for arity recursively + def get_output_rec(func_name, *idx): + arity = self._get_output_arity(func_name, *idx) + if arity == -1: + return self._get_output(func_name, *idx) + # otherwise we need to specify more indices + idx_list = list(idx) + return tuple(get_output_rec(func_name, *(idx_list + [i])) for i in range(arity)) + + return get_output_rec(func_name) + + def set_instrument(self, instrument: tvm.runtime.PackedFunc): + """Set an instrumentation function. + + If instrument is present, the function will be called + before/after each Call instruction. The function have + the following signature: + + .. code:: python + + def instrument( + func: Union[VMClosure, PackedFunc], + func_symbol: str, + before_run: bool, + ret_value: any, + *args) -> bool: + pass + + The instrument takes the following parameters: + - func: function object to be called. + - func_symbol: the symbol name of the function. + - before_run: whether it is before or after call. + - ret_value: the return value of the call, only valid after run. + - args: the arguments being passed to call. + + The instrument function can choose an integer, + which corresponds to action direction for the + following run. See VMInstrumentReturnKind for + more details. + + Parameters + ---------- + instrument: tvm.runtime.PackedFunc + A instrumentation function that get invoked every VM call instr. + + See Also + -------- + VMInstrumentReturnKind: the possible return values in VM. + """ + self._set_instrument(instrument) + + def time_evaluator( + self, + func_name, + dev, + number=10, + repeat=1, + min_repeat_ms=0, + cooldown_interval_ms=0, + repeats_to_cooldown=1, + f_preproc="", + ) -> Callable[..., tvm.runtime.module.BenchmarkResult]: + """ + Returns an evaluator that times a function in the module. + This follows the same convention as time_evaluator in tvm.runtime.module. + This can be used in combination with save_function() so that the + timings avoid extra dictionary lookups. + + Parameters + ---------- + func_name: str + The name of the function in the module. + + dev: Device + The device we should run this function on. + + number: int + The number of times to run this function for taking average. + We call these runs as one `repeat` of measurement. + + repeat: int, optional + The number of times to repeat the measurement. + In total, the function will be invoked (1 + number x repeat) times, + where the first one is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + + min_repeat_ms: int, optional + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + + cooldown_interval_ms: int, optional + The cooldown interval in milliseconds between the number of repeats defined by + `repeats_to_cooldown`. + + repeats_to_cooldown: int, optional + The number of repeats before the cooldown is activated. + + f_preproc: str, optional + The preprocess function name we want to execute before executing the time evaluator. + + Note + ---- + The function will be invoked (1 + number x repeat) times, + with the first call discarded in case there is lazy initialization. + + Example + ------- + Normal use with a VM function (may not work over RPC if the function returns a tuple): + + .. code-block:: python + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(TestTimeEvaluator, target) + vm = relax.VirtualMachine(mod, tvm.cpu()) + timing_res = vm.time_evaluator("func_name", tvm.cpu())(arg0, arg1, ..., argn) + + Use with the stateful API: + + .. code-block:: python + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(TestTimeEvaluator, target) + vm = relax.VirtualMachine(mod, tvm.cpu()) + vm.set_input("func_name", arg0, arg1, ..., argn) + timing_res = vm.time_evaluator("invoke_stateful", tvm.cpu())("func_name") + + With saved closures via `save_function` (this results in + fewer dictionary lookups in the timed portion): + + .. code-block:: python + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(TestTimeEvaluator, target) + vm = relax.VirtualMachine(mod, tvm.cpu()) + vm.save_function("func_name", "func_name_saved", arg0, arg1, ..., argn) + timing_res = vm.time_evaluator("func_name_saved", tvm.cpu())() + + Returns + ------- + ftimer : function + The function that takes same argument as func and returns a BenchmarkResult. + The ProfileResult reports `repeat` time costs in seconds. + + """ + return self.module.time_evaluator( + func_name, + dev, + number=number, + repeat=repeat, + min_repeat_ms=min_repeat_ms, + cooldown_interval_ms=cooldown_interval_ms, + repeats_to_cooldown=repeats_to_cooldown, + f_preproc=f_preproc, + ) + + def profile(self, func_name: str, *args): + """Profile a function call. + Parameters + ---------- + func_name : str + The name of the function. + args: List of NDArray or other objects supported by PackedFunc. + The arguments to the function. + Returns + ------- + report: tvm.runtime.profiling.Report + The formatted profiling result, showing per-op timing measurements. + """ + cargs: List[Any] = [] + + for arg in args: + self._convert(arg, cargs) + + report_json = self.module["profile"](func_name, *cargs) + return Report.from_json(report_json) diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 269cab8e5d4d..b84f5534930c 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -32,6 +32,8 @@ class PrinterConfig(Object): show_meta: bool ir_prefix: str tir_prefix: str + relax_prefix: str + module_alias: str buffer_dtype: str int_dtype: str float_dtype: str @@ -52,6 +54,8 @@ def __init__( show_meta: bool = False, ir_prefix: str = "I", tir_prefix: str = "T", + relax_prefix: str = "R", + module_alias: str = "cls", buffer_dtype: str = "float32", int_dtype: str = "int32", float_dtype: str = "void", @@ -71,6 +75,8 @@ def __init__( "show_meta": show_meta, "ir_prefix": ir_prefix, "tir_prefix": tir_prefix, + "relax_prefix": relax_prefix, + "module_alias": module_alias, "buffer_dtype": buffer_dtype, "int_dtype": int_dtype, "float_dtype": float_dtype, @@ -111,6 +117,8 @@ def script( show_meta: bool = False, ir_prefix: str = "I", tir_prefix: str = "T", + relax_prefix: str = "R", + module_alias: str = "cls", buffer_dtype: str = "float32", int_dtype: str = "int32", float_dtype: str = "void", @@ -136,7 +144,11 @@ def script( The prefix of AST nodes from tvm.ir tir_prefix : str = "T" The prefix of AST nodes from tvm.tir - + relax_prefix : str = "R" + The prefix of AST nodes from tvm.relax + module_alias : str = "cls" + The alias of the current module at cross-function call, + Directly use module name if it's empty. buffer_dtype : str = "float32" The default data type of buffer int_dtype : str = "int32" @@ -174,6 +186,54 @@ def script( show_meta=show_meta, ir_prefix=ir_prefix, tir_prefix=tir_prefix, + relax_prefix=relax_prefix, + module_alias=module_alias, + buffer_dtype=buffer_dtype, + int_dtype=int_dtype, + float_dtype=float_dtype, + verbose_expr=verbose_expr, + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + syntax_sugar=syntax_sugar, + path_to_underline=path_to_underline, + path_to_annotate=path_to_annotate, + obj_to_underline=obj_to_underline, + obj_to_annotate=obj_to_annotate, + ), + ) + + def _relax_script( + self, + *, + name: Optional[str] = None, + show_meta: bool = False, + ir_prefix: str = "I", + tir_prefix: str = "T", + relax_prefix: str = "R", + module_alias: str = "cls", + buffer_dtype: str = "float32", + int_dtype: str = "int32", + float_dtype: str = "void", + verbose_expr: bool = False, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: int = -1, + syntax_sugar: bool = True, + path_to_underline: Optional[List[ObjectPath]] = None, + path_to_annotate: Optional[Dict[ObjectPath, str]] = None, + obj_to_underline: Optional[List[Object]] = None, + obj_to_annotate: Optional[Dict[Object, str]] = None, + ) -> str: + return _relax_script( + self, + PrinterConfig( + name=name, + show_meta=show_meta, + ir_prefix=ir_prefix, + tir_prefix=tir_prefix, + relax_prefix=relax_prefix, + module_alias=module_alias, buffer_dtype=buffer_dtype, int_dtype=int_dtype, float_dtype=float_dtype, @@ -198,6 +258,8 @@ def show( show_meta: bool = False, ir_prefix: str = "I", tir_prefix: str = "T", + relax_prefix: str = "R", + module_alias: str = "cls", buffer_dtype: str = "float32", int_dtype: str = "int32", float_dtype: str = "void", @@ -228,7 +290,11 @@ def show( The prefix of AST nodes from tvm.ir tir_prefix : str = "T" The prefix of AST nodes from tvm.tir - + relax_prefix : str = "R" + The prefix of AST nodes from tvm.relax + module_alias : str = "cls" + The alias of the current module at cross-function call, + Directly use module name if it's empty. buffer_dtype : str = "float32" The default data type of buffer int_dtype : str = "int32" @@ -264,6 +330,8 @@ def show( show_meta=show_meta, ir_prefix=ir_prefix, tir_prefix=tir_prefix, + relax_prefix=relax_prefix, + module_alias=module_alias, buffer_dtype=buffer_dtype, int_dtype=int_dtype, float_dtype=float_dtype, diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 9283727ad41a..f5ee692cbb8f 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -17,4 +17,3 @@ """TVM Script APIs of TVM Python Package""" from .parser import ir, ir_module from .parser import parse as from_source -from .parser import tir diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 7aa33ee49c72..1d5d050444f7 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame": _ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member return self - def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument - _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member + def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument + if exc_type is None and exc_value is None: + # Do not execute `FrameExit` if the with scope exits because of exceptions + _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member def add_callback(self, callback: Callable[[], None]) -> None: """Add a callback method invoked when exiting the with-scope. @@ -136,6 +138,17 @@ def current() -> "IRBuilder": """ return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member + @staticmethod + def is_in_scope() -> bool: + """See if the current thread-local scope has an IRBuilder. + + Returns + ------- + bool + Whether the current thread-local scope has an IRBuilder + """ + return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] # pylint: disable=no-member + def get(self) -> _Object: """Get the constructed IR.""" return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index ebb9728737ad..68eda2cfeebf 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,11 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import ir_module +from .ir import ( + decl_function, + def_function, + ir_module, + module_attrs, + module_global_infos, + dummy_global_info, +) diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 213180463cb2..53c48b4cc540 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,9 +16,91 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from typing import Dict, List + +from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, DummyGlobalInfo +from tvm.runtime import Object as tvm_Object + + from . import _ffi_api from .frame import IRModuleFrame def ir_module() -> IRModuleFrame: + """Start a ir_module frame. + Returns + ------- + frame: IRModuleFrame + The constructed frame. + """ return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member + + +def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar: + """Declare a Function without given the specific function implementation. + Parameters + ---------- + func_name : str + The function unique name. + + func_signature: Optional[BaseFunc] + A Function w/o body, which used to specify the function signature + (i.e. func params and func return type/shape). + + Note + ---- + It is usually used in cross-function call. And we can specify the function by `DefFunction` + Returns + ------- + gv : GlobalVar + The corresponding GlobalVar. + """ + + return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member + func_name, func_signature + ) + + +def def_function(func_name: str, func: BaseFunc) -> None: + """Define the function which is declared before. + Parameters + ---------- + func_name : str + The function unique name. + func: BaseFunc + The given function implementation + """ + return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member + + +def module_attrs(attrs: Dict[str, tvm_Object]) -> None: + """Specify the attrs of the ir_module frame. + Parameters + ---------- + attrs: Dict[str, Object] + The module attrs. + """ + return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member + + +def module_global_infos(global_infos: Dict[str, List[GlobalInfo]]) -> None: + """Specify the global infos of the ir_module frame. + Parameters + ---------- + global_infos: Dict[str, List[GlobalInfo]] + The module global infos. + """ + return _ffi_api.ModuleGlobalInfos(global_infos) # type: ignore[attr-defined] # pylint: disable=no-member + + +############################### GlobalInfo ############################### + + +def dummy_global_info() -> DummyGlobalInfo: + """Create a dummy global info expression. + Returns + ------- + res : DummyGlobalInfo + The result dummy global info. + """ + return DummyGlobalInfo() # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/ir_builder/relax/__init__.py b/python/tvm/script/ir_builder/relax/__init__.py new file mode 100644 index 000000000000..f0905acf34e3 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Package tvm.script.ir_builder.relax""" +from . import frame +from .ir import * # pylint: disable=wildcard-import,redefined-builtin diff --git a/python/tvm/script/ir_builder/relax/_ffi_api.py b/python/tvm/script/ir_builder/relax/_ffi_api.py new file mode 100644 index 000000000000..6e2098cf88af --- /dev/null +++ b/python/tvm/script/ir_builder/relax/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.script.ir_builder.relax""" +import tvm._ffi + +tvm._ffi._init_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/relax/frame.py b/python/tvm/script/ir_builder/relax/frame.py new file mode 100644 index 000000000000..97e181fbe4be --- /dev/null +++ b/python/tvm/script/ir_builder/relax/frame.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""IR Builder Frame for Relax dialect""" +from tvm._ffi import register_object as _register_object + +from ..base import IRBuilderFrame + + +@_register_object("script.ir_builder.relax.RelaxFrame") +class RelaxFrame(IRBuilderFrame): + """The base ir_builder frame for the relax dialect.""" + + +@_register_object("script.ir_builder.relax.SeqExprFrame") +class SeqExprFrame(RelaxFrame): + ... + + +@_register_object("script.ir_builder.relax.FunctionFrame") +class FunctionFrame(SeqExprFrame): + """The ir_builder frame for the relax function.""" + + +@_register_object("script.ir_builder.relax.BlockFrame") +class BlockFrame(RelaxFrame): + """The ir_builder frame for relax binding blocks.""" + + +@_register_object("script.ir_builder.relax.IfFrame") +class IfFrame(RelaxFrame): + ... + + +@_register_object("script.ir_builder.relax.ThenFrame") +class ThenFrame(SeqExprFrame): + ... + + +@_register_object("script.ir_builder.relax.ElseFrame") +class ElseFrame(SeqExprFrame): + ... diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py new file mode 100644 index 000000000000..2f8a37a4e1da --- /dev/null +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -0,0 +1,659 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, wrong-import-order, no-member, invalid-name +"""IRBuilder for Relax dialect""" + +import builtins +import functools +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union, Callable + +import tvm +from tvm import DataType, relax +from tvm.ir import PrimExpr +from ..ir import decl_function +from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, ShapeExpr, Var, VarBinding, const +from tvm.relax.utils import gen_call_tir_inputs + + +############################### Operators ############################### +from tvm.relax.op import ( + abs, + acos, + acosh, + asin, + asinh, + atan, + atanh, + add, + argmax, + argmin, + assert_op, + astype, + broadcast_to, + builtin, + call_builtin_with_ctx, + call_tir, + call_dps_packed, + ceil, + clip, + collapse_sum_like, + collapse_sum_to, + concat, + cos, + cosh, + cumsum, + divide, + equal, + ewise_fma, + exp, + expand_dims, + flatten, + floor, + floor_divide, + full, + full_like, + greater, + greater_equal, + image, + invoke_closure, + isfinite, + isinf, + isnan, + layout_transform, + less, + less_equal, + linear, + log, + make_closure, + matmul, + max, + maximum, + mean, + memory, + min, + minimum, + multiply, + negative, + not_equal, + null_value, + ones, + ones_like, + permute_dims, + power, + print, + prod, + repeat, + reshape, + tensor_to_shape, + shape_to_tensor, + round, + shape_of, + std, + strided_slice, + sum, + take, + variance, + sigmoid, + sign, + sin, + sinh, + split, + square, + squeeze, + sqrt, + subtract, + tan, + tanh, + tile, + tril, + triu, + unique, + vm, + where, + wrap_param, + zeros, + zeros_like, + nn, +) +from tvm.relax.op.builtin import stop_lift_params +from tvm.relax.struct_info import StructInfo +from tvm.relax.utils import args_converter +from tvm.runtime import Object as tvm_Object +from tvm.runtime import ObjectGeneric + +from . import _ffi_api, frame + +##################### Python Native Function Alias ###################### + +py_print = builtins.print +py_tuple = tuple +py_str = str + + +############################### Function ################################ + + +def function() -> frame.FunctionFrame: + """Start a function frame. + Returns + ------- + frame: FunctionFrame + The constructed function frame. + """ + return _ffi_api.Function() # type: ignore[attr-defined] # pylint: disable=no-member + + +def arg(name: py_str, struct_info: StructInfo) -> Var: + """Add a parameter to the last function frame. + Parameters + ---------- + name: str + The name of the parameter. + struct_info: StructInfo + The Struct Info of the parameter + + Returns + ------- + var: Var + The created function parameter var. + """ + + return _ffi_api.Arg(name, struct_info) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_name(name: py_str) -> None: + """Specify the name of the last function frame. + Parameters + ---------- + name: str + The function name. + """ + return _ffi_api.FuncName(name) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_attr(attrs: Dict[py_str, tvm_Object]) -> None: + """Specify the attrs of the last function frame. + Parameters + ---------- + attrs: Dict[str, Object] + The function attrs. + """ + return _ffi_api.FuncAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_ret_struct_info(ret_sinfo: StructInfo) -> None: + """Specify the return struct info of the last function frame. + Parameters + ---------- + ret_type: StructInfo + The function return struct info. + """ + return _ffi_api.FuncRetStructInfo(ret_sinfo) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_ret_value(value: Expr) -> None: + """Specify the return value of the last function frame. + Parameters + ---------- + value: Expr + The function return value. + """ + return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +############################# BindingBlock ############################## + + +def dataflow() -> frame.BlockFrame: + """Start a dataflow binding block frame. + Returns + ------- + frame: frame.BlockFrame + The created ir_builder Block frame. + """ + return _ffi_api.Dataflow() # type: ignore[attr-defined] # pylint: disable=no-member + + +def output(*vars: Tuple[Var]) -> None: + """Expose the dataflow block output variables as global ones. + Parameters + ---------- + vars: Tuple[Var] + The output variables of a dataflow block. + """ + return _ffi_api.DataflowBlockOutput(vars) # type: ignore[attr-defined] # pylint: disable=no-member + + +################################## Ops ################################# + + +@args_converter.auto +def call_packed( + func: py_str, + *args: Expr, + sinfo_args: Union[StructInfo, List[StructInfo]], + **kwargs: Any, +) -> Call: + """Create a relax Call, which calls a packed function. + Parameters + ---------- + func: str + The name of extern function. + *args : Expr + The arguments. + sinfo_args: Union[StructInfo, List[StructInfo]] + The list of structure info arguments. + kwargs: Expr + The keyword arguments. + + Returns + ------- + call: Call + The created Relax Call + """ + op = ExternFunc(func) + if sinfo_args is None: + raise ValueError("R.call_packed is required to have type_args") + if isinstance(sinfo_args, py_tuple): # type: ignore + sinfo_args = list(sinfo_args) + elif not isinstance(sinfo_args, list): + sinfo_args = [sinfo_args] + for i, sinfo_arg in enumerate(sinfo_args): + if callable(sinfo_arg): + sinfo_arg = sinfo_arg() + # Convert possible StructInfoProxy to StructInfo + if isinstance(sinfo_arg, ObjectGeneric): + sinfo_arg = sinfo_arg.asobject() + sinfo_args[i] = sinfo_arg + + is_default = False + if "attrs_type_key" in kwargs: + attrs_type_key = kwargs["attrs_type_key"] + kwargs.pop("attrs_type_key") + else: + attrs_type_key = "DictAttrs" + is_default = True + attrs = None + if kwargs or not is_default: + attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs) + + return Call(op, args, attrs=attrs, sinfo_args=sinfo_args) + + +def _sinfo_arg_wrapper(func): + """A wrapper to convert StructInfoProxies to StructInfo for builtin operators with sinfo_args""" + + def _convert_tensor_type(args): + if isinstance(args, (list, py_tuple)): # type: ignore + new_args = [_convert_tensor_type(x) for x in args] + return type(args)(new_args) + if isinstance(args, dict): + return {_convert_tensor_type(k): _convert_tensor_type(v) for k, v in args.items()} + if inspect.isfunction(args): + args = args() + if isinstance(args, ObjectGeneric): + args = args.asobject() + return args + + @functools.wraps(func) + def wrapped(*args, **kwargs): + return func(*_convert_tensor_type(args), **_convert_tensor_type(kwargs)) + + return wrapped # type: ignore + + +invoke_closure = _sinfo_arg_wrapper(invoke_closure) # pylint: disable=invalid-name + +call_builtin_with_ctx = _sinfo_arg_wrapper(call_builtin_with_ctx) # pylint: disable=invalid-name + + +############################### Emits ############################### + + +def emit(value: Expr, annotate_struct_info: Optional[StructInfo] = None) -> Var: + """Emit a binding to the last binding block frame. + Parameters + ---------- + value: Expr + The right side value of the bindings to be emitted. + + annotate_struct_info: Optional[StructInfo] + The optional struct info annotation for the emitted value. + + Returns + ------- + var: Var + The left side var of the emitted binding. + """ + return _ffi_api.Emit(value, annotate_struct_info) # type: ignore[attr-defined] # pylint: disable=no-member + + +def emit_te(func: Callable, *args: Any, **kwargs: Any) -> Call: + """Emit a call node according to the te function. + This function converts arguments from relax expression to te tensor, + The callback func should return a te tensor or a list of te tensors. + + Parameters + ---------- + func : Callable + A function that returns a te tensor or a list of te tensors. + + args : Any, optional + arguments passed to the function. + + kwargs : Any, optional + The keyword arguments passed to the function. + Note that the following keyword args are reserved: + + - 'primfunc_name_hint' for passing name hint to the PrimFunc + that gets generated. + - 'primfunc_attrs' is reserved for passing func attributes to + be added to the PrimFunc that gets created. + + Returns + ------- + call : Call + A newly created call that calls into a tir function. + """ + primfunc_name_hint = kwargs.pop("primfunc_name_hint", None) + tir_func, call_args, out_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs) + if not primfunc_name_hint: + primfunc_name_hint = func.__name__ + gvar = decl_function(primfunc_name_hint, tir_func) # type: ignore + return call_tir(gvar, call_args, out_sinfo, tir_vars) + + +def emit_match_cast(value: Expr, struct_info: StructInfo) -> Var: + """Emit a match_cast binding to the last binding block frame. + Parameters + ---------- + value: Expr + The value of the MatchCast to be emitted. + struct_info: StructInfo + The struct_info of the MatchCast to be emitted. + + Returns + ------- + var: Var + The left side var of the emitted binding. + """ + return _ffi_api.EmitMatchCast(value, struct_info) # type: ignore + + +def emit_var_binding(value: VarBinding) -> Var: + """Emit a binding to the last binding block frame. + Parameters + ---------- + value: VarBinding + The binding to be emitted. + Returns + ------- + var: Var + The left side var of the emitted binding. + """ + return _ffi_api.EmitVarBinding(value) # type: ignore + + +############################# If Then Else ############################# + + +def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name + """Create an if frame. + Parameters + ---------- + condition : Expr + The condition of if statement, executes the true branch if the condition is true, + otherwise jump into the false branch. + Returns + ------- + res : frame.IfFrame + The result IfFrame. + """ + return _ffi_api.If(condition) # type: ignore[attr-defined] # pylint: disable=no-member + + +def Then() -> frame.ThenFrame: # pylint: disable=invalid-name + """Create a then frame. + Returns + ------- + res : frame.ThenFrame + The result ThenFrame. + """ + return _ffi_api.Then() # type: ignore[attr-defined] # pylint: disable=no-member + + +def Else() -> frame.ElseFrame: # pylint: disable=invalid-name + """Create an else frame. + Returns + ------- + res : frame.ElseFrame + The result ElseFrame. + """ + return _ffi_api.Else() # type: ignore[attr-defined] # pylint: disable=no-member + + +############################### R.tuple ################################ + + +def tuple(*fields: Expr) -> Expr: + """Create a tuple expression. + Parameters + ---------- + *fields : Expr + The fields of the tuple. + Returns + ------- + res : Expr + The result tuple. + """ + if len(fields) == 0: + fields = py_tuple() + + return relax.Tuple(fields) # type: ignore[attr-defined] # pylint: disable=no-member + + +############################### R.shape ################################ + + +def shape(value: List[PrimExpr]) -> Expr: + """Create a ShapeExpr. + Parameters + ---------- + value : List[PrimExpr] + The fields of the tuple. + Returns + ------- + res : Expr + The result tuple. + """ + return relax.ShapeExpr(value) # pylint: disable=no-member # type: ignore + + +############################### PrimValue ############################## + + +def prim_value(value: PrimExpr) -> Expr: + """Create a prim value expression. + Parameters + ---------- + value : PrimExpr + The value of the prim value. + Returns + ------- + res : Expr + The result prim value. + """ + return relax.PrimValue(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +def str(value: py_str) -> Expr: + """Create a string imm expression. + Parameters + ---------- + value : str + The value of the str. + Returns + ------- + res : Expr + The result str. + """ + return relax.StringImm(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +def dtype(value: Union[py_str, DataType]) -> Expr: + """Create a dtype imm expression. + Parameters + ---------- + value : dtype + The value of the dtype. + Returns + ------- + res : Expr + The result dtype. + """ + return relax.DataTypeImm(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +############################### Importer ############################### + +__all__ = [ + "Else", + "If", + "Then", + "TupleGetItem", + "abs", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atanh", + "add", + "arg", + "argmax", + "argmin", + "assert_op", + "astype", + "broadcast_to", + "builtin", + "call_packed", + "call_tir", + "call_dps_packed", + "call_builtin_with_ctx", + "ceil", + "clip", + "collapse_sum_like", + "collapse_sum_to", + "concat", + "cos", + "cosh", + "const", + "cumsum", + "dataflow", + "divide", + "dtype", + "emit", + "emit_te", + "emit_var_binding", + "emit_match_cast", + "equal", + "ewise_fma", + "exp", + "expand_dims", + "flatten", + "floor", + "floor_divide", + "full", + "full_like", + "func_attr", + "func_name", + "func_ret_struct_info", + "func_ret_value", + "function", + "greater", + "greater_equal", + "image", + "invoke_closure", + "isfinite", + "isinf", + "isnan", + "layout_transform", + "less", + "less_equal", + "linear", + "log", + "make_closure", + "matmul", + "max", + "maximum", + "mean", + "memory", + "min", + "minimum", + "multiply", + "negative", + "not_equal", + "null_value", + "ones", + "ones_like", + "output", + "permute_dims", + "power", + "prim_value", + "print", + "prod", + "repeat", + "reshape", + "tensor_to_shape", + "shape_to_tensor", + "round", + "shape", + "shape_of", + "ShapeExpr", + "std", + "str", + "strided_slice", + "sum", + "sigmoid", + "sign", + "sin", + "sinh", + "split", + "square", + "squeeze", + "sqrt", + "stop_lift_params", + "str", + "strided_slice", + "subtract", + "take", + "tan", + "tanh", + "tile", + "tril", + "triu", + "tuple", + "unique", + "variance", + "vm", + "where", + "wrap_param", + "zeros", + "zeros_like", + "nn", +] diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c3ced1e0338b..5f324393090a 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1810,6 +1810,10 @@ def wrapped(*args, **kwargs): TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic) end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic) +anylist_getitem = _op_wrapper(_tir_op.anylist_getitem) +anylist_resetitem = _op_wrapper(_tir_op.anylist_resetitem) +anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed) +anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked) def _dtype_forward(func): @@ -2089,6 +2093,10 @@ def wrapped(*args, **kwargs): "start_profile_intrinsic", "end_profile_intrinsic", "meta_var", + "anylist_getitem", + "anylist_resetitem", + "anylist_setitem_call_packed", + "anylist_setitem_call_cpacked", "llvm_lookup_intrinsic_id", "type_annotation", "broadcast", diff --git a/python/tvm/script/parser/__init__.py b/python/tvm/script/parser/__init__.py index 5161a2601c49..ba7f085c08a4 100644 --- a/python/tvm/script/parser/__init__.py +++ b/python/tvm/script/parser/__init__.py @@ -13,9 +13,8 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -# under the Licens. +# under the License. """The parser""" -from . import _core, ir, tir +from . import _core, ir from ._core import parse from .ir import ir_module -from .tir import prim_func diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py index ad7ae5034780..2767a97f6096 100644 --- a/python/tvm/script/parser/core/diagnostics.py +++ b/python/tvm/script/parser/core/diagnostics.py @@ -220,7 +220,7 @@ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) level : diagnostics.DiagnosticLevel The diagnostic level. """ - lineno = node.lineno or self.source.start_line + lineno = node.lineno or 1 col_offset = node.col_offset or self.source.start_column end_lineno = node.end_lineno or lineno end_col_offset = node.end_col_offset or col_offset diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 9e6c100c954d..d8a11f5b462a 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -43,6 +43,7 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) if extra_vars is None: import tvm # pylint: disable=import-outside-toplevel from tvm.script.parser import ir # pylint: disable=import-outside-toplevel + from tvm.script.parser import relax # pylint: disable=import-outside-toplevel from tvm.script.parser import tir # pylint: disable=import-outside-toplevel extra_vars = { @@ -51,6 +52,8 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) "ir": ir, "T": tir, "tir": tir, + "relax": relax, + "R": relax, } source = Source(program) diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 3a72a3c33106..075aedd89146 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -203,7 +203,7 @@ def _visit(self, node: doc.AST) -> Any: else: value = self._eval_expr(node.__class__(**fields)) except Exception as e: # pylint: disable=broad-except,invalid-name - self.parser.report_error(node, str(e)) + self.parser.report_error(node, e) return self._add_intermediate_result(value) def _eval_lambda(self, node: doc.Lambda) -> Any: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 7c699c42aecb..9dbe7a8e3479 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -19,10 +19,12 @@ from collections import defaultdict from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Set, Union + import numpy as np -from tvm._ffi.base import TVMError +from tvm._ffi.base import TVMError from tvm.error import DiagnosticError +from tvm.ir import GlobalVar from . import dispatch, doc from .diagnostics import Diagnostics, Source @@ -148,7 +150,7 @@ def add(self, var: str, value: Any, allow_shadowing: bool = False): The value of variable. allow_shadowing : bool - The options of whether variable shadowing allwed for this variable. + The options of whether variable shadowing allowed for this variable. """ # Skip if the key and value are equal to those in the var_table if self.name2value[var] and isinstance(self.name2value[var][-1], type(value)): @@ -259,13 +261,24 @@ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: node = self.diag.source.as_ast() self.visit(node) + def get_dispatch_token(self, node: doc.FunctionDef) -> str: + if not isinstance(node, doc.FunctionDef): + self.report_error(node, "Only can get dispatch token for function.") + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if not hasattr(decorator, "dispatch_token"): + self.report_error(node, "The parser does not understand the decorator") + return decorator.dispatch_token + def with_dispatch_token(self, token: str): """Add a new dispatching token as with statement. Parameters ---------- token : str - The dispathing token. + The dispatching token. Returns ------- @@ -273,10 +286,17 @@ def with_dispatch_token(self, token: str): The context with new dispatching token. """ + self.dispatch_tokens.append(token) + enter_func = dispatch.get(token=token, type_name="enter_token", default=lambda *args: None) + context = enter_func(self) + def pop_token(): + exit_func = dispatch.get( + token=token, type_name="exit_token", default=lambda *args: None + ) + exit_func(self, context) self.dispatch_tokens.pop() - self.dispatch_tokens.append(token) return _deferred(pop_token) def eval_expr( @@ -357,12 +377,12 @@ def eval_assign( The value binding method when assigning the values to variables. allow_shadowing : bool - The options of whether variable shadowing allwed for assignment. + The options of whether variable shadowing allowed for assignment. Returns ------- res : Dict[str, Any] - The dirctionary of assignment result. + The dictionary of assignment result. """ if self._duplicate_lhs_check(target) is True: self.report_error(target, "Duplicate vars assigned.") @@ -388,6 +408,8 @@ def report_error( # Only take the last line of the error message if isinstance(err, TVMError): msg = list(filter(None, str(err).split("\n")))[-1] + elif isinstance(err, KeyError): + msg = "KeyError: " + str(err) else: msg = str(err) self.diag.error(node, msg) @@ -457,30 +479,26 @@ def visit_tvm_annotation(self, node: doc.expr) -> Any: """ return _dispatch(self, "tvm_annotation")(self, node) - def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name - """The general function definition visiting method. + def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name + """The general function definition visit method. Parameters ---------- node : doc.FunctionDef - The doc AST function definition node. - - Returns - ------- - res : Any - The visiting result. + The doc FunctionDef node. """ - if not node.decorator_list: - self.report_error(node, "Function must be decorated") - # TODO: only the last decorator is parsed - decorator = self.eval_expr(node.decorator_list[-1]) - if not hasattr(decorator, "dispatch_token"): - self.report_error(node, "The parser does not understand the decorator") - token = decorator.dispatch_token + token = self.get_dispatch_token(node) func = dispatch.get(token=token, type_name="FunctionDef", default=None) if func is None: self.report_error(node, "The parser does not understand the decorator") + _dispatch(self, "pre_visit_local_function")(self, node) _dispatch_wrapper(func)(self, node) + _dispatch(self, "post_visit_local_function")(self, node) + + def visit_tvm_declare_function(self, node: doc.FunctionDef) -> GlobalVar: + token = self.get_dispatch_token(node) + with self.with_dispatch_token(token): + return _dispatch(self, "tvm_declare_function")(self, node) def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name """The general class definition visiting method. @@ -596,7 +614,7 @@ def visit_Expr(self, node: doc.Expr) -> Any: # pylint: disable=invalid-name Parameters ---------- node : doc.Expr - The doc AST exprssion node. + The doc AST expression node. Returns ------- diff --git a/python/tvm/script/parser/core/utils.py b/python/tvm/script/parser/core/utils.py index 6a693df12f89..3edae3f25a33 100644 --- a/python/tvm/script/parser/core/utils.py +++ b/python/tvm/script/parser/core/utils.py @@ -22,6 +22,30 @@ from .diagnostics import findsource +def get_func_nonlocals(func): + """A modified version of `inspect.getclosurevars`""" + + if inspect.ismethod(func): + func = func.__func__ + + if not inspect.isfunction(func): + raise TypeError("{!r} is not a Python function".format(func)) + + code = func.__code__ + # Nonlocal references are named in co_freevars and resolved + # by looking them up in __closure__ by positional index + nonlocal_vars = {} + if func.__closure__ is not None: + for var, cell in zip(code.co_freevars, func.__closure__): + try: + nonlocal_vars[var] = cell.cell_contents + except ValueError as err: + # cell_contents may raise ValueError if the cell is empty. + if "empty" not in str(err): + raise + return nonlocal_vars + + def inspect_function_capture(func: Callable) -> Dict[str, Any]: """Capture function non-locals and global variables. @@ -37,7 +61,7 @@ def inspect_function_capture(func: Callable) -> Dict[str, Any]: """ captured = { **func.__globals__, # type: ignore - **inspect.getclosurevars(func).nonlocals, + **get_func_nonlocals(func), } return captured diff --git a/python/tvm/script/parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py index fedd2f0a14a8..f8c9d4f0afc9 100644 --- a/python/tvm/script/parser/ir/__init__.py +++ b/python/tvm/script/parser/ir/__init__.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """The ir module parser""" - +from ...ir_builder.ir import * # pylint: disable=redefined-builtin from . import parser as _parser from .entry import ir_module -__all__ = ["ir_module"] +__all__ = ["ir_module", "module_attrs", "module_global_infos", "dummy_global_info"] diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index e0268412d284..e11fa431627b 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -14,12 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-argument """The base parser for ir module""" from ...ir_builder import ir as I from .._core import Parser, dispatch, doc +class ModuleWithGlobalVars: + """A Module that can add global vars during parsing, to support `Module.function` syntax.""" + + def __getattr__(self, attr): + # Customize the error message. + # NOTE: `__getattr__` is only called when the attribute access fails with an AttributeError + raise AttributeError(f"Cannot find the function `{attr}` in the current IRModule") + + @dispatch.register(token="ir", type_name="ClassDef") def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: """The class definition visiting method for ir module. @@ -32,10 +42,32 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: node : doc.ClassDef The doc AST class definition node. """ + with self.var_table.with_frame(): with I.ir_module(): + # Step 0. Add the class name to the var table + fake_module = ModuleWithGlobalVars() + self.var_table.add(node.name, fake_module) + + # Step 1. Visit non-function stmts, including but not limited to + # 1. `I.module_attrs` + # 2. `I.module_global_infos` + with self.with_dispatch_token("ir"): + for stmt in node.body: + if not isinstance(stmt, doc.FunctionDef): + self.visit(stmt) + + # Step 2. Visit function stmts to declare the global vars + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + global_var = self.visit_tvm_declare_function(stmt) + fake_module.__setattr__(stmt.name, global_var) + + # Step 3. Visit and parse the functions with self.with_dispatch_token("ir"): - self.visit_body(node.body) + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit(stmt) @dispatch.register(token="ir", type_name="Assign") @@ -53,7 +85,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None: @dispatch.register(token="ir", type_name="Expr") -def _visit_expr(_self: Parser, _node: doc.Expr) -> None: +def _visit_expr(self: Parser, node: doc.Expr) -> None: """The expression visiting method for ir module. Parameters @@ -64,6 +96,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None: node : doc.ClassDef The doc AST expression node. """ + self.eval_expr(node.value) @dispatch.register(token="default", type_name="Assign") @@ -75,3 +108,13 @@ def visit_assign(self: Parser, node: doc.Assign) -> None: self.eval_assign( target=lhs, source=rhs, bind_value=lambda _a, _b, _c, value: value, allow_shadowing=True ) + + +@dispatch.register(token="default", type_name="pre_visit_local_function") +def pre_visit_local_function(self: Parser, node: doc.Expr) -> None: + pass + + +@dispatch.register(token="default", type_name="post_visit_local_function") +def post_visit_local_function(self: Parser, node: doc.Expr) -> None: + pass diff --git a/python/tvm/script/parser/relax/__init__.py b/python/tvm/script/parser/relax/__init__.py new file mode 100644 index 000000000000..1715526086b7 --- /dev/null +++ b/python/tvm/script/parser/relax/__init__.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Initial impl of relax parser for sugars""" + +from typing import TYPE_CHECKING + +from ...ir_builder.relax import * # pylint: disable=redefined-builtin +from ...ir_builder.relax import ir as _relax +from . import parser as _parser +from .entry import Callable, Object, Prim, Shape, Tensor, Tuple, match_cast + +if TYPE_CHECKING: + # pylint: disable=invalid-name + # Define prim_func and make it type check as static method + # so most tvmscript won't trigger pylint error here. + function = staticmethod +else: + from .entry import function + +__all__ = _relax.__all__ + [ + "Callable", + "Object", + "Prim", + "Shape", + "Tensor", + "Tuple", + "function", + "match_cast", +] diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py new file mode 100644 index 000000000000..acb490a813b8 --- /dev/null +++ b/python/tvm/script/parser/relax/entry.py @@ -0,0 +1,341 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +import inspect +from typing import Any +from typing import Callable as _Callable +from typing import Dict, List, Optional, Set, TypeVar, Union + +from tvm.relax import ( + Expr, + ShapeExpr, + FuncStructInfo, + Function, + ObjectStructInfo, + PrimStructInfo, + ShapeStructInfo, + StructInfo, + TensorStructInfo, + TupleStructInfo, +) +from tvm.relax.expr import Var +from tvm.runtime import ObjectGeneric +from tvm.tir import PrimExpr + +from .._core import parse, utils + +FType = TypeVar("FType", bound=_Callable) + +############################## R.function ############################## + + +def function(f: FType) -> Union[Function, FType]: + if not inspect.isfunction(f): + raise TypeError(f"Expect a function, but got: {f}") + if utils.is_defined_in_class(inspect.stack(), f): + return f + return parse(f, utils.inspect_function_capture(f)) + + +setattr(function, "dispatch_token", "relax") + + +############################# Struct Info ############################## + + +class StructInfoProxy(ObjectGeneric): + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> StructInfo: + raise NotImplementedError() + + def get_symbolic_vars(self) -> Set[str]: + return {} + + def asobject(self): + return self.as_struct_info(None) + + +############################### R.Tensor ############################### + + +def _eval_shape(expr: Union[str, PrimExpr], dict_globals: Optional[Dict[str, Any]]) -> PrimExpr: + if isinstance(expr, str): + code = compile(expr, "", "eval") + return eval(code, dict_globals or {}) # pylint: disable=eval-used + else: + return expr + + +class TensorProxy(StructInfoProxy): + shape: Optional[List[Union[str, PrimExpr]]] + dtype: str + ndim: int + + def __init__( + self, + shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None, + dtype: Optional[str] = None, + ndim: int = -1, + ) -> None: + if isinstance(shape, Expr): + if not isinstance(shape, (ShapeExpr, Var)): + raise ValueError( + "When the shape is an Expr, it must be a ShapeExpr or a Var with ShapeExpr " + f"value. But got: {shape} with type: {type(shape)}" + ) + if isinstance(shape, Var) and not isinstance(shape.struct_info, ShapeStructInfo): + raise ValueError( + "When the shape is a Var, it must have shape struct_info. But got " + f"{shape} with struct_info: {shape.struct_info}" + ) + self.shape = shape + self.dtype = dtype + self.ndim = ndim + + def get_symbolic_vars(self) -> Set[str]: + if self.shape is None or isinstance(self.shape, Expr): + return {} + else: + return {s for s in self.shape if isinstance(s, str) and s.isidentifier()} + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo: + if self.shape is None: + return TensorStructInfo(None, self.dtype, self.ndim) + elif isinstance(self.shape, (ShapeExpr, Var)): + return TensorStructInfo(self.shape, self.dtype, self.ndim) + else: + if dict_globals is None and any([isinstance(s, str) for s in self.shape]): + raise ValueError( + "String-defined shape expr is only allowed when parsing function parameters " + "and return annotations for TVMScript." + ) + shape = [_eval_shape(s, dict_globals) for s in self.shape] + return TensorStructInfo(shape, self.dtype, self.ndim) + + +def Tensor( + shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None, + dtype: Optional[str] = None, + ndim: int = -1, +) -> TensorProxy: + # scalar tensor case + if shape is not None and not isinstance(shape, Var) and len(shape) == 0: + shape = [] + if isinstance(shape, str) and dtype is None: + dtype = shape + shape = None + + if shape is not None and not isinstance(shape, (tuple, list)) and not isinstance(shape, Expr): + raise ValueError(f"shape must be a list/tuple or an Expr, but got: {shape}") + return TensorProxy(shape, dtype, ndim) + + +############################## R.Callable ############################## + + +class CallableProxy(StructInfoProxy): + params: List[StructInfoProxy] + ret: StructInfoProxy + """Function type. + + A function type consists of a list of type parameters to enable + the definition of generic functions, + a set of type constraints which we omit for the time being, + a sequence of argument types, and a return type. + + Parameters + ---------- + params : List[StructInfoProxy] + The argument StructInfoProxy + + ret : StructInfoProxy + The return StructInfoProxy. + + """ + + def __init__( + self, + params: Union[StructInfoProxy, List[StructInfoProxy]], + ret: StructInfoProxy, + ) -> None: + if not isinstance(params, (list, tuple)): + params = [params] + # convert `R.Tensor` to `R.Tensor()` + self.params = [param() if callable(param) else param for param in params] + self.ret = ret() if callable(ret) else ret + + def get_symbolic_vars(self) -> Set[str]: + return set().union(*[p.get_symbolic_vars() for p in self.params]) + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncStructInfo: + params = [param.as_struct_info(dict_globals) for param in self.params] + ret = self.ret.as_struct_info(dict_globals) + return FuncStructInfo(params, ret) + + +def Callable( + params: Union[StructInfoProxy, List[StructInfoProxy]], + ret: StructInfoProxy, +) -> CallableProxy: + return CallableProxy(params, ret) + + +############################### R.Tuple ################################ + + +class TupleProxy(StructInfoProxy): + fields: List[StructInfoProxy] + """The type of tuple values. + + Parameters + ---------- + fields : List[StructInfoProxy] + The fields in the tuple + """ + + def __init__( + self, + *fields: List[StructInfoProxy], + ) -> None: + if len(fields) == 1 and isinstance(fields[0], (tuple, list)): + fields = fields[0] + # convert `R.Tensor` to `R.Tensor()` + self.fields = [field() if callable(field) else field for field in fields] + + def get_symbolic_vars(self) -> Set[str]: + return set().union(*[f.get_symbolic_vars() for f in self.fields]) + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TupleStructInfo: + fields = [field.as_struct_info(dict_globals) for field in self.fields] + return TupleStructInfo(fields) + + +def Tuple(*fields: List[StructInfoProxy]) -> TupleProxy: + return TupleProxy(*fields) + + +############################### R.Shape ################################ + + +class ShapeProxy(StructInfoProxy): + values: Optional[List[PrimExpr]] + ndim: int + """The type of shape values. + + Parameters + ---------- + values : Optional[List[PrimExpr]] + The symbolic shape values if known. + + ndim : Optional[int] + The size of the shape. + """ + + def __init__( + self, + values: Optional[List[PrimExpr]] = None, + ndim: int = -1, + ) -> None: + self.values = values + self.ndim = ndim + + def get_symbolic_vars(self) -> Set[str]: + if self.values is None: + return {} + else: + return {v for v in self.values if isinstance(v, str) and v.isidentifier()} + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: + values = [_eval_shape(v, dict_globals) for v in self.values] if self.values else None + return ShapeStructInfo(values, self.ndim) + + +def Shape(values: Optional[List[PrimExpr]] = None, ndim: int = -1) -> ShapeProxy: + return ShapeProxy(values, ndim) + + +############################### R.Object ################################ + + +class ObjectProxy(StructInfoProxy): + """The proxy fo ObjectStructInfo. + + Parameters + ---------- + values : Optional[List[PrimExpr]] + The symbolic shape values if known. + + ndim : Optional[int] + The size of the shape. + """ + + def __init__(self) -> None: + pass + + def get_symbolic_vars(self) -> Set[str]: + return set() + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: + return ObjectStructInfo() + + +def Object() -> ObjectProxy: + return ObjectProxy() + + +################################ R.Prim ################################ + + +class PrimProxy(StructInfoProxy): + dtype: str + """The type of shape values. + + Parameters + ---------- + dtype : str + The data type. + """ + + def __init__(self, dtype: str) -> None: + self.dtype = dtype + + def get_symbolic_vars(self) -> Set[str]: + return set() + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: + return PrimStructInfo(self.dtype) + + +def Prim(dtype: str) -> PrimProxy: + return PrimProxy(dtype) + + +############################ R.match_cast ############################# +class MatchCastPair: + value: Expr + struct_info: StructInfo + + def __init__(self, value: Expr, struct_info: StructInfo) -> None: + self.value = value + self.struct_info = struct_info + + +def match_cast(value: Expr, struct_info: StructInfo): + if value is None: + raise ValueError("value of match_cast cannot be None") + if struct_info is None: + raise ValueError("struct_info of match_cast cannot be None") + return MatchCastPair(value, struct_info) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py new file mode 100644 index 000000000000..06fc51b7a607 --- /dev/null +++ b/python/tvm/script/parser/relax/parser.py @@ -0,0 +1,359 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, unused-argument + +import functools +import numbers +from typing import Any, Dict, Optional + +from tvm import relax, tir +from tvm.ir import GlobalVar, structural_equal +from tvm.relax import Expr, StructInfo +from tvm.relax.utils import convert_to_expr +from tvm.script.ir_builder.relax.frame import BlockFrame + +from ...ir_builder import ir as I +from ...ir_builder import relax as R +from ...ir_builder.base import IRBuilder +from .._core import Parser, dispatch, doc +from .entry import MatchCastPair, StructInfoProxy, TupleProxy + + +def bind_assign_value( + self: Parser, + node: doc.expr, + var_name: str, + value: Any, + anno_sinfo: Optional[StructInfo] = None, +) -> Any: + var_table = self.var_table.get() + + if isinstance(value, tir.Var): + if value.name and var_name != value.name: + self.report_error( + node, + "Cannot define TIR variables with different names. The LHS of binding should " + "has the same name provided in RHS.", + ) + if var_name in var_table: + prev_value = var_table[var_name] + if not isinstance(prev_value, tir.Var): + self.report_error( + node, + "Cannot redefine a non-TIR-variable object to a TIR variable. Please " + "define the TIR variable with another name.", + ) + if prev_value.dtype != value.dtype: + self.report_error( + node, + "Expected the same dtype for TIR vars " + f"but got {value.dtype} vs {prev_value.dtype}", + ) + return prev_value + IRBuilder.name(var_name, value) + return value + + if isinstance(value, tuple): + value = convert_to_expr(value) + if isinstance(value, numbers.Number): + value = R.const(value) + + if isinstance(value, relax.Expr): + var = R.emit(value, anno_sinfo) + elif isinstance(value, MatchCastPair): + if anno_sinfo is not None and not structural_equal(anno_sinfo, value.struct_info): + self.report_error( + node, "Cannot specify inconsistent annotation for a match cast pair. " + ) + var = R.emit_match_cast(value.value, value.struct_info) + else: + return value + # raise TypeError(f"Unsupported type {type(value)} in assignment") + + IRBuilder.name(var_name, var) + return var + + +def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy: + try: + annotation = self.eval_expr(node) + if annotation is None: + return TupleProxy([]) + if callable(annotation): + annotation = annotation() + if isinstance(annotation, StructInfoProxy): + return annotation + raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.") + except Exception as err: + self.report_error(node, str(err)) + raise err + + +def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> StructInfo: + var_table = self.var_table.get() if eval_str else None + try: + return eval_struct_info_proxy(self, node).as_struct_info(var_table) + except Exception as err: + self.report_error(node, str(err)) + raise err + + +def is_called(node: Any, func_name: str) -> bool: + # Check if it calls into a func + if isinstance(node, doc.Call): + # Recursive call was found + if isinstance(node.func, doc.Name) and node.func.id == func_name: + return True + elif isinstance(node, (list, tuple)): + for stmt in node: + if is_called(stmt, func_name): + return True + elif isinstance(node, (doc.AnnAssign, doc.Assign, doc.Return, doc.Expr)): + return is_called(node.value, func_name) + elif isinstance(node, doc.With): + return is_called(node.body, func_name) + elif isinstance(node, doc.If): + smts = [] + if node.body is not None: + smts = smts + list(node.body) + if node.orelse is not None: + smts = smts + list(node.orelse) + return is_called(smts, func_name) + return False + + +def is_recursive(node: doc.FunctionDef) -> bool: + # Check if it is a recursive function + for stmt in node.body: + if is_called(stmt, node.name): + return True + return False + + +def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None: + # Collect symbolic vars from parameters + symbolic_vars = set() + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation) + symbolic_vars.update(param_sinfo_proxy.get_symbolic_vars()) + + # Define symbolic vars to the current var_table frame + for var_name in symbolic_vars: + self.var_table.add(var_name, tir.Var(var_name, "int64"), allow_shadowing=False) + + +@dispatch.register(token="relax", type_name="FunctionDef") +def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + # reserve a var for local function + func_val = self.var_table.get().get(node.name) + if not func_val and is_recursive(node): + collect_symbolic_var_from_params(self, node) + if node.returns is None: + ret_sinfo = relax.TupleStructInfo([]) + else: + ret_sinfo = eval_struct_info(self, node.returns, eval_str=True) + params_sinfo = [] + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) + params_sinfo.append(param_sinfo) + # created a var for the local function, the same var could be used for recursive call + local_func_var = relax.Var(node.name, relax.FuncStructInfo(params_sinfo, ret_sinfo)) + self.var_table.add(node.name, local_func_var) + + with self.var_table.with_frame(): + with self.with_dispatch_token("relax"): + with R.function(): + R.func_name(node.name) + collect_symbolic_var_from_params(self, node) + + if node.returns is not None: + ann_sinfo = eval_struct_info(self, node.returns, eval_str=True) + R.func_ret_struct_info(ann_sinfo) + + self.visit(node.args) + + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + if not stmt.decorator_list: + self.report_error(stmt, "Function must be decorated") + dec = self.eval_expr(stmt.decorator_list[-1]) + # inline prim_func was found + if dec.dispatch_token == "tir": + self.report_error(stmt, "inline prim_func is disallowed in Relax IR") + + self.visit_body(node.body) + + +@dispatch.register(token="relax", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar: + with self.var_table.with_frame(): + collect_symbolic_var_from_params(self, node) + + if node.returns is None: + # Use ObjectStructInfo as unknown return type + # NOTE: Cannot use VoidStructInfo here because the return type can be refined later. + ret_sinfo = relax.ObjectStructInfo() + else: + ret_sinfo = eval_struct_info(self, node.returns, eval_str=True) + params = [] + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) + params.append(relax.Var(arg.arg, param_sinfo)) + + func_signature = relax.Function.create_empty(params, ret_sinfo) + return I.decl_function(node.name, func_signature) + + +@dispatch.register(token="relax", type_name="pre_visit_local_function") +def pre_visit_local_function(self: Parser, node: doc.Expr) -> None: + ir_builder = IRBuilder() + ir_builder.__enter__() + + +@dispatch.register(token="relax", type_name="post_visit_local_function") +def post_visit_local_function(self: Parser, node: doc.Expr) -> None: + ir_builder = IRBuilder.current() + result = ir_builder.get() + ir_builder.__exit__(None, None, None) + # reuse var if it is reserved + reserved_var = self.var_table.get().get(node.name) + if reserved_var: + var = R.emit_var_binding(relax.VarBinding(reserved_var, result)) + else: + var = R.emit(result) + IRBuilder.name(node.name, var) + self.var_table.add(node.name, var, allow_shadowing=False) + + +@dispatch.register(token="relax", type_name="Expr") +def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: + value = self.eval_expr(node.value) + if value is not None: + self.report_error(node, f"Unsupported Expr stmt type {value}.") + + +@dispatch.register(token="relax", type_name="arguments") +def visit_arguments(self: Parser, node: doc.arguments) -> None: + arg: doc.arg + for arg in node.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) + param = R.arg(arg.arg, param_sinfo) + + self.var_table.add(arg.arg, param) + + +@dispatch.register(token="relax", type_name="tvm_annotation") +def visit_tvm_annotation(self: Parser, node: doc.expr) -> StructInfo: + return eval_struct_info(self, node, eval_str=False) + + +@dispatch.register(token="relax", type_name="With") +def visit_with(self: Parser, node: doc.With) -> None: + # Currently only `with R.dataflow()` is supported + if len(node.items) != 1: + self.report_error(node, "Only one item is allowed.") + item = node.items[0] + if item.optional_vars is not None: + self.report_error( + item.context_expr, + "Relax syntax doesn't allow binding expressions in `with` to variables", + ) + frame = self.eval_expr(item.context_expr) + with self.var_table.with_frame(): + with frame: + self.visit(node.body) + if isinstance(frame, BlockFrame) and frame.is_dataflow: + output_vars = frame.output_vars + for var in output_vars: + self.var_table.add(var.name_hint, var, allow_shadowing=True) + + +@dispatch.register(token="relax", type_name="Assign") +def visit_assign(self: Parser, node: doc.Assign) -> None: + if len(node.targets) != 1: + self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") + lhs = node.targets[0] + rhs = self.eval_expr(node.value) + self.eval_assign( + target=lhs, + source=rhs, + bind_value=bind_assign_value, + allow_shadowing=True, + ) + + +@dispatch.register(token="relax", type_name="AnnAssign") +def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: + lhs = node.target + rhs = self.eval_expr(node.value) + anno_sinfo = self.visit_tvm_annotation(node.annotation) + self.eval_assign( + target=lhs, + source=rhs, + bind_value=functools.partial(bind_assign_value, anno_sinfo=anno_sinfo), + allow_shadowing=True, + ) + + +@dispatch.register(token="relax", type_name="Return") +def visit_return(self: Parser, node: doc.Assign) -> None: + value = self.eval_expr(node.value) + value = convert_to_expr(value) + R.func_ret_value(value) + + +@dispatch.register(token="relax", type_name="If") +def visit_if(self: Parser, node: doc.If) -> None: + if node.orelse is None: + raise ValueError("Else statements are required for relax dialect.") + with R.If(self.eval_expr(node.test)) as if_frame: + with self.var_table.with_frame(): + with R.Then(): + self.visit_body(node.body) + with self.var_table.with_frame(): + with R.Else(): + self.visit_body(node.orelse) + self.var_table.add(if_frame.var_name, if_frame.var, allow_shadowing=True) + + +@dispatch.register(token="relax", type_name="enter_token") +def enter_token(self: Parser) -> Dict[str, Any]: + def relax_call(self, *args) -> Expr: + if all(isinstance(x, Expr) for x in args): + return relax.Call(self, args) + arg_types = [type(x) for x in args] + raise RuntimeError( + "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types) + ) + + context = {"GlobalVar.__call__": GlobalVar.__call__} + GlobalVar.__call__ = relax_call + return context + + +@dispatch.register(token="relax", type_name="exit_token") +def exit_token(self: Parser, context: Dict[str, Any]) -> None: + assert "GlobalVar.__call__" in context + GlobalVar.__call__ = context.get("GlobalVar.__call__") diff --git a/python/tvm/script/parser/tir/__init__.py b/python/tvm/script/parser/tir/__init__.py index ad16821a89a3..e44b6b521b27 100644 --- a/python/tvm/script/parser/tir/__init__.py +++ b/python/tvm/script/parser/tir/__init__.py @@ -32,4 +32,4 @@ else: from .entry import prim_func -__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"] +__all__ = _tir.__all__ + ["Buffer", "Ptr", "bool", "prim_func"] diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 411a7f8f3c83..649f817411f0 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -83,7 +83,7 @@ def __getitem__(self, keys) -> Buffer: return self(keys) if len(keys) >= 2 and not isinstance(keys[1], str): return self(keys) - return self(*keys) # pylint: disable=no-member # type: ignore + return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member class PtrProxy: @@ -93,7 +93,7 @@ class PtrProxy: def __call__(self, dtype, storage_scope="global"): if callable(dtype): dtype = dtype().dtype - return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore + return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member @deprecated("T.Ptr[...]", "T.handle(...)") def __getitem__(self, keys): diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 8a067267a352..48502df2a64d 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -21,9 +21,10 @@ from typing import Any import tvm -from tvm.ir import PrimType +from tvm.ir import GlobalVar, PrimType from tvm.tir import Buffer, IterVar, PrimExpr, Var +from ...ir_builder import ir as I from ...ir_builder import tir as T from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame @@ -473,3 +474,27 @@ def visit_return(self: Parser, node: doc.Return) -> None: The doc AST return node. """ self.report_error(node, "Return is not allowed.") + + +@dispatch.register(token="tir", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar: + """The function declaration step for tir + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Return + The doc AST return node. + """ + + ret_type = None + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + + # Only ret_type is needed for func_signature. + func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) + return I.decl_function(node.name, func_signature) diff --git a/python/tvm/script/relax.py b/python/tvm/script/relax.py new file mode 100644 index 000000000000..2301463059e3 --- /dev/null +++ b/python/tvm/script/relax.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script APIs of TVM Python Package for Relax""" +from .parser.relax import * # pylint: disable=redefined-builtin,unused-wildcard-import,wildcard-import diff --git a/python/tvm/script/tir.py b/python/tvm/script/tir.py new file mode 100644 index 000000000000..49f3ecd42c50 --- /dev/null +++ b/python/tvm/script/tir.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script APIs of TVM Python Package for TIR""" +from .parser.tir import * # pylint: disable=redefined-builtin,unused-wildcard-import,wildcard-import diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 0907ea2ebf85..40fac0f92f6d 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -41,6 +41,7 @@ from .operation import placeholder, compute, scan, extern, var, size_var, const from .operation import thread_axis, reduce_axis from .operation import create_prim_func +from .operation import create_relax_prim_func from .operation import extern_primfunc from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 59bc76f5041e..cfe5e073bae2 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -571,12 +571,64 @@ def create_prim_func( ops: List[_tensor.Tensor], index_dtype_override: Optional[str] = None ) -> tvm.tir.PrimFunc: """Create a TensorIR PrimFunc from tensor expression + Parameters + ---------- + ops : List[Tensor] + The source expression. + Example + ------- + We define a matmul kernel using following code: + .. code-block:: python + import tvm + from tvm import te + from tvm.te import create_prim_func + import tvm.script + A = te.placeholder((128, 128), name="A") + B = te.placeholder((128, 128), name="B") + k = te.reduce_axis((0, 128), "k") + C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") + func = create_prim_func([A, B, C]) + print(func.script()) + If we want to use TensorIR schedule to do transformations on such kernel, + we need to use `create_prim_func([A, B, C])` to create a schedulable PrimFunc. + The generated function looks like: + .. code-block:: python + @T.prim_func + def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + for i, j, k in T.grip(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + Returns + ------- + func : tir.PrimFunc + The created function. + """ + if not isinstance(ops, (list, tuple, Array)): + ops = [ops] + return _ffi_api.CreatePrimFunc(ops, index_dtype_override) + + +def create_relax_prim_func( + ops: List[_tensor.Tensor], + tir_var_list: List[tvm.tir.Var] = None, + index_dtype_override: Optional[str] = None, +) -> tvm.tir.PrimFunc: + """Create a TensorIR PrimFunc from tensor expression Parameters ---------- ops : List[Tensor] The source expression. + tir_var_list: List[Var] + TIR variables to add as parameters to generated PrimFunc + Example ------- We define a matmul kernel using following code: @@ -621,4 +673,4 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: """ if not isinstance(ops, (list, tuple, Array)): ops = [ops] - return _ffi_api.CreatePrimFunc(ops, index_dtype_override) + return _ffi_api.CreateRelaxPrimFunc(ops, tir_var_list, index_dtype_override) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 0fe460c085d7..d24fe8e693e6 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3037,6 +3037,74 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr): return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) +def anylist_getitem(list_handle, index): + """Returns an item from any list. + list_handle: Var + The handle to anylist + index : int + The index + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.anylist_getitem", list_handle, index) + + +def anylist_resetitem(list_handle, index): + """Reset an item from any list. + list_handle: Var + The handle to anylist + index : int + The index + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("int", "tir.anylist_resetitem", list_handle, index) + + +def anylist_setitem_call_packed(list_handle, index, func_name, *args): + """Set anylist item by result of packed call. + list_handle: Var + The handle to anylist + index : int + The index + func_name: str + The name of the function to be called. + args: + Extra arguments + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "int", "tir.anylist_setitem_call_packed", list_handle, index, func_name, *args + ) + + +def anylist_setitem_call_cpacked(list_handle, index, func_name, *args): + """Set anylist item by result of packed call. + list_handle: Var + The handle to anylist + index : int + The index + func_name: str + The name of the function to be called. + args: + Extra arguments + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "int", "tir.anylist_setitem_call_cpacked", list_handle, index, func_name, *args + ) + + # pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py index 9fa0e3bc181f..5e9457fb5391 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tir/transform/function_pass.py @@ -69,6 +69,7 @@ def prim_func_pass( opt_level: int = None, name: Optional[str] = None, required: Optional[List[str]] = None, + traceable=False, ) -> Union[Callable, PrimFuncPass]: """Decorate a function pass. @@ -147,7 +148,7 @@ def transform(func, mod, ctx): def create_function_pass(pass_arg): """Internal function that creates a function pass""" fname = name if name else pass_arg.__name__ - info = PassInfo(opt_level, fname, required) + info = PassInfo(opt_level, fname, required, traceable) if inspect.isclass(pass_arg): return _wrap_class_function_pass(pass_arg, info) if not callable(pass_arg): diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 1df2ac76b5b4..612be133de67 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -578,6 +578,21 @@ def NarrowDataType(target_bits: int): return _ffi_api.NarrowDataType(target_bits) # type: ignore +def ForceNarrowIndexToInt32(): + """Force narrow down indexing expressions and integer buffers to int32 dtype. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + + Note + ---- + This pass should not be used in default cases. + """ + return _ffi_api.ForceNarrowIndexToInt32() # type: ignore + + def VerifyMemory(): """Verify if func contains illegal host side direct memory access. @@ -1018,3 +1033,21 @@ def InstallDebugSpans(): The result pass """ return _ffi_api.InstallDebugSpans() # type: ignore + + +def DefaultGPUSchedule(): + """The pass sets default thread bindings for PrimFuncs, including symbolic shape functions, + allowing their build and execution on GPU devices. It examines all the blocks within the + PrimFunc and conducts loop fusion, splitting, and reordering operation based on the loop + extent and target information, such as the maximum thread block number and maximum thread + per block. + + The primary objective of this pass is not to optimize performance, but rather to generate + a valid GPU kernel for unscheduled or symbolic shape PrimFuncs. The pass is currently only + working for CUDA targets. + + Returns + ------- + ret: tvm.transform.Pass + """ + return _ffi_api.DefaultGPUSchedule() # type: ignore diff --git a/python/tvm/topi/hexagon/qnn/nn.py b/python/tvm/topi/hexagon/qnn/nn.py index e60314b82757..1a707cef7ee6 100644 --- a/python/tvm/topi/hexagon/qnn/nn.py +++ b/python/tvm/topi/hexagon/qnn/nn.py @@ -38,24 +38,49 @@ def clip_cast(val, dtype): return te.max(tvm.te.min(val, const_max), const_min).astype(dtype) +def is_relax_constant(expr): + return hasattr(expr.op, "value") and isinstance(expr.op.value, tvm.relax.expr.Constant) + + # Return True if given expression is scalar constant value. def is_scalar(expr): + """ + Return True if given expression is scalar constant value. + """ if isinstance(expr, te.Tensor): - return expr.ndim == 0 and (isinstance(expr.op.body[0], (tvm.tir.FloatImm, tvm.tir.IntImm))) + if is_relax_constant(expr): + shape = expr.op.value.data.shape + dtype = expr.op.value.data.dtype + return len(shape) == 0 and ("float" in dtype or "int" in dtype) + else: + return expr.ndim == 0 and ( + isinstance(expr.op.body[0], (tvm.tir.FloatImm, tvm.tir.IntImm)) + ) return isinstance(expr, (tvm.tir.FloatImm, tvm.tir.IntImm)) +def get_relax_scalar_const_value(expr): + assert len(expr.op.value.data.shape) == 0 + return expr.op.value.data.numpy()[()] + + def get_const_int_value(expr): if isinstance(expr, te.Tensor): - assert isinstance(expr.op.body[0], tvm.tir.IntImm) - return expr.op.body[0].value + if is_relax_constant(expr): + return get_relax_scalar_const_value(expr) + else: + assert isinstance(expr.op.body[0], tvm.tir.IntImm) + return expr.op.body[0].value return get_const_int(expr) def get_const_float_value(expr): if isinstance(expr, te.Tensor): - assert isinstance(expr.op.body[0], tvm.tir.FloatImm) - return expr.op.body[0].value + if is_relax_constant(expr): + return get_relax_scalar_const_value(expr) + else: + assert isinstance(expr.op.body[0], tvm.tir.FloatImm) + return expr.op.body[0].value return get_const_float(expr) @@ -224,7 +249,7 @@ def _compute(*indices): # Add output zero point + clip + cast: return saturate(te.add(mul, output_zp), out_dtype).astype(out_dtype) - return te.compute(data.shape, _compute, name="requantize") + return te.compute(data.shape, _compute, name="requantize_scalar") else: @@ -285,8 +310,8 @@ def _compute_const(x: te.Tensor, iscale, input_zp): def _compute_tensor(x: te.Tensor, input_scale, input_zp): if is_scalar(input_scale) and is_scalar(output_scale): - iscale = input_scale.op.body[0].value - oscale = output_scale.op.body[0].value + iscale = get_const_float_value(input_scale) + oscale = get_const_float_value(output_scale) scale = iscale / oscale scale_fixed_point, rsh = get_fixed_point_value(scale, "int16") return te.compute( @@ -406,7 +431,9 @@ def _compute_tensor(tensor, zero_point): if is_scalar(lhs_scale) and is_scalar(rhs_scale): assert isinstance(lhs_scale, te.Tensor) assert isinstance(rhs_scale, te.Tensor) - iscale = lhs_scale.op.body[0] * rhs_scale.op.body[0] + iscale = get_const_float_value(lhs_scale.op.body[0]) * get_const_float_value( + rhs_scale.op.body[0] + ) else: iscale = lhs_scale * rhs_scale diff --git a/python/tvm/topi/nn/group_norm.py b/python/tvm/topi/nn/group_norm.py index c6358b8bc6ff..ea9d5da0770c 100644 --- a/python/tvm/topi/nn/group_norm.py +++ b/python/tvm/topi/nn/group_norm.py @@ -20,6 +20,8 @@ def group_norm(data, gamma, beta, num_groups, channel_axis, axes, epsilon=1e-5): """Group normalization operator. + It accepts fp16 and fp32 as input data type. It will cast the input to fp32 + to perform the computation. The output will have the same data type as input. Parameters ---------- diff --git a/python/tvm/topi/nn/layer_norm.py b/python/tvm/topi/nn/layer_norm.py index 3bdeaaac61a5..7363f99c4950 100644 --- a/python/tvm/topi/nn/layer_norm.py +++ b/python/tvm/topi/nn/layer_norm.py @@ -20,6 +20,8 @@ def layer_norm(data, gamma, beta, axis, epsilon=1e-5): """Layer normalization operator. + It accepts fp16 and fp32 as input data type. It will cast the input to fp32 + to perform the computation. The output will have the same data type as input. Parameters ---------- diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py index 45d07af577a3..5045cb817457 100644 --- a/python/tvm/topi/reduction.py +++ b/python/tvm/topi/reduction.py @@ -248,3 +248,34 @@ def prod(data, axis=None, keepdims=False): ret : tvm.te.Tensor """ return cpp.prod(data, axis, keepdims) + + +def collapse_sum(data, target_shape): + """Return a summation of data to the given shape. + + collapse_sum is intended as the backward operator of topi broadcast operators in the automatic + differentiation process. + + We expect that data is the result of broadcasting some tensor of target_shape in some + broadcast operation. Thus target_shape and data.shape must follow broadcast rules. + + During computation, the axes of data.shape and target_shape are checked from right to left. + For every axis, if it either: + - exist in data but not in target_shape, or + - is larger than 1 in data and equals to 1 in target_shape, + data will be summed over this axis. + + Parameters + ---------- + data : tvm.te.Tensor + The input tensor. + + shape : Tuple[int] + The shape to collapse to. + + Returns + ------- + ret : tvm.te.Tensor + The result tensor after summation. + """ + return cpp.collapse_sum(data, target_shape) diff --git a/python/tvm/topi/scan.py b/python/tvm/topi/scan.py index 32a7e297b04c..22f9ff58a57f 100644 --- a/python/tvm/topi/scan.py +++ b/python/tvm/topi/scan.py @@ -151,7 +151,7 @@ def gen_ir(data_buf, out_buf): def cumsum( data: tvm.te.Tensor, axis: Optional[int] = None, - dtype: Optional[int] = None, + dtype: Optional[str] = None, exclusive: Optional[bool] = None, ) -> tvm.te.Tensor: """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis. diff --git a/python/tvm/topi/testing/group_norm_python.py b/python/tvm/topi/testing/group_norm_python.py index d1c0d4a6abcc..7677348426ff 100644 --- a/python/tvm/topi/testing/group_norm_python.py +++ b/python/tvm/topi/testing/group_norm_python.py @@ -51,10 +51,11 @@ def group_norm_python(data, gamma, beta, num_groups, channel_axis, axes, epsilon N-D with shape (d_0, d_1, ..., d_{N-1}) """ old_shape = data.shape + old_dtype = data.dtype new_shape = list(old_shape) new_shape[channel_axis] = data.shape[channel_axis] // num_groups new_shape.insert(channel_axis, num_groups) - data = np.reshape(data, new_shape) + data = np.reshape(data, new_shape).astype("float32") new_axes = [channel_axis + 1] for axis in axes: if axis < channel_axis: @@ -64,7 +65,7 @@ def group_norm_python(data, gamma, beta, num_groups, channel_axis, axes, epsilon mean = np.mean(data, axis=tuple(new_axes), keepdims=True) var = np.var(data, axis=tuple(new_axes), keepdims=True) data = (data - mean) / np.sqrt(var + epsilon) - data = np.reshape(data, old_shape) + data = np.reshape(data, old_shape).astype(old_dtype) gamma_broadcast_shape = [1 for _ in range(len(old_shape))] gamma_broadcast_shape[channel_axis] = gamma.shape[0] diff --git a/python/tvm/topi/testing/layer_norm_python.py b/python/tvm/topi/testing/layer_norm_python.py index 6b3b00146983..662383363b92 100644 --- a/python/tvm/topi/testing/layer_norm_python.py +++ b/python/tvm/topi/testing/layer_norm_python.py @@ -44,9 +44,12 @@ def layer_norm_python(data, gamma, beta, axis, epsilon=1e-5): result : np.ndarray N-D with shape (d_0, d_1, ..., d_{N-1}) """ + old_dtype = data.dtype + data = data.astype("float32") mean = np.mean(data, axis, keepdims=True) var = np.var(data, axis, keepdims=True) result = (data - mean) / np.sqrt(var + epsilon) + result = result.astype(old_dtype) result *= gamma if beta is not None: result += beta diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index abc25e89c48c..08ce082c4586 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -40,6 +40,7 @@ pub mod attrs; pub struct ExprNode { pub base: BaseExprNode, pub checked_type: Type, + pub struct_info: ObjectRef, pub virtual_device: ObjectRef, } @@ -48,6 +49,7 @@ impl ExprNode { ExprNode { base: BaseExprNode::base::(span.clone()), checked_type: Type::null(), + struct_info: ObjectRef::null(), virtual_device: ObjectRef::null(), } } diff --git a/src/ir/function.cc b/src/ir/function.cc index ce294708b2a9..59f94201b241 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -22,6 +22,8 @@ * \brief The function data structure. */ #include +#include +#include #include #include @@ -35,13 +37,45 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> BaseFunc { if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); } - if (const auto* f = runtime::Registry::Get("relay.ir.FuncWithAttr")) { - if (Optional ret = (*f)(func, key, value)) { + }); + +TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") + .set_body_typed([](BaseFunc func, Map attr_map) -> BaseFunc { + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } + if (const auto* f = runtime::Registry::Get("relay.ir.FuncWithAttrs")) { + if (Optional ret = (*f)(func, attr_map)) { + return ret.value(); + } + } + if (const auto* f = runtime::Registry::Get("relax.FuncWithAttrs")) { + if (Optional ret = (*f)(func, attr_map)) { return ret.value(); } } LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); }); +TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") + .set_body_typed([](BaseFunc func, String key) -> BaseFunc { + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + return func; + } + }); + } // namespace tvm diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc new file mode 100644 index 000000000000..48f56d60d68c --- /dev/null +++ b/src/ir/global_info.cc @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/global_info.cc + * \brief Module global info. + */ + +#include +namespace tvm { +TVM_REGISTER_NODE_TYPE(DummyGlobalInfoNode); +TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() { + auto n = DummyGlobalInfo(make_object()); + return n; +}); +} // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index 7a973da29dfa..4455e2bf2c8d 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -34,7 +34,8 @@ namespace tvm { IRModule::IRModule(tvm::Map functions, tvm::Map type_definitions, - std::unordered_set import_set, SourceMap source_map, DictAttrs attrs) { + std::unordered_set import_set, SourceMap source_map, DictAttrs attrs, + Map> global_infos) { auto n = make_object(); n->functions = std::move(functions); n->type_definitions = std::move(type_definitions); @@ -44,6 +45,7 @@ IRModule::IRModule(tvm::Map functions, n->import_set_ = std::move(import_set); n->source_map = source_map; n->attrs = std::move(attrs); + n->global_infos = std::move(global_infos); for (const auto& kv : n->functions) { // set global var map @@ -65,6 +67,11 @@ IRModule::IRModule(tvm::Map functions, bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { if (!equal(this->attrs, other->attrs)) return false; + if (this->global_infos.size() != other->global_infos.size()) return false; + for (const auto& kv : this->global_infos) { + if (!equal(kv.second, other->global_infos[kv.first])) return false; + } + if (functions.size() != other->functions.size()) return false; // Update GlobalVar remap for (const auto& gv : this->GetGlobalVars()) { @@ -139,6 +146,7 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { } reduce_temp(); hash_reduce(this->attrs); + hash_reduce(this->global_infos); } bool IRModuleNode::ContainGlobalVar(const String& name) const { @@ -262,6 +270,10 @@ void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) this->AddTypeDef(var, type, true); } +void IRModuleNode::UpdateGlobalInfo(const String& name, const Array& info) { + this->global_infos.Set(name, info); +} + void IRModuleNode::Remove(const GlobalVar& var) { auto functions_node = this->functions.CopyOnWrite(); functions_node->erase(var); @@ -382,9 +394,9 @@ IRModule IRModule::FromText(const String& text, const String& source_path) { TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") - .set_body_typed([](tvm::Map funcs, - tvm::Map types) { - return IRModule(funcs, types, {}); + .set_body_typed([](tvm::Map funcs, tvm::Map types, + tvm::DictAttrs attrs, Map> global_infos) { + return IRModule(funcs, types, {}, {}, attrs, global_infos); }); TVM_REGISTER_GLOBAL("ir.Module_Add") @@ -446,6 +458,11 @@ TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction") .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); +TVM_REGISTER_GLOBAL("ir.Module_UpdateGlobalInfo") + .set_body_typed([](IRModule mod, String name, Array global_info) { + mod->UpdateGlobalInfo(name, global_info); + }); + TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) { mod->Import(path); }); @@ -454,11 +471,23 @@ TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, S mod->ImportFromStd(path); }); +TVM_REGISTER_GLOBAL("ir.Module_GetAttrs").set_body_typed([](IRModule mod) -> ObjectRef { + return mod->GetAttrs(); +}); + TVM_REGISTER_GLOBAL("ir.Module_WithAttr") .set_body_typed([](IRModule mod, String key, ObjectRef value) -> IRModule { return WithAttr(mod, key, value); }); +TVM_REGISTER_GLOBAL("ir.Module_WithoutAttr") + .set_body_typed([](IRModule mod, String key) -> IRModule { return WithoutAttr(mod, key); }); + +TVM_REGISTER_GLOBAL("ir.Module_WithAttrs") + .set_body_typed([](IRModule mod, Map attr_map) -> IRModule { + return WithAttrs(mod, attr_map); + }); + TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, String key) -> ObjectRef { return mod->GetAttr(key); }); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 66b06e6b505d..619526d0b56b 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -341,11 +342,13 @@ class ModulePass : public Pass { TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); }; -PassInfo::PassInfo(int opt_level, String name, tvm::Array required) { +PassInfo::PassInfo(int opt_level, String name, tvm::Array required, + bool traceable) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); pass_info->required = std::move(required); + pass_info->traceable = std::move(traceable); data_ = std::move(pass_info); } @@ -401,7 +404,7 @@ Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { Sequential::Sequential(tvm::Array passes, String name) { auto n = make_object(); n->passes = std::move(passes); - PassInfo pass_info = PassInfo(0, std::move(name), {}); + PassInfo pass_info = PassInfo(0, std::move(name), {}, /* traceable */ false); n->pass_info = std::move(pass_info); data_ = std::move(n); } @@ -444,26 +447,61 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c VLOG(0) << "skipping disabled pass '" << pass_info->name << "'"; continue; } + // resolve dependencies for (const auto& it : pass_info->required) { mod = GetPass(it)(std::move(mod), pass_ctx); } - mod = pass(std::move(mod), pass_ctx); + + // This handles passes that does not use Relax tuning API (untraceable passes). + // We make untraceable passes trackable when pass context has a trace (trace mode). + // When passes to trace (make_traceable) is provided from users, we only make them trackable. + if (pass_ctx->trace_stack.size() && !pass_info->traceable && + (!pass_ctx->make_traceable.defined() || + pass_ctx->make_traceable.value().count(pass_info->name))) { + // TODO(tvm-team): Currently, there are some inconsistency in the pass registration. + // 1. Some passes are not registered in ffi registry. + // 2. Some passes do not follow the name convention. (e.g., = + ) + + // Due to these problems, serialization with non-traceable passes is handled in a hacky way + // now. Find a systematic way to identify such inconsistencies and fix them. + + // In the future, we should pass the ffi key for a pass by deducing from its name. + String transform_func_key = "relax.tuning_api.Choice.default_transform_func"; + String constr_func_key = "relax.tuning_api.Choice.default_constr_func"; + + relax::Knob knob = relax::Knob( + pass_info->name, {{"Applied", relax::Choice(transform_func_key, Array(), + constr_func_key, Array())}}); + + // Add new decision to the trace at the top of the stack. + auto trace = Downcast(pass_ctx->trace_stack.back()); + trace->Add(knob, "Applied"); + // In the future, we should just have + // mod = trace->Add(knob, "enabled"); + // instead of the two lines below. + mod = pass(std::move(mod), pass_ctx); + trace->SetOutMod(mod); + + } else { + mod = pass(std::move(mod), pass_ctx); + } } return mod; } Pass CreateModulePass(const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return ModulePass(pass_func, pass_info); } TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("transform.PassInfo") - .set_body_typed([](int opt_level, String name, tvm::Array required) { - return PassInfo(opt_level, name, required); + .set_body_typed([](int opt_level, String name, tvm::Array required, bool traceable) { + return PassInfo(opt_level, name, required, traceable); }); TVM_REGISTER_GLOBAL("transform.Info").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -514,7 +552,8 @@ TVM_REGISTER_GLOBAL("transform.Sequential").set_body([](TVMArgs args, TVMRetValu int opt_level = args[1]; std::string name = args[2]; tvm::Array required = args[3]; - PassInfo pass_info = PassInfo(opt_level, name, required); + bool traceable = args[4]; + PassInfo pass_info = PassInfo(opt_level, name, required, /* traceable */ traceable); *ret = Sequential(passes, pass_info); }); @@ -537,7 +576,9 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, Array instruments, - Optional> config) { + Optional> config, Array trace_stack, + Optional> make_traceable, int num_evals, + Optional tuning_api_database) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; @@ -547,6 +588,10 @@ TVM_REGISTER_GLOBAL("transform.PassContext") if (config.defined()) { pctx->config = config.value(); } + pctx->trace_stack = std::move(trace_stack); + pctx->make_traceable = std::move(make_traceable); + pctx->num_evals = std::move(num_evals); + pctx->tuning_api_database = std::move(tuning_api_database); PassConfigManager::Global()->Legalize(&(pctx->config)); return pctx; }); @@ -562,7 +607,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "\tdisabled passes: " << node->disabled_pass << "\n"; p->stream << "\tinstruments: " << node->instruments << "\n"; - p->stream << "\tconfig: " << node->config; + p->stream << "\tconfig: " << node->config << "\n"; + p->stream << "\ttrace stack: " << node->trace_stack; }); class PassContext::Internal { @@ -572,6 +618,22 @@ class PassContext::Internal { static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; +TVM_REGISTER_GLOBAL("transform.GetTraceStack") + .set_body_method(&PassContextNode::GetTraceStack); +TVM_REGISTER_GLOBAL("transform.PushTrace") + .set_body_method(&PassContextNode::PushTrace); +TVM_REGISTER_GLOBAL("transform.PopTrace").set_body_method(&PassContextNode::PopTrace); +TVM_REGISTER_GLOBAL("transform.GetTraceStackSize") + .set_body_method(&PassContextNode::GetTraceStackSize); +TVM_REGISTER_GLOBAL("transform.GetCurrentTrace") + .set_body_method(&PassContextNode::GetCurrentTrace); +TVM_REGISTER_GLOBAL("transform.SetNumEvals") + .set_body_method(&PassContextNode::SetNumEvals); +TVM_REGISTER_GLOBAL("transform.IncNumEvals") + .set_body_method(&PassContextNode::IncNumEvals); +TVM_REGISTER_GLOBAL("transform.GetTuningAPIDatabase") + .set_body_method(&PassContextNode::GetTuningAPIDatabase); + TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope); @@ -595,7 +657,7 @@ Pass PrintIR(String header, bool show_meta_data) { LOG(INFO) << "PrintIR(" << header << "):\n" << mod; return mod; }; - return CreateModulePass(pass_func, 0, "PrintIR", {}); + return CreateModulePass(pass_func, 0, "PrintIR", {}, /* traceable */ false); } TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); diff --git a/src/ir/type.cc b/src/ir/type.cc index d965406e8bb0..b61a3df09107 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -25,9 +25,10 @@ #include namespace tvm { -PrimType::PrimType(runtime::DataType dtype) { +PrimType::PrimType(runtime::DataType dtype, Span span) { ObjectPtr n = make_object(); n->dtype = dtype; + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 491af6e28f77..1bde99869ed2 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -21,68 +21,6 @@ namespace tvm { namespace meta_schedule { -/*! \brief Collecting all the blocks */ -class BlockCollector : public tir::StmtVisitor { - public: - static Array Collect(const tir::Schedule& sch, - const runtime::PackedFunc f_block_filter = nullptr) { // - return BlockCollector(sch, f_block_filter).Run(); - } - - private: - /*! \brief Entry point */ - Array Run() { - std::vector results; - for (const auto& kv : sch_->mod()->functions) { - const GlobalVar& gv = kv.first; // `gv->name_hint` is the name of the function - const BaseFunc& base_func = kv.second; // this can be PrimFunc or relay::Function - if (const auto* func = base_func.as()) { - func_name_ = gv->name_hint; - block_names_.clear(); - blocks_to_collect_.clear(); - VisitStmt(func->body); - for (const String& name : blocks_to_collect_) { - results.push_back(sch_->GetBlock(name, func_name_)); - } - } - } - return results; - } - /*! \brief Constructor */ - explicit BlockCollector(const tir::Schedule& sch, - const runtime::PackedFunc f_block_filter = nullptr) - : sch_(sch), f_block_filter_(f_block_filter) {} - /*! \brief Override the Stmt visiting behaviour */ - void VisitStmt_(const tir::BlockNode* block) override { - tir::StmtVisitor::VisitStmt_(block); - CHECK(block_names_.count(block->name_hint) == 0) - << "Duplicated block name " << block->name_hint << " in function " << func_name_ - << " not supported!"; - block_names_.insert(block->name_hint); - - // If filter function is provided, use it to selectively collect blocks. - // Otherwise collect all blocks. - Bool collect_block = Bool(true); - if (f_block_filter_ != nullptr) { - collect_block = f_block_filter_(GetRef(block)); - } - if (collect_block) { - blocks_to_collect_.push_back(block->name_hint); - } - } - - /*! \brief The schedule to be collected */ - const tir::Schedule& sch_; - /*! \brief An optional packed func that allows only certain blocks to be collected. */ - const runtime::PackedFunc f_block_filter_; - /*! \brief The set of func name and block name pair */ - std::unordered_set block_names_; - /* \brief The list of blocks to collect in order */ - Array blocks_to_collect_; - /*! \brief Name of the current PrimFunc */ - String func_name_; -}; - /*! * \brief Design Space Generator that generates design spaces by applying schedule rules to blocks * in post-DFS order. diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 9a372dde8f6d..955381b740c8 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -554,6 +554,68 @@ inline double Sum(const Array& arr) { return sum; } +/*! \brief Collecting all the blocks */ +class BlockCollector : public tir::StmtVisitor { + public: + static Array Collect(const tir::Schedule& sch, + const runtime::PackedFunc f_block_filter = nullptr) { // + return BlockCollector(sch, f_block_filter).Run(); + } + + private: + /*! \brief Entry point */ + Array Run() { + std::vector results; + for (const auto& [gv, base_func] : sch_->mod()->functions) { + // `gv->name_hint` is the name of the function + // `base_func` can be PrimFunc or relay::Function + if (const auto* func = base_func.as()) { + func_name_ = gv->name_hint; + block_names_.clear(); + blocks_to_collect_.clear(); + VisitStmt(func->body); + for (const String& name : blocks_to_collect_) { + results.push_back(sch_->GetBlock(name, func_name_)); + } + } + } + return results; + } + /*! \brief Constructor */ + explicit BlockCollector(const tir::Schedule& sch, + const runtime::PackedFunc f_block_filter = nullptr) + : sch_(sch), f_block_filter_(f_block_filter) {} + /*! \brief Override the Stmt visiting behaviour */ + void VisitStmt_(const tir::BlockNode* block) override { + tir::StmtVisitor::VisitStmt_(block); + CHECK(block_names_.count(block->name_hint) == 0) + << "Duplicated block name " << block->name_hint << " in function " << func_name_ + << " not supported!"; + block_names_.insert(block->name_hint); + + // If filter function is provided, use it to selectively collect blocks. + // Otherwise collect all blocks. + Bool collect_block = Bool(true); + if (f_block_filter_ != nullptr) { + collect_block = f_block_filter_(GetRef(block)); + } + if (collect_block) { + blocks_to_collect_.push_back(block->name_hint); + } + } + + /*! \brief The schedule to be collected */ + const tir::Schedule& sch_; + /*! \brief An optional packed func that allows only certain blocks to be collected. */ + const runtime::PackedFunc f_block_filter_; + /*! \brief The set of func name and block name pair */ + std::unordered_set block_names_; + /* \brief The list of blocks to collect in order */ + Array blocks_to_collect_; + /*! \brief Name of the current PrimFunc */ + String func_name_; +}; + } // namespace meta_schedule } // namespace tvm diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 8293af402ed9..e09ce266b340 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -21,6 +21,8 @@ #include #include +#include + namespace tvm { TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { @@ -35,6 +37,11 @@ std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional config_dict) { runtime::ObjectPtr n = make_object(); if (auto v = config_dict.Get("name")) { @@ -49,7 +56,12 @@ PrinterConfig::PrinterConfig(Map config_dict) { if (auto v = config_dict.Get("tir_prefix")) { n->tir_prefix = Downcast(v); } - + if (auto v = config_dict.Get("relax_prefix")) { + n->relax_prefix = Downcast(v); + } + if (auto v = config_dict.Get("module_alias")) { + n->module_alias = Downcast(v); + } if (auto v = config_dict.Get("buffer_dtype")) { n->buffer_dtype = DataType(runtime::String2DLDataType(Downcast(v))); } @@ -88,9 +100,25 @@ PrinterConfig::PrinterConfig(Map config_dict) { if (auto v = config_dict.Get("syntax_sugar")) { n->syntax_sugar = Downcast(v)->value; } + + // Checking prefixes if they are valid Python identifiers. + CHECK(IsIdentifier(n->ir_prefix)) << "Invalid `ir_prefix`: " << n->ir_prefix; + CHECK(IsIdentifier(n->tir_prefix)) << "Invalid `tir_prefix`: " << n->tir_prefix; + CHECK(IsIdentifier(n->relax_prefix)) << "Invalid `relax_prefix`: " << n->relax_prefix; + CHECK(n->module_alias.empty() || IsIdentifier(n->module_alias)) + << "Invalid `module_alias`: " << n->module_alias; + this->data_ = std::move(n); } +Array PrinterConfigNode::GetBuiltinKeywords() { + Array result{this->ir_prefix, this->tir_prefix, this->relax_prefix}; + if (!this->module_alias.empty()) { + result.push_back(this->module_alias); + } + return result; +} + TVM_REGISTER_NODE_TYPE(PrinterConfigNode); TVM_REGISTER_GLOBAL("node.PrinterConfig").set_body_typed([](Map config_dict) { return PrinterConfig(config_dict); diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc new file mode 100644 index 000000000000..4132039a5e34 --- /dev/null +++ b/src/relax/analysis/analysis.cc @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file analysis.cc + * + * \brief Analysis functions for Relax. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { + +template +struct InsertionSet { + std::unordered_set set; + std::vector data; + void Insert(const T& t) { + if (set.count(t) == 0) { + set.insert(t); + data.push_back(t); + } + } +}; + +class VarVisitor : protected ExprVisitor { + public: + Array Free(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + if (bound_vars_.set.count(v) == 0) { + ret.push_back(v); + } + } + return ret; + } + + Array Collect() { + Array ret; + for (const auto& v : bound_vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array Bound(const Expr& expr) { + this->VisitExpr(expr); + return Collect(); + } + + Array All(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array AllGlobalVars(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : global_vars_.data) { + ret.push_back(v); + } + return ret; + } + + void MarkBounded(const Var& v) { + bound_vars_.Insert(v); + vars_.Insert(v); + } + + void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef(var)); } + + void VisitExpr_(const FunctionNode* op) final { + for (const auto& param : op->params) { + MarkBounded(param); + } + VisitExpr(op->body); + } + + void VisitExpr_(const GlobalVarNode* op) final { global_vars_.Insert(GetRef(op)); } + + void VisitExpr_(const CallNode* call_node) final { + VisitSpan(call_node->span); + VisitExpr(call_node->op); + + for (StructInfo sinfo_arg : call_node->sinfo_args) { + VisitExprDepStructInfoField(sinfo_arg); + } + + for (Expr arg : call_node->args) { + VisitExpr(arg); + } + } + + void VisitBinding_(const VarBindingNode* binding) final { + MarkBounded(binding->var); + VisitExpr(binding->value); + VisitVarDef(binding->var); + } + + void VisitBinding_(const MatchCastNode* binding) final { + MarkBounded(binding->var); + ExprVisitor::VisitBinding_(binding); + } + + private: + InsertionSet vars_; + InsertionSet bound_vars_; + InsertionSet global_vars_; +}; + +tvm::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } + +tvm::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } + +tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } + +tvm::Array AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); } + +TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); + +TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); + +TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); + +TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/detect_recursion.cc b/src/relax/analysis/detect_recursion.cc new file mode 100644 index 000000000000..9c150fed8bfd --- /dev/null +++ b/src/relax/analysis/detect_recursion.cc @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file detect_recursion.cc + * + * \brief Analysis to detect global recursive or mutually recursive functions. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { + +/* + * General approach to detecting recursion: + * Suppose we have a dependency graph of global functions, + * where function A depends on function B if A contains a reference to B + * (i.e., an edge A->B means A references B). If function A is recursive, + * then it has a self-edge A->A. + * + * Note that the call can happen _anywhere_ in the function's body: + * All that is important for mutual recursion is that one function + * needs the other to be in scope (it needs to know about it) to define + * the body. This includes calls that happen inside local function definitions, + * branches that may not execute, etc. + * + * Then detecting simple recursion and mutual recursion is a problem of cycle + * detection: Functions F1, F2, ..., Fn are mutually recursive if there exists + * a single directed cycle that contains all of them. + * + * We aim to find the _largest_ directed cycles in the graph, as there can + * be smaller cycles within the larger ones, as in the following example: + * + * A <-> B <-> C + * ^ | ^ + * | v | + * | D | + * | | | + * v v v + * E <-> F <-> G + * + * Handling a case like this in a directed graph is very difficult + * because most simple algorithms (variations of DFS) aim to find the smallest + * cycle, but in this case, we have multiple cycles that go through nodes multiple times: + * E.g., A->B->D->F->E->A, B->C->G->F->D->B, and A->B->C->G->F->E->A. + * However, we would consider _all_ of these nodes to be mutually recursive, + * and there is a single cycle: A->B->C->G->F->E->A->B->D->F->E->A (must go through A twice) + * + * We can use Johnson's elementary circuit-finding algorithm (1975): + * https://epubs.siam.org/doi/10.1137/0204007 + * and find all elementary circuits in the graph, which are cycles that go + * through nodes at most once. + * + * With all the elementary cycles, we can coalesce different cycles that involve the + * same node, which would all form a group of mutually recursive functions + */ + +class DependencyGatherer : public ExprVisitor { + public: + explicit DependencyGatherer(const IRModule& m) : m_(m) {} + + std::unordered_set Track(const Function& func) { + this->VisitExpr(func); + return deps_; + } + + void VisitExpr_(const GlobalVarNode* gv) override { + // disregard PrimFuncs + if (!m_->Lookup(GetRef(gv)).as()) { + return; + } + deps_.insert(gv->name_hint); + } + + private: + std::unordered_set deps_; + const IRModule& m_; +}; + +using adjacency_map = std::unordered_map>; +using node_set = std::unordered_set; +using adjacency_index = std::vector; + +adjacency_map GatherDependencyGraph(const IRModule& m) { + adjacency_map ret; + for (auto gv_func : m->functions) { + const relax::FunctionNode* func = gv_func.second.as(); + // disregard PrimFuncs and the like + if (!func) { + continue; + } + std::string name = gv_func.first->name_hint; + auto deps = DependencyGatherer(m).Track(GetRef(func)); + ret.insert({name, deps}); + } + return ret; +} + +// the graph algorithm pseudocode assumes vertices are indices and makes use of the fact you can +// increment them, so for ease, we convert the strings to indices by some ordering +adjacency_index ConvertToIndices(const adjacency_map& graph, + const std::vector& ordering) { + adjacency_index ret; + for (size_t i = 0; i < ordering.size(); i++) { + std::string current = ordering[i]; + node_set neighbors; + for (size_t j = 0; j < ordering.size(); j++) { + if (graph.at(current).count(ordering[j])) { + neighbors.insert(j); + } + } + ret.push_back(neighbors); + } + return ret; +} + +/********* Strongly connected component (SCC) detection, needed for Johnson's algorithm *********/ +// Based on the pseudocode for Tarjan's SCC detection algorithm +// See: https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm + +// Modification: We take a min_vert parameter to ignore all vertices below that. +// This is because Johnson's algorithm searches for SCCs on a subgraph of +// all vertices after a certain one (per some arbitrary ordering) + +void StronglyConnect(size_t node, const adjacency_index& graph, size_t min_vert, + // use signed ints so that -1 can indicate undefined/unvisited + std::vector* indices, std::vector* low_links, + std::vector* stack, std::vector* on_stack, + std::vector* sccs, int* running_index) { + indices->operator[](node) = *running_index; + low_links->operator[](node) = *running_index; + (*running_index)++; + stack->push_back(node); + on_stack->operator[](node) = true; + + auto children = graph.at(node); + for (auto child : children) { + // ignore children outside the verts we are checking + if (child < min_vert) { + continue; + } + if (indices->at(child) == -1) { + StronglyConnect(child, graph, min_vert, indices, low_links, stack, on_stack, sccs, + running_index); + low_links->operator[](node) = std::min(low_links->at(node), low_links->at(child)); + } else if (on_stack->at(child)) { + low_links->operator[](node) = std::min(low_links->at(node), indices->at(child)); + } + } + + // root node -> have found an SCC + if (low_links->at(node) == indices->at(node)) { + node_set scc; + size_t m; + do { + m = stack->back(); + stack->pop_back(); + on_stack->operator[](m) = false; + scc.insert(m); + } while (m != node); + sccs->push_back(scc); + } +} + +std::vector FindStronglyConnectedComponents(const adjacency_index& graph, + size_t min_vert) { + std::vector stack; + std::vector sccs; + int running_index = 0; + + std::vector indices; + std::vector low_links; + std::vector on_stack; + for (size_t i = 0; i < graph.size(); i++) { + indices.push_back(-1); + low_links.push_back(-1); + on_stack.push_back(false); + } + + for (size_t i = min_vert; i < graph.size(); i++) { + StronglyConnect(i, graph, min_vert, &indices, &low_links, &stack, &on_stack, &sccs, + &running_index); + } + return sccs; +} + +/********* Helper functions needed for Johnson's algorithm *********/ + +// return strongly connected componenet containing the least vertex +node_set GetLeastSCC(const std::vector& sccs) { + int min_idx = 0; + bool min_found = false; + size_t min = 0; + for (size_t i = 0; i < sccs.size(); i++) { + if (!min_found) { + min = *(sccs[i].begin()); + min_found = true; + min_idx = i; + } + + for (size_t v : sccs[i]) { + if (v < min) { + min = v; + min_idx = i; + } + } + } + return sccs[min_idx]; +} + +size_t LeastVertex(const node_set& scc) { + bool min_found = false; + size_t min = 0; + for (size_t v : scc) { + if (!min_found) { + min = v; + min_found = true; + } + if (v < min) { + min = v; + } + } + return min; +} + +/********* Johnson's algorithm implementation *********/ +// implementation is based directly on the pseudocode from +// "Finding All the Elementary Circuits of a Directed Graph" (Johnson, 1975) + +void Unblock(std::vector* blocked, std::vector* blocked_nodes, size_t node) { + blocked->operator[](node) = false; + // copy so we don't modify the set we're iterating on + auto blocked_on_node = node_set(blocked_nodes->at(node)); + for (auto blocked_node : blocked_on_node) { + blocked_nodes->at(node).erase(blocked_node); + if (blocked->at(blocked_node)) { + Unblock(blocked, blocked_nodes, blocked_node); + } + } +} + +bool CheckCircuit(const adjacency_index& graph, const node_set& scc, + std::vector* blocked_nodes, std::vector* blocked, + std::vector* current_stack, std::vector* found_circuits, + size_t s, size_t v) { + bool f = false; + blocked->operator[](v) = true; + current_stack->push_back(v); + for (size_t child : graph[v]) { + // ignore any node that's not in the SCC: + // the algorithm considers only the subgraph pertaining to the SCC + if (!scc.count(child)) { + continue; + } + if (child == s) { + // we found a circuit, so report it + auto new_circuit = node_set(current_stack->begin(), current_stack->end()); + new_circuit.insert(s); + found_circuits->push_back(new_circuit); + f = true; + } else if (!blocked->at(child)) { + if (CheckCircuit(graph, scc, blocked_nodes, blocked, current_stack, found_circuits, s, + child)) { + f = true; + } + } + } + if (f) { + Unblock(blocked, blocked_nodes, v); + } else { + for (size_t child : graph[v]) { + if (!scc.count(child)) { + continue; + } + if (!blocked_nodes->at(child).count(v)) { + blocked_nodes->at(child).insert(v); + } + } + } + current_stack->pop_back(); + return f; +} + +std::vector DetectElementaryCircuits(const adjacency_index& graph) { + std::vector blocked_nodes; + for (size_t i = 0; i < graph.size(); i++) { + blocked_nodes.push_back(node_set()); + } + + std::vector blocked; + for (size_t i = 0; i < graph.size(); i++) { + blocked.push_back(false); + } + std::vector current_stack; + std::vector found_circuits; + + size_t s = 0; + while (s < graph.size()) { + auto sccs = FindStronglyConnectedComponents(graph, s); + auto scc = GetLeastSCC(sccs); + s = LeastVertex(scc); + // Note: the pseudocode calls for an early exit if the subgraph is empty. + // However, that will never happen (there will always be at least one SCC + // with at least one node) + for (size_t i = s; i < graph.size(); i++) { + if (!scc.count(i)) { + continue; + } + blocked[i] = false; + blocked_nodes[i].clear(); + } + CheckCircuit(graph, scc, &blocked_nodes, &blocked, ¤t_stack, &found_circuits, s, s); + s++; + } + return found_circuits; +} + +/********* Coalescing the circuits and returning the results *********/ + +// given all elementary circuits, we want to coalesce any circuits that share nodes +std::vector CoalesceCircuits(const std::vector& circuits) { + std::vector ret; + std::unordered_set merged; + bool changed = false; + for (size_t i = 0; i < circuits.size(); i++) { + if (merged.count(i)) { + continue; + } + node_set current(circuits[i].begin(), circuits[i].end()); + for (size_t j = i + 1; j < circuits.size(); j++) { + if (merged.count(j)) { + continue; + } + for (size_t member : current) { + if (circuits[j].count(member)) { + changed = true; + merged.insert(j); + current.insert(circuits[j].begin(), circuits[j].end()); + } + } + } + ret.push_back(current); + } + // try again if something changed, as there may be more chances to coalesce + if (changed) { + return CoalesceCircuits(ret); + } + return ret; +} + +tvm::Array> DetectRecursion(const IRModule& m) { + auto graph = GatherDependencyGraph(m); + + // have to decide on some ordering for names + std::vector name_ordering; + for (auto kv : graph) { + name_ordering.push_back(kv.first); + } + + auto indices = ConvertToIndices(graph, name_ordering); + auto groups = CoalesceCircuits(DetectElementaryCircuits(indices)); + + // convert to expected representation + tvm::Array> ret; + for (auto group : groups) { + tvm::Array found; + for (size_t node : group) { + found.push_back(m->GetGlobalVar(name_ordering[node])); + } + ret.push_back(found); + } + return ret; +} + +TVM_REGISTER_GLOBAL("relax.analysis.detect_recursion").set_body_typed(DetectRecursion); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc new file mode 100644 index 000000000000..44538fea98e5 --- /dev/null +++ b/src/relax/analysis/layout_transformation.cc @@ -0,0 +1,621 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/analysis/layout_transormation.cc + * \brief Analyze the PrimFunc and suggest layout transformation on it's blocks and buffers based on + * the user provided layout transformations on it's outputs. + */ +#include +#include +#include +#include + +#include "../../support/array.h" + +namespace tvm { +namespace relax { + +using namespace tir; + +/********** Helper Functions **********/ + +/*! \brief Checks if a transformation is bijective affine over the given ranges */ +static bool IsBijectiveAffine(const IndexMap& m, const Array& ranges) { + Map input_iters; + ICHECK_EQ(m->initial_indices.size(), ranges.size()); + for (size_t i = 0; i < ranges.size(); i++) { + input_iters.Set(m->initial_indices[i], ranges[i]); + } + arith::Analyzer analyzer; + auto iter_map_result = DetectIterMap(m->final_indices, input_iters, /* predicate = */ 1, + /*check_level=*/arith::IterMapLevel::Bijective, &analyzer, + /*simplify_trivial_iterators=*/true); + return !iter_map_result->indices.empty(); +} + +/*! + * \brief Analyzer to collect iterators from IterSumExpr. + * \details Analyzes the indices from DetectIterMap analysis to collect the spatial iterators that + * are used in it. This is important to get which spatial iterators are accessed in each index + * of buffer access. + */ +class IndexAnalyzer : public ExprVisitor { + public: + Array Analyze(const arith::IterSumExpr& expr) { + VisitExpr(expr); + return iterators_; + } + + private: + /*! \brief Override VisitExpr for iter expr type processing */ + void VisitExpr(const PrimExpr& expr) override { + if (const auto* op = expr.as()) { + for (const auto& arg : op->args) VisitExpr(arg); + VisitExpr(op->base); + return; + } + if (const auto* op = expr.as()) { + VisitIterMark(op->source); + VisitExpr(op->lower_factor); + VisitExpr(op->extent); + VisitExpr(op->scale); + return; + } + return ExprVisitor::VisitExpr(expr); + } + + void VisitIterMark(const arith::IterMark& op) { + if (const auto* var = op->source.as()) + iterators_.push_back(GetRef(var)); + else + VisitExpr(op->source); + VisitExpr(op->extent); + } + + private: + Array iterators_; +}; + +/*! + * \brief Analyzes IterMapResult to get the Spatial Layout of buffer access. + * \details We define Spatial Layout of a buffer access as an array of length equal to the + * dimensions of the buffer. i-th element of Spatial Layout contains spatial iter var used from the + * block iteration domain. For indices, where no spatial iter vars are used, the spatial layout + * element is empty. If any of the buffer access indices use multiple spatial iter vars, the spatial + * layout is undefined. + * + * Here are a few examples of inferred spatial layout from buffer access. si denotes i-th spatial + * iter var, and ri denotes i-th reduction iter var. + * + * SpatialLayout(A[s0*constant, s1]) = {s0, s1} + * SpatialLayout(A[s0, constant, r0, s1]) = {s0, null, null, s1} + * SpatialLayout(A[s0 * c + s1]) = undefined + */ +using SpatialLayout = Array>; +static SpatialLayout GetSpatialLayout(const arith::IterMapResult& iter_map_result) { + ICHECK(!iter_map_result->indices.empty()); + SpatialLayout result; + for (const arith::IterSumExpr& index : iter_map_result->indices) { + IndexAnalyzer index_analyzer; + Array iter_vars = index_analyzer.Analyze(index); + if (iter_vars.size() >= 2) { + LOG(WARNING) << "[LayoutInference] Unable to get spatial layout of access: " + << arith::NormalizeIterMapToExpr(index); + return {}; + } + if (iter_vars.empty()) { + result.push_back({}); + continue; + } + result.push_back(iter_vars[0]); + } + return result; +} + +/*! + * \brief Checks if the two spatial layouts are identical. Two empty spatial layouts are treated as + * unequal. + */ +static bool AreIdenticalSpatialAccess(const SpatialLayout& s0, const SpatialLayout& s1) { + if (s0.empty() || s1.empty()) return false; + if (s0.size() != s1.size()) return false; + for (size_t i = 0; i < s0.size(); ++i) { + if ((!s0[i].defined() && s1[i].defined()) || (s0[i].defined() && !s1[i].defined())) + return false; + if (!s0[i].same_as(s1[i])) return false; + } + return true; +} + +/*! + * \brief Checks if the block accesses a buffer sequentially in terms of spatial dimensions + * (ignoring reduction dimensions). It checks that the order of spatial iter vars in spatial layout + * of a buffer access is same as the order of spatial iter vars in block domain. + */ +using VarToBlockIndexMap = std::unordered_map; +static bool IsSequentialAccess(const SpatialLayout& iterators, + const VarToBlockIndexMap& iter_to_block_index) { + int last_value = -1; + for (const auto& i : iterators) { + if (!i.defined()) continue; + auto it = iter_to_block_index.find(i.value()); + ICHECK(it != iter_to_block_index.end()); + int blk_index = it->second; + if (blk_index <= last_value) return false; + last_value = blk_index; + } + return true; +} + +/*! \brief Checks if two IndexMaps represent identical transforms */ +static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) { + if (t0->initial_indices.size() != t1->initial_indices.size()) return false; + if (t0->final_indices.size() != t1->final_indices.size()) return false; + + // Create a new shape expression. + Array t1_initial_indices = + t1->initial_indices.Map([](tir::Var i) -> PrimExpr { return i; }); + auto t0_output = t0->MapIndices(t1_initial_indices); + arith::Analyzer analyzer; + for (size_t i = 0; i < t0_output.size(); ++i) { + if (!analyzer.CanProveEqual(t0_output[i], t1->final_indices[i])) return false; + } + return true; +} + +/*! + * \brief Returns the layout transformation for a target spatial layout from the source spatial + * layout and transformation. + * \details Given the source buffer spatial layout \p src_spatial_layout and its transformation \p + * src_transformation, this function constructs the transformation for the target buffer whose + * spatial layout is given as \p tgt_spatial_layout. + * + * The algorithm is explained below using an example: + * + * Let's say the source transformation is lambda N, C, H, W -> (N, H, W, C // 4, C % + * 4), source spatial layout is 'NCHW' and target spatial layout is 'KCHW'. + * + * Step 1: Copy over the source transformation initial & final indices to target transformation + * initial and final indices. + * target transformation = lambda N, C, H, W -> (N, H, W, C // 4, C %4) + * + * Step 2: Drop any vars from initial indices which do not occur in target buffer using source and + * target spatial layouts. + * target transformation = lambda C, H, W -> (N, H, W, C // 4, C %4) + * + * Step 3: Erase any expression from final indices which is dependent on a var not present in + * initial indices. + * target transformation = lambda C, H, W -> (H, W, C // 4, C %4) + * + * Step 4: Go over the target spatial layout and add any missing dims to both initial and final + * indices. This is done by checking if any iterator in target spatial layout is not present in + * source spatial layout. + * target transformation = lambda dim, C, H, W -> (dim, H, W, C // 4, C %4) + */ +using VarSet = std::unordered_set; +static Optional InferLayoutTransformation(const SpatialLayout& src_spatial_layout, + const IndexMap& src_transformation, + const SpatialLayout& tgt_spatial_layout) { + // Copy over the src transformation intial and final indices + auto initial_indices = support::AsList(src_transformation->initial_indices); + auto final_indices = support::AsList(src_transformation->final_indices); + + // Get the iterator var set used in target spatial layout. + VarSet tgt_var_set; + for (const auto& i : tgt_spatial_layout) { + if (i.defined()) tgt_var_set.insert(i.value()); + } + + // Erase initial indices corresponding to iter vars that do not occur in target spatial layout. + // Also compute the var set of initial indices. + auto initial_indices_it = initial_indices.begin(); + VarSet initial_indices_var_set; + for (const auto& i : src_spatial_layout) { + ICHECK(i.defined()); + if (tgt_var_set.count(i.value())) { + initial_indices_var_set.insert(*initial_indices_it); + initial_indices_it++; + continue; + } + initial_indices_it = initial_indices.erase(initial_indices_it); + } + + // Erase any expressions in final indices that have undefined vars + auto final_indices_it = final_indices.begin(); + while (final_indices_it != final_indices.end()) { + // Collect all the vars used in this final index. + Array used_vars = tir::UndefinedVars(*final_indices_it); + ICHECK(!used_vars.empty()) + << "IndexMap expression must always contain tir::Var nodes but found none in: " + << *final_indices_it; + + bool has_undefined_vars = std::any_of(used_vars.begin(), used_vars.end(), + [&initial_indices_var_set](const tir::Var& v) { + return initial_indices_var_set.count(v) == 0; + }); + + // If all vars are from initial indices, nothing to do for this final index. + if (!has_undefined_vars) { + final_indices_it++; + continue; + } + // We are about to drop this expr from final indices since it has undefined vars. Check if it is + // dependent on any of the initial indices. If it is dependent, this cannot be dropped and we + // bail by returning null. + // This captures the scenario where the source transformation is unpacking a dimension (e.g, + // "H4h" -> "H*4+h" ) and the buffer we are trying to infer the transformation of has 'h' + // dimension, but not 'H'. So, it is dependent on undefined var 'H' and defined var 'h'. + bool depends_on_initial_indices = std::any_of(used_vars.begin(), used_vars.end(), + [&initial_indices_var_set](const tir::Var& v) { + return initial_indices_var_set.count(v) != 0; + }); + if (depends_on_initial_indices) { + LOG(WARNING) + << "[LayoutInference] Buffer access is dependent on both defined and undefined vars"; + return {}; + } + // It is ok to erase this final index expression as it only depends on undefined vars. + final_indices_it = final_indices.erase(final_indices_it); + } + + // Go over the target spatial layout and add any missing dims to both initial and final indices. + // This is done by checking if any iterator in target spatial layout is not present in source + // spatial layout. + VarSet src_var_set; + for (const auto& i : src_spatial_layout) { + ICHECK(i.defined()); + src_var_set.insert(i.value()); + } + + initial_indices_it = initial_indices.begin(); + final_indices_it = final_indices.begin(); + for (const auto& i : tgt_spatial_layout) { + if (i.defined() && src_var_set.count(i.value())) { + initial_indices_it++; + if (final_indices_it != final_indices.end()) final_indices_it++; + continue; + } + + auto new_dim = tir::Var("d"); + initial_indices.insert(initial_indices_it, new_dim); + final_indices.insert(final_indices_it, new_dim); + } + + return IndexMap(support::AsArray(initial_indices), support::AsArray(final_indices)); +} + +/*! + * \brief Analyzes the Block and given output buffer transformations to propose + * transformations of block and read buffers. + * \details It does a best effort analysis to propose transformations which would preserve + * sequential access to buffers (especially output buffers). Since this is best effort, it is + * possible that the Block is too complex for analysis. In such a case, no transformations are + * proposed. Limitations: + * 1. Expects exactly one write buffer in the block whose transformation is given by + * `write_transformation`. + * 2. Expects write buffer access to be affine and only use spatial iterators of the block. + * 3. Proposes transformations to a read buffer if all access to it are affine. + */ +class BlockAnalyzer : public StmtExprVisitor { + public: + explicit BlockAnalyzer(const Block& block, const Map& transformation_cache, + IndexMap write_transformation) + : can_transform_block_(true), + write_transformation_(write_transformation), + block_(block), + buffer_transformation_cache_(transformation_cache) { + ICHECK(block_->writes.size() == 1); + auto write_buffer = block_->writes[0]->buffer; + + ComputeBlockSpatialDomain(); + + // Visit the block body to collect load/store access patterns of different buffers. + VisitStmt(block_->body); + + // While visiting the load/store accesses it is possible we see an unexpected pattern, such as + // nested block or write access to multiple buffers. In such a case, we can return early as we + // would not be making any layout suggesstions. + if (!can_transform_block_) { + LOG(WARNING) << "[LayoutInference] Unable to transform block " << block->name_hint; + return; + } + + // Get iterator ordering and it's spatial layout. + VarToBlockIndexMap iter_var_to_block_index; + SpatialLayout block_spatial_layout; + int index = 0; + for (const auto& iter_var : block->iter_vars) { + auto var = iter_var->var; + iter_var_to_block_index[var] = index++; + block_spatial_layout.push_back(var); + } + + // Helper to get the spatial layout of buffer from buffer access map. + auto get_spatial_layout = [&](Buffer b) -> SpatialLayout { + auto it = buffer_access_info_.find(b); + if (it == buffer_access_info_.end()) { + return {}; + } + auto access_info = it->second; + return access_info.GetValidSpatialLayout(); + }; + + // Check that write has sequential access within the block. + SpatialLayout write_spatial_layout = get_spatial_layout(write_buffer); + if (write_spatial_layout.empty()) { + can_transform_block_ = false; + return; + } + if (!IsSequentialAccess(write_spatial_layout, iter_var_to_block_index)) { + can_transform_block_ = false; + return; + } + + // Infer Block transformation from write buffer transformation. + auto maybe_block_transformation = InferLayoutTransformation( + write_spatial_layout, write_transformation_, block_spatial_layout); + if (!maybe_block_transformation.defined()) { + can_transform_block_ = false; + return; + } + block_transformation_ = maybe_block_transformation.value(); + + Array block_ranges = block_->iter_vars.Map([](const IterVar& i) { return i->dom; }); + if (!IsBijectiveAffine(block_transformation_, block_ranges)) { + can_transform_block_ = false; + LOG(WARNING) << "[LayoutInference] Inferred block transformation is not bijective affine, " + "transformation: (" + << block_transformation_ << ") over range (" << block_ranges << ")"; + return; + } + + // Infer read buffer transformations from write buffer transformation. + for (const auto& r : block->reads) { + SpatialLayout read_spatial_layout = get_spatial_layout(r->buffer); + if (read_spatial_layout.empty()) continue; + if (!IsSequentialAccess(read_spatial_layout, iter_var_to_block_index)) continue; + + auto maybe_read_transformation = InferLayoutTransformation( + write_spatial_layout, write_transformation_, read_spatial_layout); + if (!maybe_read_transformation.defined()) continue; + IndexMap read_transformation = maybe_read_transformation.value(); + if (buffer_transformation_cache_.count(r->buffer) != 0) { + if (!AreIdenticalTransforms(read_transformation, buffer_transformation_cache_[r->buffer])) + LOG(WARNING) << "[LayoutInference] Buffer: " << r->buffer + << " has conflicting transform proposals -- (preferred) " + << buffer_transformation_cache_[r->buffer] << " vs. " << read_transformation; + continue; + } + read_buffer_transformations_.Set(r->buffer, read_transformation); + } + } + + private: + // Helper class to keep track of spatial layout of buffer as we visit multiple accesses to this + // buffer within the block. + class BufferAccessInfo { + public: + BufferAccessInfo() : is_valid_(true) {} + void Update(SpatialLayout s) { + if (!IsValid()) return; + if (spatial_layout_.empty()) spatial_layout_ = s; + if (!AreIdenticalSpatialAccess(s, spatial_layout_)) { + Invalidate(); + return; + } + } + bool IsValid() { return is_valid_; } + void Invalidate() { is_valid_ = false; } + SpatialLayout GetValidSpatialLayout() { + if (!IsValid()) return {}; + return spatial_layout_; + } + + private: + bool is_valid_; + SpatialLayout spatial_layout_; + }; + + // Helper to break down the indices of buffer access. + SpatialLayout DetectBufferAccessIterMap(Array indices) { + auto result = arith::DetectIterMap( + /*indices=*/indices, /*input_iters*/ spatial_dom_, + /*predicate*/ 1, /*check_level*/ arith::IterMapLevel::NoCheck, &arith_analyzer_); + if (result->indices.empty()) { + LOG(WARNING) << "[LayoutInference] Failed to analyze indices " << indices + << ", error: " << result->errors; + return {}; + } + return GetSpatialLayout(result); + } + + // Compute the spatial domain map of block + void ComputeBlockSpatialDomain() { + for (const IterVar& v : block_->iter_vars) { + if (v->iter_type == kDataPar) { + spatial_dom_.Set(v->var, v->dom); + continue; + } + if (v->iter_type == kCommReduce) continue; + LOG(WARNING) << "[LayoutInference] Cannot compute block spatial domain in presence of " + "unknown block iter_type : " + << v->iter_type; + can_transform_block_ = false; + return; + } + } + + void VisitStmt_(const BlockNode* op) final { + // Blocks with nested blocks cannot be handled yet. + LOG(WARNING) << "[LayoutInference] Nested blocks are not supported for layout inference yet"; + can_transform_block_ = false; + } + void VisitStmt_(const BufferStoreNode* op) final { + StmtExprVisitor::VisitStmt_(op); + + BufferAccessInfo& access_info = buffer_access_info_[op->buffer]; + + // Fast path to ignore further analysis if we know that the buffer access is invalid. + if (!access_info.IsValid()) return; + + // Only single write buffer is supported for each block. + if (!op->buffer.same_as(block_->writes[0]->buffer)) { + access_info.Invalidate(); + LOG(WARNING) << "[LayoutInference] Exactly one write buffer is supported for layout " + "inference, found two: " + << op->buffer << " and " << block_->writes[0]->buffer; + can_transform_block_ = false; + return; + } + + // If the write buffer access cannot be analyzed, no transformation to the block will be made. + auto detected_spatial_layout = DetectBufferAccessIterMap(op->indices); + if (detected_spatial_layout.empty()) { + access_info.Invalidate(); + return; + } + + // Check if we have access info for this buffer, if present, the two accesses must be + // identical. + access_info.Update(detected_spatial_layout); + } + + void VisitExpr_(const BufferLoadNode* op) final { + Buffer read_buffer = op->buffer; + BufferAccessInfo& access_info = buffer_access_info_[op->buffer]; + + auto detected_spatial_layout = DetectBufferAccessIterMap(op->indices); + + if (detected_spatial_layout.empty()) { + access_info.Invalidate(); + return; + } + access_info.Update(detected_spatial_layout); + } + + public: + bool CanBeTransformed() { return can_transform_block_; } + IndexMap GetBlockTransformation() { return block_transformation_; } + Map GetReadBufferTransformations() { return read_buffer_transformations_; } + + private: + bool can_transform_block_; + IndexMap write_transformation_; + Map spatial_dom_; + arith::Analyzer arith_analyzer_; + + Block block_; + IndexMap block_transformation_; + + Map read_buffer_transformations_; + const Map& buffer_transformation_cache_; + std::unordered_map buffer_access_info_; +}; + +/*! + * \brief Analyzes the PrimFunc and user provided output buffer transformations to propose + * transformations of block and buffers within the PrimFunc. + * \details It does a best effort analysis to propose transformations which would preserve + * sequential access to buffers (especially output buffers). Since this is best effort, it is + * possible that the PrimFunc is too complex for analysis. In such a case, no transformations are + * proposed. + */ +class PrimFuncAnalyzer : public StmtExprVisitor { + public: + explicit PrimFuncAnalyzer(const PrimFunc& func, Array write_transformations) { + ICHECK_LE(write_transformations.size(), func->params.size()) + << "Incompatible PrimFunc and write_transformations"; + + size_t first_write_index = func->params.size() - write_transformations.size(); + for (size_t i = 0; i < write_transformations.size(); ++i) { + auto param = func->params[first_write_index + i]; + Optional param_buf = func->buffer_map.Get(param); + ICHECK(param_buf.defined()); + ICHECK_EQ(param_buf.value()->shape.size(), write_transformations[i]->initial_indices.size()) + << "Mismatch between output buffer shape and index map"; + buffer_transformation_cache_.Set(param_buf.value(), write_transformations[i]); + } + VisitStmt(func->body); + } + Map> GetSuggestedTransforms() { + Map> result; + for (const auto& [block, index_map] : block_transformations_) { + Map block_transformations; + block_transformations.Set(block, index_map); + for (const auto& buffer : block_to_buffer_[block]) { + block_transformations.Set(buffer, buffer_transformation_cache_[buffer]); + } + result.Set(block, block_transformations); + } + return result; + } + + private: + void VisitStmt_(const BlockNode* op) final { + if (op->name_hint == "root") { + // Skip the root block + StmtVisitor::VisitStmt_(op); + return; + } + + Block block = GetRef(op); + // Get block write buffer transformation. + if (block->writes.size() != 1) return; + auto write_buffer = block->writes[0]->buffer; + block_to_buffer_[block].push_back(write_buffer); + BlockAnalyzer block_analyzer(block, buffer_transformation_cache_, + buffer_transformation_cache_[write_buffer]); + + if (!block_analyzer.CanBeTransformed()) return; + // Collect the suggested transformations + block_transformations_.Set(block, block_analyzer.GetBlockTransformation()); + + for (const auto& [buffer, index_map] : block_analyzer.GetReadBufferTransformations()) { + // BlockAnalyzer makes sure that it does not propose transformation for a buffer for which a + // transformation has already been proposed by other blocks or by write_transformations which + // are input to this analysis. + ICHECK_EQ(buffer_transformation_cache_.count(buffer), 0); + buffer_transformation_cache_.Set(buffer, index_map); + block_to_buffer_[block].push_back(buffer); + } + } + + private: + Map buffer_transformation_cache_; + Map block_transformations_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> block_to_buffer_; +}; + +Map> SuggestLayoutTransforms( + const PrimFunc& prim_func, Array write_buffer_transformations) { + // No changes to the PrimFunc are required if no transformations on output buffers. + if (write_buffer_transformations.empty()) return {}; + + PrimFuncAnalyzer analyzer(prim_func, write_buffer_transformations); + return analyzer.GetSuggestedTransforms(); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.suggest_layout_transforms")) + .set_body_typed([](PrimFunc fn, Array write_buffer_transformations) { + return SuggestLayoutTransforms(fn, write_buffer_transformations); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/shape_analysis.cc b/src/relax/analysis/shape_analysis.cc new file mode 100644 index 000000000000..70ce5ac06e90 --- /dev/null +++ b/src/relax/analysis/shape_analysis.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file shape_analysis.cc + * + * \brief Utilities for shape analysis. + */ + +#include +#include + +namespace tvm { +namespace relax { + +bool CanProveShapeEqual(const Array& lhs, const Array& rhs, + arith::Analyzer* ana) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!ana->CanProveEqual(lhs[i], rhs[i])) return false; + } + return true; +} + +bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana) { + if (lhs.same_as(rhs)) return true; + auto* lhs_shape = lhs.as(); + auto* rhs_shape = rhs.as(); + + if (lhs_shape && rhs_shape) { + return CanProveShapeEqual(lhs_shape->values, rhs_shape->values, ana); + } else { + return false; + } +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc new file mode 100644 index 000000000000..7dfcd60c952e --- /dev/null +++ b/src/relax/analysis/struct_info_analysis.cc @@ -0,0 +1,865 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file struct_info_analysis.cc + * \brief Implementations of foundation struct info analysis + * + * \note Update this file when you added a new StructInfo. + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +//-------------------------- +// GetStaticType +//-------------------------- +class StaticTypeDeriver : public StructInfoFunctor { + public: + Type VisitStructInfo_(const ObjectStructInfoNode* op) final { return ObjectType(op->span); } + + Type VisitStructInfo_(const PrimStructInfoNode* op) final { + return PrimType(op->dtype, op->span); + } + + Type VisitStructInfo_(const ShapeStructInfoNode* op) final { + return ShapeType(op->ndim, op->span); + } + + Type VisitStructInfo_(const TensorStructInfoNode* op) final { + return DynTensorType(op->ndim, op->dtype); + } + + Type VisitStructInfo_(const TupleStructInfoNode* op) final { + Array fields = + op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + return TupleType(fields, op->span); + } + + Type VisitStructInfo_(const FuncStructInfoNode* op) final { + if (op->IsOpaque()) return PackedFuncType(op->span); + Array params = op->params.value().Map( + [this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + Type ret = this->VisitStructInfo(op->ret); + return FuncType(params, ret, {}, {}, op->span); + } +}; + +Type GetStaticType(const StructInfo& info) { return StaticTypeDeriver()(info); } + +TVM_REGISTER_GLOBAL("relax.analysis.GetStaticType").set_body_typed([](const StructInfo& info) { + return GetStaticType(info); +}); + +//-------------------------- +// StructInfoFromType +//-------------------------- + +StructInfo StructInfoFromType(const Type& type) { + if (type.as()) { + return ObjectStructInfo(type->span); + } else if (const PrimTypeNode* prim_type = type.as()) { + return PrimStructInfo(prim_type->dtype, prim_type->span); + } else if (const ShapeTypeNode* shape_type = type.as()) { + return ShapeStructInfo(shape_type->ndim, type->span); + } else if (const DynTensorTypeNode* tensor_type = type.as()) { + return TensorStructInfo(tensor_type->dtype, tensor_type->ndim); + } else if (const TupleTypeNode* tuple_type = type.as()) { + Array fields; + for (const Type& field : tuple_type->fields) { + fields.push_back(StructInfoFromType(field)); + } + return TupleStructInfo(fields, type->span); + } else if (const FuncTypeNode* func_type = type.as()) { + Array params = + func_type->arg_types.Map([](const Type& param) { return StructInfoFromType(param); }); + StructInfo ret = StructInfoFromType(func_type->ret_type); + return FuncStructInfo(params, ret, func_type->span); + } else { + LOG(FATAL) << "Unsupported type: " << type; + return StructInfo(); + } +} + +//-------------------------- +// EraseToWellDefined +//-------------------------- +class WellDefinedEraser : public StructInfoMutator, + public ExprMutatorBase, + public tir::ExprMutator { + public: + WellDefinedEraser(std::function(const tir::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, arith::Analyzer* ana) + : f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {} + + StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) final { + bool has_undefined = false; + Optional> values; + + if (op->values.defined()) { + std::swap(has_undefined_, has_undefined); + values = op->values.value().Map([&](PrimExpr val) { return this->VisitPrimExpr(val); }); + std::swap(has_undefined_, has_undefined); + } + // erase symbolic shape if we have undefined. + if (!has_undefined) { + if (values.same_as(op->values)) { + return GetRef(op); + } else { + return ShapeStructInfo(values.value(), op->span); + } + } else { + return ShapeStructInfo(op->ndim, op->span); + } + } + + StructInfo VisitStructInfo_(const TensorStructInfoNode* op) final { + bool has_undefined = false; + Optional shape; + + if (op->shape.defined()) { + std::swap(has_undefined_, has_undefined); + shape = relax::ExprMutatorBase::VisitExpr(op->shape.value()); + std::swap(has_undefined_, has_undefined); + } + + // erase symbolic shape if we have undefined. + if (!has_undefined) { + if (shape.same_as(op->shape)) { + return GetRef(op); + } else { + if (shape.defined()) { + return TensorStructInfo(shape.value(), op->dtype, op->span); + } else { + return TensorStructInfo(op->dtype, op->ndim, op->span); + } + } + } else { + return TensorStructInfo(op->dtype, op->ndim, op->span); + } + } + + StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final { + // NOTE: we always require func struct info to be well-defined. + // + // All the occuring symbolic variables are defined in parameters' + // struct info annotations. So there is no needed to erase. + return GetRef(op); + } + + using relax::ExprMutatorBase::VisitExpr_; + using tir::ExprMutator::VisitExpr_; + + // connect things up + PrimExpr VisitPrimExpr(const PrimExpr& expr) { + // apply eager simplification + PrimExpr val = tir::ExprMutator::VisitExpr(expr); + if (!val.same_as(expr)) { + return ana_->Simplify(val); + } else { + return val; + } + } + + Expr VisitExpr_(const DataflowVarNode* var) final { + return VisitExpr_(static_cast(var)); + } + + Expr VisitExpr_(const VarNode* var) final { + Optional ret; + if (f_var_map_ != nullptr) { + ret = f_var_map_(GetRef(var)); + } + has_undefined_ = has_undefined_ || !ret.defined(); + if (ret.defined()) { + ICHECK(ret.as() || ret.as()) + << "Only allow Expr in StructInfo to be ShapeExpr or Var"; + } + return ret.value_or(GetRef(var)); + } + + PrimExpr VisitExpr_(const tir::VarNode* var) final { + Optional ret; + if (f_shape_var_map_ != nullptr) { + ret = f_shape_var_map_(GetRef(var)); + } + has_undefined_ = has_undefined_ || !ret.defined(); + + if (ret.defined()) { + PrimExpr value = ret.value(); + if (value->IsInstance()) { + return tvm::cast(DataType::Int(64), value); + } + ICHECK(value.dtype() == DataType::Int(64)) << "Can only provide i64 expressions in shape"; + return value; + } else { + return GetRef(var); + } + } + + private: + bool has_undefined_ = false; + std::function(const tir::Var& var)> f_shape_var_map_; + std::function(const Var& var)> f_var_map_; + arith::Analyzer* ana_; +}; + +StructInfo EraseToWellDefined( + const StructInfo& info, std::function(const tir::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return WellDefinedEraser(f_shape_var_map, f_var_map, &inst).VisitStructInfo(info); + } else { + return WellDefinedEraser(f_shape_var_map, f_var_map, ana).VisitStructInfo(info); + } +} + +StructInfo EraseToWellDefined(const StructInfo& info, Map shape_var_map, + Map var_map, arith::Analyzer* ana) { + std::function(const tir::Var& var)> f_shape_var_map = nullptr; + std::function(const Var& var)> f_var_map = nullptr; + + if (!shape_var_map.empty()) { + f_shape_var_map = [&](const tir::Var& var) -> Optional { + auto it = shape_var_map.find(var); + if (it != shape_var_map.end()) return (*it).second; + return NullOpt; + }; + } + + if (!var_map.empty()) { + f_var_map = [&](const Var& var) -> Optional { + auto it = var_map.find(var); + if (it != var_map.end()) return (*it).second; + return NullOpt; + }; + } + + return EraseToWellDefined(info, f_shape_var_map, f_var_map, ana); +} + +TVM_REGISTER_GLOBAL("relax.analysis.EraseToWellDefined") + .set_body_typed([](const StructInfo& info, Map shape_var_map, + Map var_map) { + return EraseToWellDefined(info, shape_var_map, var_map); + }); + +//-------------------------- +// IsBaseOf +//-------------------------- +class StructInfoBaseChecker + : public StructInfoFunctor { + public: + explicit StructInfoBaseChecker(arith::Analyzer* ana) : analyzer_(ana) {} + + BaseCheckResult VisitStructInfo(const StructInfo& lhs, const StructInfo& other) override { + // quick path + // Note: subclass may disable this quick path if we need to go over all struct info. + if (lhs.same_as(other)) return BaseCheckResult::kPass; + return StructInfoFunctor::VisitStructInfo(lhs, other); + } + + // Object is base of everything + BaseCheckResult VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { + return BaseCheckResult::kPass; + } + + BaseCheckResult VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + return lhs->dtype == rhs->dtype ? BaseCheckResult::kPass : BaseCheckResult::kFailL0; + } + + BaseCheckResult VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + // lhs have unknown ndim + if (lhs->IsUnknownNdim()) return BaseCheckResult::kPass; + + // ndim must match + if (lhs->ndim != rhs->ndim) { + if (rhs->IsUnknownNdim()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + + // lhs does not have symbolic value + if (!lhs->values.defined()) return BaseCheckResult::kPass; + // rhs does not have symbolic value but lhs do. + if (!rhs->values.defined()) return BaseCheckResult::kFailL2; + + // shape match check + return ShapeMatchCheck(lhs->values.value(), rhs->values.value()); + } + + BaseCheckResult VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + // dtype mismatch + if (!lhs->IsUnknownDtype() && lhs->dtype != rhs->dtype) { + if (rhs->IsUnknownDtype()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + + // ndim mismatch + if (!lhs->IsUnknownNdim() && lhs->ndim != rhs->ndim) { + if (rhs->IsUnknownNdim()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + // lhs does not have defined shape and everything else matches + if (!lhs->shape.defined()) return BaseCheckResult::kPass; + // rhs does not have symbolic value but lhs don't + if (!rhs->shape.defined()) return BaseCheckResult::kFailL2; + + // shape match check + return ShapeMatchCheck(lhs->shape.value(), rhs->shape.value()); + } + + BaseCheckResult VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + return ArrayCheck(lhs->fields, rhs->fields); + } + + BaseCheckResult VisitStructInfo_(const FuncStructInfoNode* lhs, + const StructInfo& other) override { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + + // lhs opaque handling + if (lhs->IsOpaque()) { + if (lhs->derive_func.defined()) { + // function proving is best effort. + return lhs->derive_func.same_as(rhs->derive_func) ? BaseCheckResult::kPass + : BaseCheckResult::kFailL2; + } + // no derivation function, only depends on ret + return this->VisitStructInfo(lhs->ret, rhs->ret); + } + + // Function check is best effort. + // rhs is opaque but lhs is not + if (rhs->IsOpaque()) return BaseCheckResult::kFailL2; + + // NOTE: lhs->params, rhs->params may contain different symbolic + // vars that needs to be re-mapped to each other. + // This can only be done through structural equality check and not ArrayCheck. + // + // So we check structural equality here and if two are structurally + // equal return true. + // + // otherwise we do best effort BaseArrayCheck. + // + // This still does not handle cases where some arguments are sub of another + // while other parameters needs to get remapped. + // + // Given we only do best effort checking in these cases, and such cases + // are likely not a primary concern atm, we take this approach here. + if (struct_equal_(GetRef(lhs), other)) return BaseCheckResult::kPass; + + auto param_check = FuncParamsCheck(lhs->params.value(), rhs->params.value()); + auto ret_check = this->VisitStructInfo(lhs->ret, rhs->ret); + return CombineCheck(param_check, ret_check); + } + + protected: + // analyzer + arith::Analyzer* analyzer_; + // struct equal checker + StructuralEqual struct_equal_; + + // customizable functions. + /*! + * \brief Check symbolic shape value equivalence. + * \param lhs The left hand shape. + * \param rhs The right hand shape. + * \return CheckResult. + */ + virtual BaseCheckResult PrimValueMatchCheck(const PrimExpr& lhs, const PrimExpr& rhs) { + // get static shape checking right. + auto* int_lhs = lhs.as(); + auto* int_rhs = rhs.as(); + if (int_lhs && int_rhs) { + if (int_lhs->value == int_rhs->value) { + return BaseCheckResult::kPass; + } else { + return BaseCheckResult::kFailL0; + } + } + return analyzer_->CanProveEqual(lhs, rhs) ? BaseCheckResult::kPass : BaseCheckResult::kFailL2; + } + /*! + * \brief CheckShape value. + * \param lhs The left hand shape. + * \param rhs The right hand shape. + * \return CheckResult. + */ + virtual BaseCheckResult ShapeMatchCheck(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; + + BaseCheckResult ret = BaseCheckResult::kPass; + for (size_t i = 0; i < lhs.size(); ++i) { + auto cmp_ret = PrimValueMatchCheck(lhs[i], rhs[i]); + if (ret == BaseCheckResult::kFailL0) return ret; + ret = CombineCheck(cmp_ret, ret); + } + return ret; + } + + /*! + * \brief CheckShape value. + * \param lhs The left hand shape. + * \param rhs The right hand shape. + * \return Check result. + */ + virtual BaseCheckResult ShapeMatchCheck(const Expr& lhs, const Expr& rhs) { + if (lhs.same_as(rhs)) return BaseCheckResult::kPass; + auto* lhs_shape = lhs.as(); + auto* rhs_shape = rhs.as(); + if (lhs_shape && rhs_shape) { + return ShapeMatchCheck(lhs_shape->values, rhs_shape->values); + } else { + return BaseCheckResult::kFailL2; + } + } + + /*! + * \brief CheckShape function parameters. + * \param lhs The left hand params. + * \param rhs The right hand params. + * \return Check result. + */ + virtual BaseCheckResult FuncParamsCheck(const Array& lhs, + const Array& rhs) { + auto res = ArrayCheck(lhs, rhs); + // treat L1 failures in params checking as L2. + if (res == BaseCheckResult::kFailL1) res = BaseCheckResult::kFailL2; + return res; + } + // helper functions + /*! + * \brief Combine check results. + * \param lhs The left operand. + * \param rhs The righr operand. + * \return The check result. + */ + static BaseCheckResult CombineCheck(BaseCheckResult lhs, BaseCheckResult rhs) { + if (lhs == BaseCheckResult::kFailL0 || rhs == BaseCheckResult::kFailL0) { + return BaseCheckResult::kFailL0; + } + if (lhs == BaseCheckResult::kFailL1 || rhs == BaseCheckResult::kFailL1) { + return BaseCheckResult::kFailL1; + } + if (lhs == BaseCheckResult::kFailL2 || rhs == BaseCheckResult::kFailL2) { + return BaseCheckResult::kFailL2; + } + return BaseCheckResult::kPass; + } + + /*! + * \brief Generic helper function to check arrays. + * \param lhs The left operand. + * \param rhs The right operand. + */ + BaseCheckResult ArrayCheck(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; + BaseCheckResult ret = BaseCheckResult::kPass; + + for (size_t i = 0; i < lhs.size(); ++i) { + auto cmp_ret = this->VisitStructInfo(lhs[i], rhs[i]); + if (ret == BaseCheckResult::kFailL0) return ret; + ret = CombineCheck(cmp_ret, ret); + } + return ret; + } +}; + +BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, + arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return StructInfoBaseChecker(&inst)(base, derived); + } else { + return StructInfoBaseChecker(ana)(base, derived); + } +} + +TVM_REGISTER_GLOBAL("relax.analysis.StructInfoBaseCheck") + .set_body_typed([](const StructInfo& base, const StructInfo& derived) -> int { + return static_cast(StructInfoBaseCheck(base, derived)); + }); + +bool IsBaseOf(const StructInfo& base, const StructInfo& derived, arith::Analyzer* ana) { + return StructInfoBaseCheck(base, derived, ana) == BaseCheckResult::kPass; +} + +TVM_REGISTER_GLOBAL("relax.StructInfoIsBaseOf") + .set_body_typed([](const StructInfo& base, const StructInfo& derived) { + return IsBaseOf(base, derived); + }); + +//-------------------------- +// DeriveStructInfo +//-------------------------- + +// NOTE: we are reusing StructInfoBaseChecker here to populate a mapping +// from the expressions in arg(rhs) to var in param. +class CallRetStructInfoDeriver : public StructInfoBaseChecker { + public: + explicit CallRetStructInfoDeriver(arith::Analyzer* ana) : StructInfoBaseChecker(ana) {} + + // No short cut, so we can recursively populate all pairs. + BaseCheckResult VisitStructInfo(const StructInfo& lhs, const StructInfo& other) final { + return StructInfoFunctor::VisitStructInfo(lhs, other); + } + + StructInfo Derive(const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { + // opaque derivation + if (finfo->IsOpaque()) { + if (finfo->derive_func.defined()) { + // derive using custom derivation function. + return finfo->derive_func.value()(call, ctx); + } else { + // directly return the normal value. + return finfo->ret; + } + } + + // Normal function signature derivation. + auto params = finfo->params.value(); + if (params.size() != call->args.size()) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "number of arguments and parameters mismatch:" + << " expected " << params.size() << ", given " << call->args.size()); + } + // Visit each param arg pair, check and populate the var map + for (size_t i = 0; i < params.size(); ++i) { + auto arg_sinfo = GetStructInfo(call->args[i]); + BaseCheckResult res = this->VisitStructInfo(params[i], arg_sinfo); + // Report error if we find L1 level failure + // L2 level is best effort so we don't report. + // The behavior of L2 can be customized later. + if (res == BaseCheckResult::kFailL0 || res == BaseCheckResult::kFailL1) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "Argument " << i << " type mismatch:" + << " expected " << params[i] << ", given " << arg_sinfo); + } + } + // map the ret using the populated var map. + return EraseToWellDefined(finfo->ret, shape_var_map_, var_map_); + } + + protected: + // Whether to populate map in params. + bool populate_mapping_{true}; + // for simplicity, we make these fields public so the user can access them. + Map shape_var_map_; + Map var_map_; + + using StructInfoBaseChecker::ShapeMatchCheck; + + // Match shape values in between param(lhs) and arg(rhs) + BaseCheckResult PrimValueMatchCheck(const PrimExpr& param, const PrimExpr& arg) final { + if (!populate_mapping_) { + return StructInfoBaseChecker::PrimValueMatchCheck(param, arg); + } + + if (auto* ptr = param.as()) { + auto var = GetRef(ptr); + auto it = shape_var_map_.find(var); + // not populated + if (it == shape_var_map_.end()) { + shape_var_map_.Set(var, arg); + return BaseCheckResult::kPass; + } else { + // Best effort prove. + PrimExpr mapped_value = (*it).second; + if (analyzer_->CanProveEqual(mapped_value, arg)) return BaseCheckResult::kPass; + return BaseCheckResult::kFailL2; + } + } else { + // Best effort + // Do not attempt to do prove when param contains a symbolic expr. + // such expression might depends on a later defined var in params created by dyn fusion. + // example: f(a: Tensor[(n+1)], s: Shape[(n,)]), the (n+1) case here. + return StructInfoBaseChecker::PrimValueMatchCheck(param, arg); + } + } + + BaseCheckResult ShapeMatchCheck(const Expr& lhs, const Expr& rhs) final { + if (!populate_mapping_) { + return StructInfoBaseChecker::ShapeMatchCheck(lhs, rhs); + } + + if (auto* ptr = lhs.as()) { + auto var = GetRef(ptr); + auto it = var_map_.find(var); + // not populated + if (it == var_map_.end()) { + var_map_.Set(var, rhs); + return BaseCheckResult::kPass; + } else { + // Best effort prove. + Expr mapped_value = (*it).second; + if (CanProveShapeEqual(mapped_value, rhs, analyzer_)) return BaseCheckResult::kPass; + return BaseCheckResult::kFailL2; + } + } + auto lhs_shape = lhs.as(); + auto rhs_shape = rhs.as(); + ICHECK(lhs_shape) << "lhs must have a shape"; + if (!rhs_shape) return BaseCheckResult::kFailL2; + return ShapeMatchCheck(lhs_shape->values, rhs_shape->values); + } + + BaseCheckResult FuncParamsCheck(const Array& lhs, + const Array& rhs) final { + // Set populate mapping to false + // so we do not pick up symbolic vars in params with function type. + // + // @R.function + // def f(g: R.Func([R.Tensor[(n,)]], R.Tensor[(n+1,)]), + // x: R.Tensor[(m,)]) -> R.Tensor[(m,)]: + // ... + // + // For example, in the above function f, we should avoid + // pick up n in g's signature. + bool populate_mapping = false; + std::swap(populate_mapping_, populate_mapping); + auto ret = StructInfoBaseChecker::FuncParamsCheck(lhs, rhs); + std::swap(populate_mapping_, populate_mapping); + return ret; + } +}; + +StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, + const BlockBuilder& ctx, arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return CallRetStructInfoDeriver(&inst).Derive(finfo, call, ctx); + } else { + return CallRetStructInfoDeriver(ana).Derive(finfo, call, ctx); + } +} + +TVM_REGISTER_GLOBAL("relax.analysis.DeriveCallRetStructInfo") + .set_body_typed([](const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { + return DeriveCallRetStructInfo(finfo, call, ctx); + }); + +//-------------------------- +// UnifyToLCA +//-------------------------- +class StructInfoLCAFinder + : public StructInfoFunctor { + public: + explicit StructInfoLCAFinder(arith::Analyzer* ana) : analyzer_(ana) {} + + StructInfo VisitStructInfo(const StructInfo& lhs, const StructInfo& other) final { + // quick path + if (lhs.same_as(other)) return lhs; + return StructInfoFunctor::VisitStructInfo(lhs, other); + } + + // Object is based of everything, unify to object. + StructInfo VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { + return GetRef(lhs); + } + + StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + if (lhs->dtype == rhs->dtype) return GetRef(lhs); + // PrimType will be treated as their boxed(object) values + // as a result we can unify to object. + return ObjectStructInfo(lhs->span); + } + + StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + + int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim; + if (lhs->ndim != rhs->ndim || !lhs->values.defined() || !rhs->values.defined() || + !CanProveShapeEqual(lhs->values.value(), rhs->values.value(), analyzer_)) { + // prefers return same when possible + if (!lhs->values.defined() && lhs->ndim == ndim) { + return GetRef(lhs); + } else { + return ShapeStructInfo(ndim, lhs->span); + } + } + // equals to each other + return GetRef(lhs); + } + + StructInfo VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + + // find the target dtype and ndim. + DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void(); + int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim; + // if ndim mismatch or one side of shape is missing + // then we cannot keep in symbolic shape + if (lhs->ndim != rhs->ndim || !lhs->shape.defined() || !rhs->shape.defined() || + !CanProveShapeEqual(lhs->shape.value(), rhs->shape.value(), analyzer_)) { + // reuse lhs when possible + if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim) { + return GetRef(lhs); + } else { + return TensorStructInfo(dtype, ndim, lhs->span); + } + } + // symbolic shape match but dtype mismatch + if (lhs->dtype != dtype) { + return TensorStructInfo(lhs->shape.value(), dtype, lhs->span); + } else { + return GetRef(lhs); + } + } + + StructInfo VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + Optional> fields = UnifyArray(lhs->fields, rhs->fields); + // tuple length not the same. + if (!fields.defined()) return ObjectStructInfo(lhs->span); + + // same length tuple. + if (!fields.same_as(lhs->fields)) { + return TupleStructInfo(fields.value(), lhs->span); + } else { + return GetRef(lhs); + } + } + + StructInfo VisitStructInfo_(const FuncStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + + // lhs opaque handling + if (lhs->IsOpaque()) { + if (lhs->derive_func.defined()) { + if (lhs->derive_func.same_as(rhs->derive_func)) { + return GetRef(lhs); + } else { + // Create a new opaque with object return + return FuncStructInfo::OpaqueFunc(ObjectStructInfo(), lhs->span); + } + } else { + // no derivation function, only depends on ret + StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); + if (ret.same_as(lhs->ret)) return GetRef(lhs); + return FuncStructInfo::OpaqueFunc(ret, lhs->span); + } + } + // rhs is opaque, lhs is not + if (rhs->IsOpaque()) { + // unify ret value, note that rhs's ret is context free(because it is opaque) + // so result of the unify is also context-free. + StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); + return FuncStructInfo::OpaqueFunc(ret, lhs->span); + } + + // Both lhs and rhs are not opaque + // NOTE: lhs->params, rhs->params may contain different symbolic + // vars that needs to be re-mapped to each other. + // This can only be done through structural equality check. + // + // So we check structural equality here and if two are structurally + // equal return true. + // + // otherwise we do best effort of unify types without considering var remap. + // + // This still does not handle cases where some arguments are sub of another + // while other parameters needs to get remapped. + // + // Given we only do best effort checking in these cases, and such cases + // are likely not a primary concern atm, we take this approach here. + if (struct_equal_(GetRef(lhs), GetRef(rhs))) { + return GetRef(lhs); + } + + auto params = UnifyArray(lhs->params.value(), rhs->params.value()); + auto ret = this->VisitStructInfo(lhs->ret, rhs->ret); + + if (params.same_as(lhs->params) && ret.same_as(lhs->ret)) { + return GetRef(lhs); + } else { + // fail to unify the params + if (!params.defined()) { + return FuncStructInfo::OpaqueFunc(ret, lhs->span); + } else { + return FuncStructInfo(params.value(), ret, lhs->span); + } + } + } + + private: + // analyzer + arith::Analyzer* analyzer_; + // struct equal checker + StructuralEqual struct_equal_; + + // check arrays + Optional> UnifyArray(const Array& lhs, + const Array& rhs) { + if (lhs.same_as(rhs)) return lhs; + if (lhs.size() != rhs.size()) return NullOpt; + size_t index = 0; + return lhs.Map([&](const StructInfo& a) { return this->VisitStructInfo(a, rhs[index++]); }); + } +}; + +StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return StructInfoLCAFinder(&inst)(lhs, rhs); + } else { + return StructInfoLCAFinder(ana)(lhs, rhs); + } +} + +TVM_REGISTER_GLOBAL("relax.analysis.StructInfoLCA") + .set_body_typed([](const StructInfo& lhs, const StructInfo& rhs) { + return StructInfoLCA(lhs, rhs); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc new file mode 100644 index 000000000000..d7d84c197366 --- /dev/null +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -0,0 +1,453 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using namespace tir; + +class PatternKindAnalyzer : public StmtExprVisitor { + public: + explicit PatternKindAnalyzer(const tir::PrimFunc& func) { + for (const tir::Var& param : func->params) { + Optional param_buf = func->buffer_map.Get(param); + if (param_buf.defined()) { + param_buffers_.insert(param_buf.value()); + } + } + } + + private: + bool IsOutputBlock(const BlockNode* block) { + for (const BufferRegion& write_region : block->writes) { + if (param_buffers_.count(write_region->buffer)) { + return true; + } + } + return false; + } + + void VisitStmt_(const BufferStoreNode* op) final { + // We only support one buffer store in a block (usually generated by TE compute) + // If we have already seen buffer store in the current block, classify as Opaque. + if (store_.defined() && !IsSameArray(op->indices, store_.value()->indices)) { + kind_ = relay::kOpaque; + return; + } + store_ = GetRef(op); + StmtVisitor::VisitStmt_(op); + } + + void VisitExpr_(const BufferLoadNode* op) final { + loads_.push_back(GetRef(op)); + ExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BlockNode* op) final { + if (op->name_hint == "root") { + // Skip the root block + StmtVisitor::VisitStmt(op->body); + return; + } + + // Step 1. Clear loads and store + loads_.clear(); + store_ = NullOpt; + // Step 2. Visit block body. + StmtVisitor::VisitStmt(op->body); + + // We support exactly one buffer store in a block (usually generated by TE compute) + // If we have not seen any store in the current block, classify as Opaque. + if (!store_.defined()) { + kind_ = relay::kOpaque; + return; + } + + BufferStore store = store_.value(); + + // Step 3. Checking load store indices pattern + relay::OpPatternKind index_pair_pattern = relay::kElemWise; + bool has_elem_wise = false; + for (const BufferLoad& load : loads_) { + // Since elemwise is stricter than broadcast and broadcast is stricter than injective, + // while the order amount enums: kElemWise < kBroadcast < kInjective. + // We can simply use `std::max` to detect these three patterns. + // E.g Here is only one store node but two load nodes, like C[i, j] = A[i, j] + B[i] + // Buffer C and A are elemwise but C and B are broadcast. So the whole block follows + // broadcast pattern. + if (IsElemwisePattern(store, load)) { + index_pair_pattern = std::max(index_pair_pattern, relay::kElemWise); + has_elem_wise = true; + } else if (IsBroadcastPattern(store, load)) { + index_pair_pattern = std::max(index_pair_pattern, relay::kBroadcast); + } else if (IsInjectivePattern(store, load)) { + index_pair_pattern = std::max(index_pair_pattern, relay::kInjective); + } else { + index_pair_pattern = relay::kOpaque; + break; + } + } + // If there is a index pair is kElemWise and others are kBroadcast, we regard it as kElemWise + // e.g. A[i, j] = B[i, j] + C[i] + if (index_pair_pattern == relay::kBroadcast && has_elem_wise) { + index_pair_pattern = relay::kElemWise; + } + // If the block index pattern is not opaque, update kind. + if (index_pair_pattern != relay::kOpaque) { + // This rule for softmax: reduce + injective. + if (IsOutputBlock(op) && kind_ == relay::kCommReduce) { + kind_ = relay::kOutEWiseFusable; + } else { + kind_ = std::max(kind_, index_pair_pattern); + } + return; + } + + // Step 4. Checking if the block contains reduce axis by looking into block iterators. + bool has_reduction = false; + Array reduce_vars; + for (const IterVar& it : op->iter_vars) { + if (it->iter_type == kCommReduce) { + has_reduction = true; + reduce_vars.push_back(it->var); + } + } + + if (has_reduction) { + if (IsFMA(op->body)) { + // FMA is regards as kOutEWiseFusable, e.g. Matmul or Conv. + kind_ = std::max(kind_, relay::kOutEWiseFusable); + return; + } else { + for (size_t i = 0; i < loads_.size(); ++i) { + // If it's not a pure reduce, regards as kOutEWiseFusable. + // This rule works for pooling for now. + if (!IsPureReducePattern(reduce_vars, loads_[i]->indices)) { + kind_ = std::max(kind_, relay::kOutEWiseFusable); + return; + } + } + } + kind_ = std::max(kind_, relay::kCommReduce); + } else { + kind_ = relay::kOpaque; + } + } + + /********** Helper Functions **********/ + + /*! \brief Checking if two arrays contains same elements. */ + static bool IsSameArray(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); ++i) { + if (!lhs[i].same_as(rhs[i])) { + return false; + } + } + return true; + } + + /*! + * \brief Checking the load indices and store indices follows elemwise pattern. + * It's elemwise pattern iff load indices and store indices are the same. + * E.g A[i, j] = B[i, j] + */ + static bool IsElemwisePattern(const BufferStore& store, const BufferLoad& load) { + return IsSameArray(store->indices, load->indices); + } + + /*! + * \brief Checking the load indices and store indices follows broadcast pattern. + * It's broadcast pattern iff all load indices are in the store indices in order + * E.g. A[i, j] = B[i] is broadcast since all load indices(`i`) are in the store indices + * A[i, j] = B[i, k] is not broadcast since `k` are not in the store indices. + * A[i, j] = B[j, i] is not broadcast the load indices are not in the same order as store's + */ + static bool IsBroadcastPattern(const BufferStore& store, const BufferLoad& load) { + size_t ndim_load_buf = load->buffer->shape.size(); + size_t ndim_store_buf = store->buffer->shape.size(); + + for (size_t i = 0, j = 0; i < ndim_load_buf; ++i) { + if (is_const_int(load->buffer->shape[i], 1) && is_const_int(load->indices[i], 0)) { + // Skip unit load dimensions + // E.g. A[i, j] = B[1, j] is still broadcast + continue; + } + + // Try to find the i-th load index in the store indices. + while (j < ndim_store_buf && !store->indices[j].same_as(load->indices[i])) { + ++j; + } + + // It's not broadcast if we cannot find load indices in the store indices in order. + if (j == ndim_store_buf) { + return false; + } + } + return true; + } + + /*! + * \brief Checking the load indices and store indices follows injective pattern. + * It's injective pattern iff all load index vars are in the store indices, no matter orders. + * Note that we only support store indices are direct vars so far, which can be enhance later. + * E.g. A[i, j] = B[j, i] is injective. + * A[i, j] = B[i - j] is injective since the load index vars are only i, j + */ + static bool IsInjectivePattern(const BufferStore& store, const BufferLoad& load) { + std::unordered_set vars; + for (const PrimExpr& store_index : store->indices) { + if (const auto* v = store_index.as()) { + vars.insert(v); + } else { + return false; + } + } + for (const PrimExpr& load_index : load->indices) { + // return false if there are vars used in load indices but not in store indices. + if (tir::UsesVar(load_index, [&vars](const tir::VarNode* var) { return !vars.count(var); })) { + return false; + } + } + return true; + } + + /*! + * \brief Checking the load indices and store indices allow data reuse. + * It allow data reuse iff there is any vars in load indices but they are not in store indices + * E.g. Store = A[i, j] and Load = B[i, j, k] allow data reuse. + * Store = A[i, j] and Load = B[i, j + k] allow data reuse. + */ + static bool IsAllowReusePattern(const BufferStore& store, const BufferLoad& load) { + std::unordered_set vars; + for (const PrimExpr& index : store->indices) { + if (const auto* v = index.as()) { + vars.insert(v); + } else { + return false; + } + } + for (const PrimExpr& index : load->indices) { + PreOrderVisit(index, [&](const ObjectRef& node) { + if (const auto* v = node.as()) { + if (vars.count(v)) { + vars.erase(v); + } + } + return true; + }); + } + return !vars.empty(); + } + + /*! \brief Checking if the stmt is multiply add. E.g. C[i, j] += A[i, k] * B[j, k] */ + static bool IsFMA(const Stmt& body) { + if (const auto* store = body.as()) { + if (const auto* add = store->value.as()) { + if (const auto* l = add->a.as()) { + if (const auto* r = add->b.as()) { + bool incremental = + store->buffer.same_as(l->buffer) && IsSameArray(store->indices, l->indices); + const auto* l_load = r->a.as(); + const auto* r_load = r->b.as(); + if (incremental && l_load && r_load) { + return IsAllowReusePattern(GetRef(store), GetRef(l_load)) && + IsAllowReusePattern(GetRef(store), GetRef(r_load)); + } + } + } + } + } + return false; + } + + /*! + * \brief Checking if it is pure reduce pattern. + * It's pure reduce pattern iff all reduces axis are directly reduce var + * E.g. A[i] = sum(B[i, j]) is pure reduce + * A[i] = sum(B[i, j + k]) is not pure reduce + * pooling is not pure reduce + */ + static bool IsPureReducePattern(Array reduce_loops, Array indices) { + for (const PrimExpr& e : indices) { + int id = -1; + if (UsesVar(e, [&](const tir::VarNode* var) { + for (size_t i = 0; i < reduce_loops.size(); ++i) { + if (reduce_loops[i].get() == var) { + id = i; + return true; + } + } + return false; + })) { + if (!reduce_loops[id].same_as(e)) { + return false; + } + } + } + return true; + } + + private: + /*! + * \brief The BufferStore node in the current block. + * \note We only support one BufferStore node in a block (usually generated by TE compute) + */ + Optional store_; + /*! \brief The BufferLoad nodes in the current block. */ + Array loads_; + /*! \brief The result of op pattern. */ + relay::OpPatternKind kind_ = relay::kElemWise; + /*! \brief The buffers from function params. I.e. the input and output buffers. */ + std::unordered_set param_buffers_; + + public: + relay::OpPatternKind GetResult() { return kind_; } +}; + +relay::OpPatternKind AnalyzeOpPatternKind(const PrimFunc& func) { + PatternKindAnalyzer analyzer(func); + analyzer(func->body); + return analyzer.GetResult(); +} + +bool HasReshapePattern(const PrimFunc& func) { + class ReshapeDetector : public StmtVisitor { + public: + static bool Detect(const Buffer& src_buffer, const Buffer& dst_buffer, Stmt stmt) { + ReshapeDetector detector(src_buffer, dst_buffer); + detector(stmt); + return detector.is_reshape_; + } + + private: + explicit ReshapeDetector(const Buffer& src_buffer, const Buffer& dst_buffer) + : is_reshape_(false), src_buffer_(src_buffer), dst_buffer_(dst_buffer) {} + + void VisitStmt_(const ForNode* loop) final { + ana_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + // To detect the reshape pattern, we require each For to have + // either another For or a BlockRealize as body. + if (!(loop->body->IsInstance() || loop->body->IsInstance())) { + return; + } + this->VisitStmt(loop->body); + } + + void VisitStmt_(const BlockRealizeNode* block_realize) final { + // Constructing the mapping from block iterators to iterator + // binding values. The mapping will be used in the substitution of + // the flattened buffer access index. + const Block& block = block_realize->block; + const Array& block_iter = block->iter_vars; + const Array& iter_values = block_realize->iter_values; + ICHECK_EQ(block_iter.size(), iter_values.size()); + int n_iter = block_iter.size(); + for (int i = 0; i < n_iter; ++i) { + // To detect the reshape pattern, we require each block iter to be data-parallel. + if (block_iter[i]->iter_type != tir::IterVarType::kDataPar) { + return; + } + } + + // Recurse into the block. + this->VisitStmt(block); + } + + void VisitStmt_(const BlockNode* block) final { + // Step 0. If the block body is a ForNode, recurse into it. + if (block->body->IsInstance()) { + this->VisitStmt(block->body); + return; + } + + for (const IterVar& v : block->iter_vars) { + ana_.Bind(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent)); + } + + // Step 1. Get the load/store pattern of the block body. + // To detect the reshape pattern, we require the block body to be a + // BufferStore, which has a BufferLoad as value. + const auto* buffer_store = block->body.as(); + if (buffer_store == nullptr) { + return; + } + const auto* buffer_load = buffer_store->value.as(); + if (buffer_load == nullptr) { + return; + } + // Further, we require the buffer being stored and being loaded to + // match the parameter of the PrimFunc, namely `dst_buffer_` and `src_buffer_`. + if (!(buffer_store->buffer.same_as(dst_buffer_) && + buffer_load->buffer.same_as(src_buffer_))) { + return; + } + + // Step 3. Calculate the flattened access index according to the load/store pattern. + auto f_calc_flattened_idx = [](const Buffer& buffer, const Array& indices) { + ICHECK_EQ(indices.size(), buffer->shape.size()); + int ndim = indices.size(); + PrimExpr idx = 0; + for (int i = 0; i < ndim; ++i) { + idx = idx * buffer->shape[i] + indices[i]; + } + return idx; + }; + PrimExpr src_idx = f_calc_flattened_idx(src_buffer_, buffer_load->indices); + PrimExpr dst_idx = f_calc_flattened_idx(dst_buffer_, buffer_store->indices); + + // Step 4. Check if we can prove the equality of flattened indices. + if (ana_.CanProveEqual(src_idx, dst_idx)) { + this->is_reshape_ = true; + } + } + + bool is_reshape_; + const Buffer& src_buffer_; + const Buffer& dst_buffer_; + arith::Analyzer ana_; + }; + + if (func->params.size() < 2) { + return false; + } + Optional src_buffer = func->buffer_map.Get(func->params.front()); + Optional dst_buffer = func->buffer_map.Get(func->params.back()); + if (!(src_buffer.defined() && dst_buffer.defined())) { + return false; + } + + // To detect the reshape pattern, we require each For to have + // either another For or a BlockRealize as body. + ICHECK(func->body->IsInstance()); + return ReshapeDetector::Detect(src_buffer.value(), dst_buffer.value(), func->body); +} + +TVM_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc new file mode 100644 index 000000000000..1c49fd581f7d --- /dev/null +++ b/src/relax/analysis/udchain.cc @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/analysis/udchain.cc + * \brief Implementation of use-def analysis. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class UDChain : public relax::ExprVisitor { + public: + // nullptr users means it is the output of the function. + std::map> to_users; + + const VarNode* cur_user_; + + void VisitBinding_(const VarBindingNode* binding) override { + // init + cur_user_ = binding->var.get(); + this->VisitVarDef(binding->var); + this->VisitExpr(binding->value); + cur_user_ = nullptr; + } + + void VisitExpr_(const VarNode* op) override { to_users[op].insert(cur_user_); } + void VisitVarDef(const Var& var) override { to_users[var.get()] = {}; } + void VisitExpr_(const FunctionNode* op) override { + cur_user_ = nullptr; + ExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const DataflowVarNode* op) override { + VisitExpr_(static_cast(op)); + } +}; + +std::pair>, runtime::Array> FunctionUseDef( + const Function& fn) { + UDChain udchain; + udchain.VisitExpr(fn); + + Map> user_map; + Array fn_outs; + + for (const auto& [var, users] : udchain.to_users) { + Array uses{}; + uses.reserve(users.size()); + for (const auto& v : users) { + if (v == nullptr && + std::find(fn_outs.begin(), fn_outs.end(), GetRef(var)) == fn_outs.end()) { + fn_outs.push_back(GetRef(var)); + } else { + uses.push_back(GetRef(v)); + } + } + user_map.Set(GetRef(var), std::move(uses)); + } + return std::make_pair(std::move(user_map), std::move(fn_outs)); +} + +runtime::Map> DataflowBlockUseDef(const DataflowBlock& dfb) { + UDChain udchain; + udchain.VisitBindingBlock(dfb); + runtime::Map> ret; + for (const auto& [var, users] : udchain.to_users) { + Array uses{}; + uses.reserve(users.size()); + for (const auto& v : users) uses.push_back(GetRef(v)); + ret.Set(GetRef(var), std::move(uses)); + } + return ret; +} + +TVM_REGISTER_GLOBAL("relax.analysis.udchain").set_body_typed(DataflowBlockUseDef); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/var2value.cc b/src/relax/analysis/var2value.cc new file mode 100644 index 000000000000..be50e9bdcef2 --- /dev/null +++ b/src/relax/analysis/var2value.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { +class Var2ValAnalysis : public relax::ExprVisitor { + public: + tvm::runtime::Map var2value_; + void VisitBinding_(const VarBindingNode* binding) override { + var2value_.Set(binding->var, binding->value); + // Recursively visit the value to handle local functions. + VisitExpr(binding->value); + } +}; + +tvm::runtime::Map AnalyzeVar2Value(const Expr& expr) { + Var2ValAnalysis var2val_analysis; + var2val_analysis.VisitExpr(expr); + return std::move(var2val_analysis.var2value_); +} + +tvm::runtime::Map AnalyzeVar2Value(const DataflowBlock& dfb) { + Var2ValAnalysis var2val_analysis; + var2val_analysis.VisitBindingBlock_(dfb.get()); + return std::move(var2val_analysis.var2value_); +} + +tvm::runtime::Map AnalyzeVar2Value(const IRModule& m) { + Var2ValAnalysis var2val_analysis; + + for (const auto& it : m->functions) { + // visit relax.Function + if (auto* n = it.second.as()) { + var2val_analysis.VisitExpr(GetRef(n)); + } + } + + return std::move(var2val_analysis.var2value_); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.get_var2val")).set_body_typed([](const Function& f) { + return AnalyzeVar2Value(f); +}); + +class Name2BindingAnalysis : public relax::ExprVisitor { + public: + // runtime::Map is not suitable for doing in-place update. + // so we use standard container for internal usage. + std::map> name2bindings_; + void VisitBinding_(const VarBindingNode* binding) override { + const auto& vname = binding->var->name_hint(); + name2bindings_[vname].push_back(GetRef(binding)); + } + + void VisitBinding_(const MatchCastNode* binding) override { + const auto& vname = binding->var->name_hint(); + name2bindings_[vname].push_back(GetRef(binding)); + } +}; + +Map> NameToBinding(const Function& fn) { + Name2BindingAnalysis analysis{}; + analysis.VisitExpr_(fn.get()); + return Map>(std::make_move_iterator(analysis.name2bindings_.begin()), + std::make_move_iterator(analysis.name2bindings_.end())); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.name_to_binding")).set_body_typed(NameToBinding); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc new file mode 100644 index 000000000000..3eeefd0be584 --- /dev/null +++ b/src/relax/analysis/well_formed.cc @@ -0,0 +1,496 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/analysis/well_formed.cc + * \brief Check if the IRModule is well-formed. + * + * This pass is supposed to be applied to normalized Relax AST. + * If it's malformed, messages will be logged as Warning. + * This pass will check: + * 1. Each Expr should have `struct_info_` field already populated, when + * `check_struct_info` is true. + * 2. GlobalVars are defined before use. + * 3. When a Function has a corresponding GlobalVar and a `global_symbol` + * attribute, the name of the GlobalVar must equal the value of the + * `global_symbol` attribute value. + * 4. Any variable cannot used as different function parameters in the same IRModule + * 5. Vars are defined before use. + * 6. Vars are defined exactly once. + * 7. Symbolic Vars are defined before use. + * 8. DataflowVars cannot be defined inside BindingBlock. + * 9. Vars defined in IfNode, except the return Var, are invisible + * out of the If body.(May change for new AST designs) + * 10. SeqExpr only serves as function body, or in the true and + * false branches in IfNode. + * 11. The IR is in ANF: + * (a) Expressions cannot contain nested complex expressions. + * Here are the expressions that may be nested inside other expressions: + * Var, DataflowVar, GlobalVar, Constant, ShapeExpr, + * Op, Tuple (we call these "leaf" expressions). + * (b) The right-hand side of a binding may contain a non-leaf expression + * (where all expressions nested in it are leaf expressions), + * other than SeqExprs (see rule 6) + * (c) Exceptions: The body of a Function node and the true branch + * and false branch of If nodes *must* be SeqExprs. + * (d) Places where non-leaf expressions cannot appear: + * * The tuple_value field of TupleGetItem nodes + * * The cond field of If nodes + * * The op or args fields of Call nodes + * * Inside the fields of Tuple nodes + * 12. Expr always has checked_type_ (with the exception of Op). + */ +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +// TODO(relax-team): Consider further refactor using +// Scope Frame to store manage the var context. +// +/*! \brief Helper to implement well formed check.*/ +class WellFormedChecker : public relax::ExprVisitor, + public relax::StructInfoVisitor, + public tir::ExprVisitor { + public: + static bool Check(IRModule mod, bool check_struct_info) { + WellFormedChecker well_formed_checker = WellFormedChecker(mod, check_struct_info); + + for (const auto& it : mod->functions) { + // visit relax.Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func); + well_formed_checker.VisitExpr(func); + } + } + return well_formed_checker.well_formed_; + } + + private: + explicit WellFormedChecker(IRModule mod, bool check_struct_info) + : mod_(std::move(mod)), check_struct_info_(check_struct_info) {} + + using relax::ExprVisitor::VisitExpr_; + using tir::ExprVisitor::VisitExpr; + using tir::ExprVisitor::VisitExpr_; + + // Possible mode of visitor + enum class VisitMode { + /*! + * \brief Check all vars are well-defined + */ + kDefault, + /*! + * \brief Match define the vars on first occurance. + * Do not check the well-defined property of composite expr. + */ + kMatchVarDef + }; + + void Malformed(Diagnostic diag) { + well_formed_ = false; + LOG(WARNING) << "This IR is not well formed: " << diag->message; + } + + void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, Function func) { + // check name in global var and gsymbol + Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + if (gsymbol.defined() && gsymbol != var->name_hint) { + Malformed(Diagnostic::Error(func->span) + << "Name in GlobalVar is not equal to name in gsymbol: " << var->name_hint + << " != " << gsymbol.value()); + } + } + + void VisitExpr(const Expr& expr) final { + if (!expr.as() && !expr->checked_type_.defined()) { + Malformed(Diagnostic::Error(expr) << "The checked_type_ of Expr " << expr << " is nullptr."); + } + relax::ExprVisitor::VisitExpr(expr); + } + + void VisitExpr_(const GlobalVarNode* op) final { + GlobalVar var = GetRef(op); + if (!(mod_->ContainGlobalVar(var->name_hint) && + mod_->GetGlobalVar(var->name_hint).same_as(var))) { + Malformed(Diagnostic::Error(var) << "GlobalVar " << op->name_hint << " is not defined."); + } + + if (op->checked_type_.defined()) { + if ((!op->checked_type_->IsInstance()) && + (!op->checked_type_->IsInstance())) { + Malformed(Diagnostic::Error(var) << "The checked_type_ of GlobalVar " << op->name_hint + << " must be either FuncType or PackedFuncType."); + } + } + + CheckStructInfo(op); + } + + void VisitExpr_(const TupleNode* op) final { + for (size_t i = 0; i < op->fields.size(); i++) { + Expr expr = op->fields[i]; + if (IsLeafOrTuple(expr)) { + this->VisitExpr(expr); + } else { + Malformed(Diagnostic::Error(expr) + << "Tuple is not in ANF form, field " << i << " gets " << expr->GetTypeKey()); + } + } + + CheckStructInfo(op); + } + + void VisitExpr_(const TupleGetItemNode* op) final { + if (IsLeafOrTuple(op->tuple)) { + this->VisitExpr(op->tuple); + } else { + Malformed(Diagnostic::Error(op) + << "The tuple value in a TupleGetItem node must be a leaf expression."); + } + CheckStructInfo(op); + } + + void VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + if (var_set_.count(var) == 0 && recur_vars_.count(var) == 0) { + Malformed(Diagnostic::Error(var) << "Var " << op->name_hint() << " is not defined."); + } + CheckStructInfo(op); + } + + void VisitExpr_(const DataflowVarNode* op) final { + DataflowVar var = GetRef(op); + if (!is_dataflow_) { + Malformed(Diagnostic::Error(var) + << "DataflowVar " << op->name_hint() << " is used outside DataflowBlock."); + } + if (dataflow_var_set_.count(var) == 0) { + Malformed(Diagnostic::Error(var) << "DataflowVar " << op->name_hint() << " is not defined."); + } + CheckStructInfo(op); + } + + void VisitExpr_(const FunctionNode* op) final { + // save the var_set_ for local function + auto prev_var_set = var_set_; + auto prev_dataflow_var_set = dataflow_var_set_; + auto prev_symbolic_var_set = symbolic_var_set_; + bool old_dataflow_state = is_dataflow_; + // symbolic var is not captured across function boundaries + symbolic_var_set_.clear(); + is_dataflow_ = false; + + // first populate defs in params + WithMode(VisitMode::kMatchVarDef, [&]() { + ICHECK(mode_ == VisitMode::kMatchVarDef); + for (Var param : op->params) { + relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param)); + } + }); + + // check all expr are well defined. + for (Var param : op->params) { + this->VisitVarDef(param); + + if (param_var_func_map_.count(param) == 1) { + // TODO(relax-team): Complete this error info after we integrate printer + Malformed(Diagnostic::Error(param->span) + << "Relax variable " << param->name_hint() + << " is repeatedly used as parameters in function."); + } + param_var_func_map_.insert({param, GetRef(op)}); + } + // check function ret_struct_info + if (op->ret_struct_info.defined()) { + this->VisitStructInfo(op->ret_struct_info); + } else { + Malformed(Diagnostic::Error(op) << "Function must have defined ret_struct_info"); + } + + if (auto seq = op->body.as()) { + this->VisitSeqExpr(seq); + } else { + Malformed(Diagnostic::Error(op) << "Function bodies must be sequence expressions"); + } + + is_dataflow_ = old_dataflow_state; + dataflow_var_set_ = prev_dataflow_var_set; + var_set_ = prev_var_set; + symbolic_var_set_ = prev_symbolic_var_set; + } + + void VisitExpr_(const CallNode* op) final { + if (IsLeafOrTuple(op->op)) { + this->VisitExpr(op->op); + } else { + Malformed(Diagnostic::Error(op) << "The called expression must be a leaf expression"); + } + for (size_t i = 0; i < op->args.size(); i++) { + Expr arg = op->args[i]; + if (IsLeafOrTuple(arg)) { + this->VisitExpr(arg); + } else { + Malformed(Diagnostic::Error(arg->span) + << "Call is not in ANF form, arg " << i << " gets " << arg->GetTypeKey()); + } + } + + for (const StructInfo& sinfo_arg : op->sinfo_args) { + this->VisitStructInfo(sinfo_arg); + } + + CheckStructInfo(op); + } + + void VisitExpr_(const IfNode* op) final { + if (IsLeafOrTuple(op->cond)) { + this->VisitExpr(op->cond); + } else { + Malformed(Diagnostic::Error(op) << "The condition for an if node must be a leaf expression."); + } + auto true_seq = op->true_branch.as(); + auto false_seq = op->false_branch.as(); + if (true_seq && false_seq) { + std::unordered_set previous_var_set = var_set_; + std::unordered_set previous_symbolic_var_set = + symbolic_var_set_; + this->VisitSeqExpr(true_seq); + var_set_ = previous_var_set; + symbolic_var_set_ = previous_symbolic_var_set; + this->VisitSeqExpr(false_seq); + var_set_ = previous_var_set; + symbolic_var_set_ = previous_symbolic_var_set; + } else { + Malformed(Diagnostic::Error(op) << "If node branches must be seq exprs"); + } + CheckStructInfo(op); + } + + void VisitExpr_(const ShapeExprNode* op) final { + for (PrimExpr expr : op->values) { + // check if the symbolic vars in the expr are defined, e.g, 2 * m + tir::ExprVisitor::VisitExpr(expr); + if (!expr.dtype().is_int()) { + Malformed(Diagnostic::Error(expr) + << "Shape expressions must be of integer type, but got " << expr.dtype()); + } + } + CheckStructInfo(op); + } + + void VisitExpr_(const SeqExprNode* op) final { + Malformed(Diagnostic::Error(op) << "SeqExpr only serves as the function body in FunctionNode, " + "or the true/false branch body in IfNode."); + } + + void VisitSeqExpr(const SeqExprNode* op) { + // a special call only if SeqExpr is the function body + // in FunctionNode or the true/false branch body in IfNode + for (BindingBlock block : op->blocks) { + this->VisitBindingBlock(block); + } + if (!IsLeafOrTuple(op->body)) { + Malformed(Diagnostic::Error(op) << "SeqExpr bodies must be leaf expressions."); + } + this->VisitExpr(op->body); + CheckStructInfo(op); + } + + void VisitBinding_(const VarBindingNode* binding) final { + bool is_lambda = false; + if (binding->value->IsInstance()) { + is_lambda = true; + recur_vars_.insert(binding->var); + } + if (binding->value->IsInstance()) { + Malformed(Diagnostic::Error(binding->value) << "Inline PrimFunc is disallowed in Relax IR."); + } else { + this->VisitExpr(binding->value); + } + this->VisitVarDef(binding->var); + if (is_lambda) { + recur_vars_.erase(binding->var); + } + } + + void VisitBinding_(const MatchCastNode* binding) final { + this->VisitExpr(binding->value); + // define the vars + WithMode(VisitMode::kMatchVarDef, [&]() { this->VisitStructInfo(binding->struct_info); }); + + this->VisitStructInfo(binding->struct_info); + this->VisitVarDef(binding->var); + } + + void VisitBindingBlock_(const DataflowBlockNode* block) final { + bool old_is_dataflow_ = is_dataflow_; + is_dataflow_ = true; + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + is_dataflow_ = old_is_dataflow_; + dataflow_var_set_.clear(); + } + + void VisitVarDef_(const DataflowVarNode* var) final { + if (!is_dataflow_) { + Malformed(Diagnostic::Error(var) + << "DataflowVar " << var->name_hint() << " is defined outside DataflowBlock."); + } + DataflowVar lv = GetRef(var); + if (dataflow_var_set_.count(lv) == 1) { + Malformed(Diagnostic::Error(var) + << "DataflowVar " << lv->name_hint() << " is defined more than once."); + } + // register DataflowVar + dataflow_var_set_.insert(lv); + CheckStructInfo(var); + } + + void VisitVarDef_(const VarNode* var) final { + Var gv = GetRef(var); + if (var_set_.count(gv) == 1) { + Malformed(Diagnostic::Error(var) + << "Var " << gv->name_hint() << " is defined more than once."); + } + // register Var + var_set_.insert(gv); + CheckStructInfo(var); + } + + void VisitVarDef(const Var& var) final { + if (const DataflowVarNode* lv_node = var.as()) { + VisitVarDef_(lv_node); + } else if (const VarNode* gv_node = var.as()) { + VisitVarDef_(gv_node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } + } + + void VisitExpr_(const tir::VarNode* op) final { + tir::Var var = GetRef(op); + // default mode, check defined. + if (symbolic_var_set_.count(var) == 0) { + this->Malformed(Diagnostic::Error(var) + << "Symbolic Var " << var->name_hint << " is not defined."); + } + } + + void VisitStructInfo_(const FuncStructInfoNode* op) final { + if (op->params.defined()) { + WithMode(VisitMode::kMatchVarDef, [&]() { + ICHECK(mode_ == VisitMode::kMatchVarDef); + for (StructInfo param : op->params.value()) { + this->VisitStructInfo(param); + } + }); + } + this->VisitStructInfo(op->ret); + } + + void VisitStructInfoExprField(const Expr& expr) final { + if (mode_ == VisitMode::kMatchVarDef) { + // populate symbolic var in first occurrence + if (auto* op = expr.as()) { + auto var = GetRef(op); + if (var_set_.count(var) == 0) { + var_set_.insert(var); + } + } + if (auto* shape = expr.as()) { + for (auto val : shape->values) { + this->VisitStructInfoExprField(val); + } + } + } else { + relax::ExprVisitor::VisitExpr(expr); + } + } + + void VisitStructInfoExprField(const PrimExpr& expr) final { + if (mode_ == VisitMode::kMatchVarDef) { + // populate symbolic var in first occurrence + if (auto* op = expr.as()) { + auto var = GetRef(op); + if (symbolic_var_set_.count(var) == 0) { + symbolic_var_set_.insert(var); + } + } + } else { + tir::ExprVisitor::VisitExpr(expr); + } + } + + void CheckStructInfo(const ExprNode* op) { + if (!check_struct_info_) { + return; + } + + auto* sinfo = op->struct_info_.as(); + if (sinfo != nullptr) { + this->VisitStructInfo(GetRef(sinfo)); + } else { + Malformed(Diagnostic::Error(op) << "Expr must have struct_info populated. " + << " Expr.type_key=" << op->GetTypeKey()); + } + } + + // Run callback with mode. + template + void WithMode(VisitMode mode, FType callback) { + std::swap(mode_, mode); + callback(); + std::swap(mode_, mode); + } + + IRModule mod_; + const bool check_struct_info_; + bool well_formed_ = true; + bool is_dataflow_; + // Current visit mode. + VisitMode mode_ = VisitMode::kDefault; + // set of context variables. + std::unordered_set var_set_; + std::unordered_set recur_vars_; + std::unordered_set dataflow_var_set_; + std::unordered_set symbolic_var_set_; + std::unordered_map param_var_func_map_; +}; + +bool WellFormed(IRModule m, bool check_struct_info) { + return WellFormedChecker::Check(std::move(m), check_struct_info); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.well_formed")) + .set_body_typed([](IRModule m, bool check_struct_info) { + return WellFormed(m, check_struct_info); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h new file mode 100644 index 000000000000..219799870728 --- /dev/null +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -0,0 +1,419 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/backend/contrib/codegen_json/codegen_json.h + * \brief Utilities for json codegen and runtime + */ +#ifndef TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_ +#define TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../../../../runtime/contrib/json/json_node.h" +#include "../../../../runtime/contrib/json/json_runtime.h" +#include "../../../transform/utils.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace backend { +namespace contrib { + +using namespace tvm::runtime::json; + +using ShapeVector = std::vector>; +using TypeVector = std::vector; +using JSONGraphObjectPtr = std::shared_ptr; + +/*! + * \brief Helper class to extract all attributes of a certain op and save them + * into text format. + */ +class OpAttrExtractor : public AttrVisitor { + public: + explicit OpAttrExtractor(JSONGraphObjectPtr node) : node_(node) {} + + template ::value>> + std::string Fp2String(const T value) { + std::ostringstream out; + out.precision(std::numeric_limits::max_digits10); + out << value; + return out.str(); + } + + void SetNodeAttr(const char* key, const std::vector& value) { + std::vector attr; + attr.emplace_back(value); + node_->SetAttr(key, attr); + } + + void Visit(const char* key, double* value) final { SetNodeAttr(key, {Fp2String(*value)}); } + + void Visit(const char* key, int64_t* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, uint64_t* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, int* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, bool* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, std::string* value) final { SetNodeAttr(key, {*value}); } + + void Visit(const char* key, DataType* value) final { + if (!value->is_void()) { + SetNodeAttr(key, {runtime::DLDataType2String(*value)}); + } else { + SetNodeAttr(key, {""}); + } + } + + void Visit(const char* key, runtime::ObjectRef* value) final { + if (const auto* an = (*value).as()) { + std::vector attr; + for (size_t i = 0; i < an->size(); ++i) { + if (const auto* im = (*an)[i].as()) { + attr.push_back(std::to_string(im->value)); + } else if (const auto* fm = (*an)[i].as()) { + attr.push_back(Fp2String(fm->value)); + } else if (const auto* str = (*an)[i].as()) { + String s = GetRef(str); + attr.push_back(s); + } else { + LOG(FATAL) << "Not supported type: " << (*an)[i]->GetTypeKey(); + } + } + SetNodeAttr(key, attr); + } else if (!(*value).defined()) { // Skip NullValue + SetNodeAttr(key, std::vector{""}); + } else if (const auto* im = (*value).as()) { + SetNodeAttr(key, std::vector{std::to_string(im->value)}); + } else if (const auto* fm = (*value).as()) { + SetNodeAttr(key, std::vector{Fp2String(fm->value)}); + } else if (const auto* str = (*value).as()) { + String s = GetRef(str); + SetNodeAttr(key, std::vector{s}); + } else { + LOG(FATAL) << "Not yet supported type: " << (*value)->GetTypeKey() << ": " << *value; + } + } + + void Visit(const char* key, runtime::NDArray* value) final { + LOG(FATAL) << "NDArray is not allowed in op attribute"; + } + + void Visit(const char* key, void** value) final { + LOG(FATAL) << "void pointer is not allowed in op attribute"; + } + + void Extract(Object* node) { + if (node) { + reflection_->VisitAttrs(node, this); + } + } + + private: + JSONGraphObjectPtr node_; + ReflectionVTable* reflection_ = ReflectionVTable::Global(); +}; + +using NodeEntries = std::vector; + +/*! \brief Serialize a Relax expression to JSON. */ +class JSONSerializer : public relax::MemoizedExprTranslator { + public: + using MemoizedExprTranslator::VisitExpr_; + using MemoizedExprTranslator::VisitBinding_; + + /*! + * \brief Constructor + * \param constant_names The names of all constants in the original module. + */ + explicit JSONSerializer(const Map& constant_names) + : constant_names_(constant_names) {} + + void serialize(Function func) { + // First we convert all the parameters into input nodes. + for (const auto& param : func->params) { + auto node_ptr = std::make_shared(param->name_hint(), "input" /* op_type_ */); + memo_[param] = AddNode(node_ptr, param); + } + heads_ = VisitExpr(func->body); + } + + /*!\brief Return the required constants. */ + Array GetConstantNames() const { return constants_used_; } + + /*!\brief Return the generated json. */ + std::string GetJSON() { + std::ostringstream os; + dmlc::JSONWriter writer(&os); + Save(&writer); + return os.str(); + } + + protected: + /*! + * \brief Add a node to graph. + * + * \param node A graph node. It is a shared pointer. Some attributes of it + * will be added, i.e. shape and type. These attributes are attached to + * the JSON graph in the end. + * \param expr The relax expression. + * \return A list of graph entry nodes. It the relax expr is a tuple type, we + * will flatten it. + */ + NodeEntries AddNode(JSONGraphObjectPtr node, const Expr& expr) { + auto struct_info = GetStructInfo(expr); + auto node_id = nodes_.size(); + nodes_.push_back(node); + NodeEntries ret; + ShapeVector shape; + TypeVector dtype; + + // Flatten tuple node. + if (const auto* tuple_sinfo = struct_info.as()) { + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + const auto* tensor_sinfo = tuple_sinfo->fields[i].as(); + ICHECK(tensor_sinfo) << "Expect TensorStructInfo, but received: ." + << tuple_sinfo->fields[i]->GetTypeKey(); + ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined."; + ShapeExpr output_shape = Downcast(tensor_sinfo->shape.value()); + ret.push_back(JSONGraphNodeEntry(node_id, i)); + shape.emplace_back(GetIntShape(output_shape->values)); + dtype.emplace_back(DType2String(tensor_sinfo->dtype)); + } + node->SetNumOutput(tuple_sinfo->fields.size()); + } else { + const auto* tensor_sinfo = struct_info.as(); + ICHECK(tensor_sinfo) << "Expect TensorStructInfo, but received: " + << struct_info->GetTypeKey(); + ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined."; + ShapeExpr output_shape = Downcast(tensor_sinfo->shape.value()); + + shape.emplace_back(GetIntShape(output_shape->values)); + dtype.emplace_back(DType2String(tensor_sinfo->dtype)); + ret.push_back(JSONGraphNodeEntry(node_id, 0)); + } + std::vector shape_attrs; + shape_attrs.emplace_back(shape); + node->SetAttr("shape", shape_attrs); + + std::vector type_attrs; + type_attrs.emplace_back(dtype); + node->SetAttr("dtype", type_attrs); + return ret; + } + + void SetCallNodeAttribute(JSONGraphObjectPtr node, const CallNode* cn) { + if (cn->op.as()) { + OpAttrExtractor extractor(node); + const Object* call_attr = cn->attrs.get(); + extractor.Extract(const_cast(call_attr)); + } else if (const auto* fn = cn->op.as()) { + ICHECK(false); + auto pattern = fn->GetAttr(attr::kPartitionedFromPattern); + ICHECK(pattern.defined()); + std::vector values; + values.push_back(pattern.value()); + std::vector attr; + attr.emplace_back(values); + node->SetAttr("PartitionedFromPattern", attr); + } + } + + NodeEntries VisitBinding_(const MatchCastNode* binding) { + LOG(FATAL) << "JSON runtime currently doesn't match cast\n"; + return {}; + } + + NodeEntries VisitBinding(const Binding& binding) { + NodeEntries nodes; + if (const auto* node = binding.as()) { + auto from_b = VisitBinding_(node); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } else if (const auto* node = binding.as()) { + auto from_b = VisitBinding_(node); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + return nodes; + } + + NodeEntries VisitBindingBlock(const BindingBlock& block) { + NodeEntries nodes; + if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + nodes.insert(nodes.end(), from_bb.begin(), from_bb.end()); + } else if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + nodes.insert(nodes.end(), from_bb.begin(), from_bb.end()); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return nodes; + } + + NodeEntries VisitBindingBlock_(const BindingBlockNode* block) { + NodeEntries nodes; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } + return nodes; + } + + NodeEntries VisitBindingBlock_(const DataflowBlockNode* block) { + NodeEntries nodes; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } + return nodes; + } + + NodeEntries VisitExpr_(const SeqExprNode* op) { + NodeEntries nodes; + for (BindingBlock block : op->blocks) { + VisitBindingBlock(block); + } + auto from_body = VisitExpr(op->body); + nodes.insert(nodes.end(), from_body.begin(), from_body.end()); + return nodes; + } + + NodeEntries VisitExprDefault_(const Object* op) { + LOG(FATAL) << "JSON runtime currently doesn't support " << op->GetTypeKey(); + return {}; + } + + NodeEntries VisitExpr_(const ConstantNode* cn) { + auto name = constant_names_.find(GetRef(cn)); + ICHECK(name != constant_names_.end()) + << "Cannot find the name of the constant: " << GetRef(cn); + constants_used_.push_back((*name).second); + auto node = std::make_shared((*name).second, "const" /* op_type_ */); + return AddNode(node, GetRef(cn)); + } + + NodeEntries VisitExpr_(const TupleNode* tn) { + NodeEntries fields; + for (const auto& field : tn->fields) { + auto ref = VisitExpr(field); + fields.insert(fields.end(), ref.begin(), ref.end()); + } + return fields; + } + + NodeEntries VisitExpr_(const CallNode* cn) { + Expr expr = GetRef(cn); + std::string name; + if (const auto* op_node = cn->op.as()) { + name = op_node->name; + } else if (const auto* fn = cn->op.as()) { + auto comp = fn->GetAttr(attr::kComposite); + ICHECK(comp.defined()) << "JSON runtime only supports composite functions."; + name = comp.value(); + } else { + LOG(FATAL) << "JSON runtime does not support calls to " << cn->op->GetTypeKey(); + } + + // TODO(@sunggg): Revisit when we have op naming convention. + // Currently, simply remove "relax." prefix to make it work. + name = std::string("tensorrt.") + name.substr(6); + + NodeEntries inputs; + for (const auto& arg : cn->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + SetCallNodeAttribute(node, cn); + return AddNode(node, GetRef(cn)); + } + + NodeEntries VisitExpr_(const TupleGetItemNode* gtn) { + auto vtuple = VisitExpr(gtn->tuple); + return {vtuple[gtn->index]}; + } + + NodeEntries VisitExpr_(const FunctionNode* fn) { + ICHECK(fn->GetAttr(attr::kComposite).defined()) + << "JSON runtime only supports composite functions"; + + // FunctionNode should be handled by the caller. + return {}; + } + + /*! + * \brief Save to JSON graph + * + * \param writer A json writer + */ + void Save(dmlc::JSONWriter* writer) { + std::vector arg_nodes; + for (size_t i = 0; i < nodes_.size(); ++i) { + auto node = nodes_[i]; + if (node->IsLeaf()) { + arg_nodes.push_back(i); + } + } + size_t num_entry = 0; + std::vector node_row_ptr{0}; + for (auto node : nodes_) { + num_entry += node->GetNumOutput(); + node_row_ptr.push_back(num_entry); + } + writer->BeginObject(); + writer->WriteObjectKeyValue("nodes", nodes_); + writer->WriteObjectKeyValue("arg_nodes", arg_nodes); + writer->WriteObjectKeyValue("heads", heads_); + writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr); + writer->EndObject(); + } + + private: + /*! \brief JSON graph nodes. */ + std::vector nodes_; + /*! \brief Output of the JSON graph. */ + NodeEntries heads_; + /*! \brief The list of required constants, ordered. */ + Array constants_used_; + /*! \brief The names of all constants in the original module. */ + const Map& constant_names_; +}; + +} // namespace contrib +} // namespace backend +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_ diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc new file mode 100644 index 000000000000..8ef68baf6832 --- /dev/null +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/cutlass/codegen.cc + * \brief Implementation of the CUTLASS code generator for Relax. + */ +#include "../../../../relay/backend/contrib/cutlass/codegen.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include "../../../../relay/backend/contrib/codegen_c/codegen_c.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using namespace relay::contrib::cutlass; + +using Output = relay::contrib::Output; +using GenerateBodyOutput = relay::contrib::GenerateBodyOutput; +using relay::contrib::cutlass::GenerateBody; +using OutputType = std::vector; + +class CodegenCutlass : public relax::MemoizedExprTranslator, + public relay::contrib::CodegenCBase { + public: + CodegenCutlass(const std::string& id, const Map& bindings) + : ext_func_id_(id), bindings_(bindings) {} + + std::string JIT(const OutputType& out) final { + std::vector arg_types, arg_names; + + for (const auto& arg : ext_func_args_) { + auto sinfo = GetStructInfo(arg); + if (const auto* tensor_sinfo = sinfo.as()) { + arg_types.emplace_back(backend::DType2String(tensor_sinfo->dtype)); + } else { + LOG(FATAL) << "Unimplemented"; + } + arg_names.push_back(arg->name_hint()); + } + + code_stream_ << EmitSignature(out, ext_func_id_, arg_names) << "{\n"; + + this->EnterScope(); + + for (auto decl : buf_decl_) { + this->PrintIndents(); + code_stream_ << decl << "\n"; + } + code_stream_ << "\n"; + for (auto stmt : ext_func_body_) { + this->PrintIndents(); + code_stream_ << stmt << "\n"; + } + + this->ExitScope(); + code_stream_ << "}\n"; + + this->GenerateBackendCFunc(ext_func_id_, arg_types, /*const_arr_name=*/"", out, true); + return code_stream_.str(); + } + + Array GetHeaders() { return headers_; } + + protected: + OutputType VisitExpr_(const VarNode* node) final { + ext_func_args_.push_back(GetRef(node)); + Output output; + output.name = node->name_hint(); + return {output}; + } + + OutputType VisitExpr_(const CallNode* call) final { + const auto* fn_var = call->op.as(); + ICHECK(fn_var); + const auto func = Downcast(bindings_[GetRef(fn_var)]); + const auto pattern_name_opt = func->GetAttr(attr::kComposite); + ICHECK(pattern_name_opt) << "Only composite function is supported for CUTLASS."; + auto ret = GenerateBody(call, pattern_name_opt.value(), func->attrs->dict); + ext_func_body_.push_back(ret.decl); + headers_ = ret.headers; + return ret.outputs; + } + + OutputType VisitExpr_(const FunctionNode* fn) { + ICHECK(fn->GetAttr(attr::kComposite).defined()) + << "JSON runtime only supports composite functions"; + // FunctionNode should be handled by the caller. + return {}; + } + + OutputType VisitBinding(const Binding& binding) { + OutputType outputs; + if (const auto* node = binding.as()) { + auto from_b = VisitBinding_(node); + outputs.insert(outputs.end(), from_b.begin(), from_b.end()); + } else { + LOG(FATAL) << "Unimplemented type: " << binding->GetTypeKey(); + } + return outputs; + } + + OutputType VisitBindingBlock(const BindingBlock& block) { + OutputType outputs; + if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + outputs.insert(outputs.end(), from_bb.begin(), from_bb.end()); + } else if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + outputs.insert(outputs.end(), from_bb.begin(), from_bb.end()); + } else { + LOG(FATAL) << "Unimplemented type: " << block->GetTypeKey(); + } + return outputs; + } + + OutputType VisitBindingBlock_(const BindingBlockNode* block) { + OutputType outputs; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + outputs.insert(outputs.end(), from_b.begin(), from_b.end()); + } + return outputs; + } + + OutputType VisitBindingBlock_(const DataflowBlockNode* block) { + OutputType outputs; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + outputs.insert(outputs.end(), from_b.begin(), from_b.end()); + } + return outputs; + } + + OutputType VisitExpr_(const SeqExprNode* op) { + OutputType outputs; + + for (BindingBlock block : op->blocks) { + VisitBindingBlock(block); + } + + auto from_body = VisitExpr(op->body); + outputs.insert(outputs.end(), from_body.begin(), from_body.end()); + + return outputs; + } + + private: + Array GetArgumentNames(const CallNode* call) { + Array arg_names; + for (size_t i = 0; i < call->args.size(); ++i) { + auto res = VisitExpr(call->args[i]); + for (const auto& out : res) { + arg_names.push_back(out.name); + } + } + return arg_names; + } + + GenerateBodyOutput GenerateBody(const CallNode* call, const std::string& func_name, + const Map& attrs) { + auto func_args = GetArgumentNames(call); + auto struct_info = GetStructInfo(GetRef(call)); + + std::vector out_types; + if (const auto* tensor_sinfo = struct_info.as()) { + out_types.emplace_back(backend::DType2String(tensor_sinfo->dtype)); + } else { + LOG(FATAL) << "Unimplemented sinfo type: " << struct_info; + } + + return contrib::GenerateBody(func_name, ext_func_id_, out_types, func_args, attrs, &buf_idx_); + } + + /*! \brief The id of the external cutlass ext_func. */ + std::string ext_func_id_; + /*! + * \brief The index to track the output buffer. Each kernel will redirect the + * output to a buffer that may be consumed by other kernels. + */ + int buf_idx_{0}; + /*! \brief The arguments used by a wrapped function that calls CUTLASS kernels. */ + Array ext_func_args_; + /*! \brief The statements of the function that will be compiled using CUTLASS kernels. */ + std::vector ext_func_body_; + /*! \brief The declaration of intermediate buffers. */ + std::vector buf_decl_; + /*! \brief The binding to look up composite functions. */ + Map bindings_; + /*! \brief Required header-file names. */ + Array headers_; +}; + +class CutlassModuleCodegen { + public: + runtime::Module CreateCSourceModule(Array functions, + const Map& options) { + std::string headers = ""; + std::string code = ""; + for (const auto& f : functions) { + auto [f_code, op_headers] = GenCutlassFunc(f, options); + code += "\n" + f_code; + for (const auto& header : op_headers) { + headers += "#include <" + header + ">\n"; + } + } + return Finalize(headers + "\n" + code, func_names_); + } + + private: + std::pair> GenCutlassFunc(const Function& function, + const Map& options) { + ICHECK(function.defined()) << "Input error: expect a Relay function."; + + auto sid = GetExtSymbol(function); + func_names_.push_back(sid); + + CodegenCutlass builder(sid, AnalyzeVar2Value(function)); + auto out = builder.VisitExpr(function->body); + return {builder.JIT(out), builder.GetHeaders()}; + } + + /*! \brief The accumulated function names. */ + Array func_names_; +}; + +Array CUTLASSCompiler(Array functions, Map options, + Map /*unused*/) { + const auto* tune_func = runtime::Registry::Get("contrib.cutlass.tune_relax_function"); + ICHECK(tune_func != nullptr) + << "The packed function contrib.cutlass.tune_relax_function not found, " + "please import tvm.contrib.cutlass.build"; + + Array annotated_functions = (*tune_func)(functions, options); + + auto source_mod = CutlassModuleCodegen().CreateCSourceModule(annotated_functions, options); + const auto* pf = runtime::Registry::Get("contrib.cutlass.compile"); + ICHECK(pf != nullptr) << "The packed function contrib.cutlass.compile not found, please import " + "tvm.contrib.cutlass.build"; + runtime::Module cutlass_mod = (*pf)(source_mod, options); + + return {cutlass_mod}; +} + +TVM_REGISTER_GLOBAL("relax.ext.cutlass").set_body_typed(CUTLASSCompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/dnnl/codegen.cc b/src/relax/backend/contrib/dnnl/codegen.cc new file mode 100644 index 000000000000..3cbf4cfa2ace --- /dev/null +++ b/src/relax/backend/contrib/dnnl/codegen.cc @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/dnnl/codegen.cc + * \brief Implementation of the DNNL JSON serializer. + */ +#include + +#include + +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONSerializer = backend::contrib::JSONSerializer; +using backend::contrib::NodeEntries; + +class DNNLJSONSerializer : public JSONSerializer { + public: + DNNLJSONSerializer(Map constant_names, Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + + using JSONSerializer::VisitExpr_; + + NodeEntries VisitExpr_(const CallNode* call_node) final { + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + ICHECK(fn.defined()) << "Expects the callee to be a function."; + + auto composite_opt = fn->GetAttr(attr::kComposite); + ICHECK(composite_opt.defined()) << "Only composite functions are supported."; + + std::string composite_name = composite_opt.value(); + + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(composite_name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + + const CallNode* root_call = nullptr; + if (composite_name.find("conv2d") != std::string::npos) { + root_call = backend::GetOpInFunction(fn, "relax.nn.conv2d"); + } else { + LOG(FATAL) << "Unimplemented pattern: " << composite_name; + } + + SetCallNodeAttribute(node, root_call); + return AddNode(node, GetRef(call_node)); + } + + private: + /*! \brief The bindings to look up composite functions. */ + Map bindings_; +}; + +Array DNNLCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; + + for (const auto& func : functions) { + DNNLJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + auto graph_json = serializer.GetJSON(); + auto constant_names = serializer.GetConstantNames(); + const auto* pf = runtime::Registry::Get("runtime.DNNLJSONRuntimeCreate"); + ICHECK(pf != nullptr) << "Cannot find DNNL runtime module create function."; + auto func_name = GetExtSymbol(func); + compiled_functions.push_back((*pf)(func_name, graph_json, constant_names)); + } + + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.dnnl").set_body_typed(DNNLCompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc new file mode 100644 index 000000000000..5ce6bf5e7d42 --- /dev/null +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/tensorrt/codegen.cc + * \brief Implementation of the TensorRT JSON serializer. + */ +#include +// TODO(sunggg): add operator attribute when it's ready +// #include +#include + +#include +#include +#include + +#include "../../../transform/utils.h" +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +#if TVM_GRAPH_EXECUTOR_TENSORRT +#include "NvInfer.h" +#endif + +namespace tvm { +namespace relax { +namespace contrib { + +/*! \brief Attributes to store the compiler options for TensorRT. */ +struct TensorRTCompilerConfigNode : public tvm::AttrsNode { + Array tensorrt_version; + bool use_implicit_batch; + size_t max_workspace_size; + bool remove_no_mac_subgraphs; + bool use_fp16; + bool use_uint8; + + TVM_DECLARE_ATTRS(TensorRTCompilerConfigNode, "relax.ext.attrs.TensorRTCompilerConfigNode") { + TVM_ATTR_FIELD(tensorrt_version) + .describe("TensorRT version as (major, minor, patch).") + .set_default(Array({6, 0, 1})); + TVM_ATTR_FIELD(use_implicit_batch).set_default(true); + TVM_ATTR_FIELD(max_workspace_size).set_default(size_t(1) << 30); + TVM_ATTR_FIELD(remove_no_mac_subgraphs).set_default(false); + TVM_ATTR_FIELD(use_fp16).set_default(false); + TVM_ATTR_FIELD(use_uint8).set_default(false); + } +}; + +class TensorRTCompilerConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorRTCompilerConfig, Attrs, + TensorRTCompilerConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("relax.ext.tensorrt.options", TensorRTCompilerConfig); + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr; +using OpAttrExtractor = backend::contrib::OpAttrExtractor; +using JSONSerializer = backend::contrib::JSONSerializer; + +class TensorRTJSONSerializer; + +/*! + * \brief Collect the constants and attributes from all operator calls in the body + * of a "Composite" function. + */ +class CollectFromCompositeFunctionBody : public ExprVisitor { + public: + explicit CollectFromCompositeFunctionBody(TensorRTJSONSerializer* serializer) + : serializer_(serializer), node_(std::make_shared()) {} + + void VisitExpr_(const ConstantNode* constant_node) final; + void VisitExpr_(const CallNode* call_node) final; + + void SetGenericAttributes(const CallNode* call_node) { + OpAttrExtractor extractor(node_); + const Object* attr_obj = call_node->attrs.get(); + extractor.Extract(const_cast(attr_obj)); + } + + TensorRTJSONSerializer* serializer_; + /*! \brief Accumulated translated arguments. */ + std::vector args_; + /*! + * \brief Temporary node into which we'll accumulate attributes. Ideally this would be the + * final JSONGraphNode however we don't yet know how many inputs that will have. + */ + JSONGraphObjectPtr node_; +}; + +/*! + * \brief Generates an TensorRTModule from a relax expression by serializing the expression to a + * json representation. TensorRT is not required here because use of TensorRT APIs is deferred until + * runtime. + */ +class TensorRTJSONSerializer : public JSONSerializer { + public: + explicit TensorRTJSONSerializer(Map constant_names, Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + + using JSONSerializer::VisitExpr_; + + std::vector VisitExpr_(const CallNode* call_node) final { + // The call must be to an inline "Composite" function + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + + auto opt_composite = fn->GetAttr(attr::kComposite); + ICHECK(opt_composite.defined()); + std::string name = opt_composite.value(); + + // Collect the constants and attributes of all operator calls inside the composite body. + CollectFromCompositeFunctionBody collector(this); + collector.VisitExpr(fn->body); + + // Capture the args to the "Composite" function as inputs for this node. + std::vector inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + + // Capture constants from the composite function body as additional inputs for this node. + for (const auto& node : collector.args_) { + inputs.emplace_back(node); + } + + // Create the final node. + auto node = std::make_shared(name, + /*op_type=*/"kernel", inputs, + /*num_output=*/1); + + // Transfer attributes from the collector's node to the final node. + node->CaptureAttrs(*collector.node_); + + // Capture global settings on the JSON node. + SaveGlobalAttributes(node); + + VLOG(1) << name << " has " << node->GetInputs().size() << " inputs"; + + return AddNode(node, GetRef(call_node)); + } + + static void SaveGlobalAttributes(std::shared_ptr node) { + auto ctx = transform::PassContext::Current(); + auto cfg = ctx->GetConfig("relax.ext.tensorrt.options"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + ICHECK_EQ(cfg.value()->tensorrt_version.size(), 3); + std::vector tensorrt_version = { + std::to_string(cfg.value()->tensorrt_version[0].IntValue()), + std::to_string(cfg.value()->tensorrt_version[1].IntValue()), + std::to_string(cfg.value()->tensorrt_version[2].IntValue())}; + std::vector use_implicit_batch = {std::to_string(cfg.value()->use_implicit_batch)}; + std::vector max_workspace_size = {std::to_string(cfg.value()->max_workspace_size)}; + std::vector use_fp16 = {std::to_string(cfg.value()->use_fp16)}; + std::vector use_uint8 = {std::to_string(cfg.value()->use_uint8)}; + std::vector tensorrt_version_attr, use_implicit_batch_attr, max_workspace_size_attr, + use_fp16_attr, use_uint8_attr; + tensorrt_version_attr.emplace_back(tensorrt_version); + use_implicit_batch_attr.emplace_back(use_implicit_batch); + max_workspace_size_attr.emplace_back(max_workspace_size); + use_fp16_attr.emplace_back(use_fp16); + use_uint8_attr.emplace_back(use_uint8); + node->SetAttr("tensorrt_version", tensorrt_version_attr); + node->SetAttr("use_implicit_batch", use_implicit_batch_attr); + node->SetAttr("max_workspace_size", max_workspace_size_attr); + node->SetAttr("use_fp16", use_fp16_attr); + node->SetAttr("use_uint8", use_uint8_attr); + } + + private: + /*! \brief The bindings to look up composite functions. */ + Map bindings_; +}; + +void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) { + for (const auto& entry : serializer_->VisitExpr(GetRef(constant_node))) { + args_.emplace_back(entry); + } +} + +void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { + SetGenericAttributes(call_node); + ExprVisitor::VisitExpr_(call_node); +} + +/*! + * \brief Create runtime modules for TensorRT. + * \param functions The extern functions to be compiled via TensorRT + * \return Runtime modules. + */ +Array TensorRTCompiler(Array functions, + Map /*unused*/, + Map constant_names) { + Array compiled_functions; + for (const auto& func : functions) { + VLOG(1) << "TensorRT partition:" << std::endl << func; + TensorRTJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + std::string graph_json = serializer.GetJSON(); + VLOG(1) << "TensorRT JSON:" << std::endl << graph_json; + auto constant_names = serializer.GetConstantNames(); + const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create"); + ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function."; + std::string func_name = GetExtSymbol(func); + VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'"; + compiled_functions.push_back((*pf)(func_name, graph_json, constant_names)); + } + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.tensorrt").set_body_typed(TensorRTCompiler); + +/*! + * \brief Check whether TensorRT graph executor is enabled. + * \return True if enabled, False if not. + */ +inline constexpr bool IsTensorRTRuntimeEnabled() { +#if TVM_GRAPH_EXECUTOR_TENSORRT + return true; +#else + return false; +#endif // TVM_GRAPH_EXECUTOR_TENSORRT +} + +/*! + * \brief Get TensorRT version that TVM is built against. + * \return Array of three integers for major, minor, and patch, or empty array if TensorRT graph + * runtime is not enabled. + */ +Array GetTensorRTVersion() { +#if TVM_GRAPH_EXECUTOR_TENSORRT + return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR), Integer(NV_TENSORRT_PATCH)}; +#else + return {}; +#endif // TVM_GRAPH_EXECUTOR_TENSORRT +} + +TVM_REGISTER_GLOBAL("relax.is_tensorrt_runtime_enabled").set_body_typed(IsTensorRTRuntimeEnabled); +TVM_REGISTER_GLOBAL("relax.get_tensorrt_version").set_body_typed(GetTensorRTVersion); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h new file mode 100644 index 000000000000..4190ad66b6df --- /dev/null +++ b/src/relax/backend/contrib/utils.h @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/backend/contrib/utils.h + * \brief Utils function for backend + */ +#ifndef TVM_RELAX_BACKEND_CONTRIB_UTILS_H_ +#define TVM_RELAX_BACKEND_CONTRIB_UTILS_H_ + +#include +#include + +#include +#include + +#include "../../transform/utils.h" + +namespace tvm { +namespace relax { +namespace backend { + +/*! + * \brief Get the Packed Func + * + * \param func_name + * \return const PackedFunc* + */ +inline const PackedFunc* GetPackedFunc(const std::string& func_name) { + return tvm::runtime::Registry::Get(func_name); +} + +/*! + * \brief Extract shape from an IndexExpr array to std::vector + * + * \param shape The shape in Array + * \return The converted shape in std::vector + */ + +inline std::vector GetIntShape(const Array& shape) { + std::vector ret; + for (const auto& dim : shape) { + const int64_t* pval = tir::as_const_int(dim); + ret.push_back(pval ? *pval : -1); + } + return ret; +} + +/*! + * \brief Convert type to string + * + * \param typ + * \return std::string string format of type + */ +inline std::string DType2String(const tvm::DataType dtype) { + std::ostringstream os; + if (dtype.is_float()) { + os << "float"; + } else if (dtype.is_int()) { + os << "int"; + } else if (dtype.is_uint()) { + os << "uint"; + } else if (dtype.is_bfloat16()) { + os << "bfloat"; + } else if ((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) { + os << "custom[" + << (*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator std::string() + << "]"; + } else { + LOG(FATAL) << "Unknown type with code " << static_cast(dtype.code()); + } + os << dtype.bits(); + return os.str(); +} + +/*! + * \brief Check if a call node is calling an op with the given name + * \param call The call node whose callee we want to check + * \param op_name The name of the op + * \return true if the callee op matches with the op name + */ +inline bool IsOp(const CallNode* call, const std::string& op_name) { + const auto* op_node = call->op.as(); + if (!op_node) return false; + Op op = GetRef(op_node); + return op == Op::Get(op_name); +} + +/*! + * \brief Return a call node within the function which calls an op with the given name + * The function must contain exactly one call to such op. + * \param f The function to look for an op. + * \param op_name The name of the op + * \return A call node which calls an op with the given name + */ +inline const CallNode* GetOpInFunction(Function f, const std::string& op_name) { + auto local_bindings = AnalyzeVar2Value(f); + for (const auto& entry : local_bindings) { + if (auto call = entry.second.as(); call && backend::IsOp(call, op_name)) { + return call; + } + } + LOG(FATAL) << op_name << " not found in the function:\n" << f; + return nullptr; +} + +} // namespace backend +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_CONTRIB_UTILS_H_ diff --git a/src/relax/backend/pattern_registry.cc b/src/relax/backend/pattern_registry.cc new file mode 100644 index 000000000000..34ebb4d6ddbf --- /dev/null +++ b/src/relax/backend/pattern_registry.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./pattern_registry.h" + +#include "../../support/utils.h" + +namespace tvm { +namespace relax { +namespace backend { +static std::vector* GetRegistryTable() { + static std::vector table; + return &table; +} + +void RegisterPatterns(Array entries) { + auto* table = GetRegistryTable(); + for (const auto& entry : entries) { + table->push_back(entry); + } +} + +void RemovePatterns(Array names) { + std::unordered_set name_set{names.begin(), names.end()}; + + auto* table = GetRegistryTable(); + table->erase( + std::remove_if(table->begin(), table->end(), + [&](const FusionPattern& entry) { return name_set.count(entry->name) > 0; }), + table->end()); +} + +Array GetPatternsWithPrefix(const String& prefix) { + auto* table = GetRegistryTable(); + Array result; + for (auto it = table->rbegin(); it != table->rend(); ++it) { + if (support::StartsWith((*it)->name, prefix.data())) { + result.push_back(*it); + } + } + return result; +} + +Optional GetPattern(const String& pattern_name) { + auto* table = GetRegistryTable(); + for (auto it = table->rbegin(); it != table->rend(); ++it) { + if ((*it)->name == pattern_name) { + return *it; + } + } + return NullOpt; +} + +TVM_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns); +TVM_REGISTER_GLOBAL("relax.backend.RemovePatterns").set_body_typed(RemovePatterns); +TVM_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix").set_body_typed(GetPatternsWithPrefix); +TVM_REGISTER_GLOBAL("relax.backend.GetPattern").set_body_typed(GetPattern); + +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/pattern_registry.h b/src/relax/backend/pattern_registry.h new file mode 100644 index 000000000000..72eea1238d38 --- /dev/null +++ b/src/relax/backend/pattern_registry.h @@ -0,0 +1,73 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/backend/contrib/pattern_registry.h + * \brief Functions related to registering and retrieving patterns for + * functions handled by backends. + */ +#ifndef TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_ +#define TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace backend { + +using transform::FusionPattern; + +/*! + * \brief Register patterns which will be used to partition the DataflowBlock + * into subgraphs that are supported by external backends. + * \param patterns Patterns to be registered. Patterns that appear later in the list have + * higher priority when partitioning DataflowBlock. + */ +void RegisterPatterns(Array patterns); + +/*! + * \brief Remove patterns from the registry by their name. + * \param names The name of patterns to be removed + */ +void RemovePatterns(Array names); + +/*! + * \brief Find patterns whose name starts with a particular prefix. + * \param prefx The pattern name prefix. + * \return Matched patterns, ordered by priority from high to low. + */ +Array GetPatternsWithPrefix(const String& prefix); + +/*! + * \brief Find the pattern with a particular name. + * \param name The pattern name. + * \return The matched pattern. NullOpt if not found. + */ +Optional GetPattern(const String& name); + +} // namespace backend +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_ diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc new file mode 100644 index 000000000000..5bd764c68e78 --- /dev/null +++ b/src/relax/backend/task_extraction.cc @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace backend { + +using tvm::meta_schedule::ExtractedTask; + +/*! + * \brief Extract the Meta-Schedule tuning task from a given IRModule. + * \note + * 1. The task extractor is responsible for task deduplication. The + * deduplication is achieved by comparing structural hashes of PrimFuncs. + * 2. For a PrimFunc, the weight of its corresponding task is the number + * of times it called by op Call-TIR. Say in an IRModule there are three + * PrimFuncs `fn1`, `fn2` and `fn3` sharing the same structural hash. + * Suppose `fn1` is called by 5 Call-TIR ops among all Relax function, + * `fn2` is called by 3 Call-TIR and `fn3` is called by 5 Call-TIR. + * Then we will have a ExtractedTask for all three functions, whose weight + * is 5 + 3 + 2 = 10. + */ +class TaskExtractor : public ExprVisitor { + public: + static Array ExtractTask(IRModule mod, Target target) { + TaskExtractor extractor(mod, target); + // We go through each Relax function in the module. + for (const auto& kv : mod->functions) { + if (const auto* func = kv.second.as()) { + extractor(GetRef(func)); + } + } + return std::move(extractor.tasks_); + } + + private: + explicit TaskExtractor(IRModule mod, Target target) + : mod_(std::move(mod)), target_(std::move(target)) { + normalize_mod_func_ = runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); + ICHECK(normalize_mod_func_) << "Normalization function is not found."; + } + + void VisitExpr_(const CallNode* call) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + + // TODO(@tvm-team): When we differentiate the call for tir function and packed function, + // this logic should be changed accordingly. + if (!call->op.same_as(call_tir_op)) { + // Since the Relax function is of A-normal form, the arguments of this call cannot be another + // Calls. And hence we do not need to recurse into this Call. + return; + } + + const GlobalVar& global_var = Downcast(call->args[0]); + const tir::PrimFunc& func = Downcast(mod_->Lookup(global_var)); + + auto it = func2task_.find(func); + if (it != func2task_.end()) { + it->second->weight += 1; + return; + } + + IRModule tir_mod = (*normalize_mod_func_)(func); + ExtractedTask task(/*task_name=*/global_var->name_hint, // + /*mod=*/tir_mod, // + /*target=*/target_, // + /*dispatched=*/{tir_mod}, // + /*weight=*/1); + tasks_.push_back(task); + func2task_.emplace(func, task); + } + + IRModule mod_; + Target target_; + Array tasks_; + std::unordered_map func2task_; + const runtime::PackedFunc* normalize_mod_func_; +}; + +TVM_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask") + .set_body_typed([](IRModule mod, Target target) { + return TaskExtractor::ExtractTask(std::move(mod), std::move(target)); + }); + +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc new file mode 100644 index 000000000000..b36b5ed4d6c6 --- /dev/null +++ b/src/relax/backend/vm/codegen_vm.cc @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/vm/codegen_vm.cc + * \brief A codegen to generate VM executable from a Relax IRModule. + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../../target/metadata_module.h" +#include "../../../target/source/codegen_source_base.h" + +namespace tvm { +namespace relax { +namespace relax_vm { + +using tvm::Target; +using namespace relax; +using namespace tvm::runtime; +using namespace tvm::runtime::relax_vm; + +// Helper function to get the function name of the registered packed function implementation of +// relax operator. +FCallPacked GetPackedFuncName(const Call& call) { + static auto op_map = Op::GetAttrMap("FCallPacked"); + if (call->op.as()) { + Op op = Downcast(call->op); + if (op_map.count(op)) { + return op_map[op]; + } + } + return {}; +} + +/*! + * \brief A class to generate VM executable for Relax functions. + */ +class CodeGenVM : public ExprFunctor { + public: + explicit CodeGenVM(relax::ExecBuilder builder, IRModule ctx_mod) + : builder_(builder), ctx_mod_(ctx_mod) {} + + static IRModule Run(relax::ExecBuilder builder, IRModule mod) { + IRModule res_mod = IRModule(Map()); + CodeGenVM codegen(builder, mod); + // Remove relax function and turn into TIR func. + for (auto& p : mod->functions) { + if (auto* func = p.second.as()) { + codegen.Codegen(GetRef(func)); + } else { + res_mod->Add(p.first, p.second); + } + } + return res_mod; + } + + protected: + size_t NewRegister() { return registers_num_++; } + + // Convert Arg value to a register, trigger copy if needed + Instruction::Arg EnsureReg(Instruction::Arg arg) { + if (arg.kind() == Instruction::ArgKind::kRegister) { + return arg; + } else { + RegName dst_reg = NewRegister(); + builder_->EmitCall("vm.builtin.copy", {arg}, dst_reg); + return Instruction::Arg::Register(dst_reg); + } + } + + void Codegen(const Function& func) { + Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. " + "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; + + Array param_names; + for (Var param : func->params) { + param_names.push_back(param->name_hint()); + } + + builder_->EmitFunction(gsymbol.value(), func->params.size(), param_names); + + for (size_t i = 0; i < func->params.size(); ++i) { + RegName r = NewRegister(); + ICHECK_EQ(r, static_cast(i)); + this->var_arg_map_.insert({func->params[i], Instruction::Arg::Register(r)}); + } + Instruction::Arg ret = ExprFunctor::VisitExpr(func->body); + builder_->EmitRet(EnsureReg(ret)); + builder_->EndFunction(gsymbol.value()); + // reset register number to be 0; + registers_num_ = 0; + var_arg_map_.clear(); + } + + Instruction::Arg VisitExpr_(const SeqExprNode* op) final { + for (auto block : op->blocks) { + for (Binding binding : block->bindings) { + Instruction::Arg value; + if (auto* var_binding = binding.as()) { + value = this->VisitExpr(var_binding->value); + } else if (auto* match_cast = binding.as()) { + value = this->VisitExpr(match_cast->value); + } else { + LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); + } + this->var_arg_map_.insert({binding->var, value}); + } + } + + Instruction::Arg ret_reg = this->VisitExpr(op->body); + return ret_reg; + } + + Instruction::Arg VisitExpr_(const CallNode* call_node) final { + Call call = GetRef(call_node); + + if (call_node->op == null_value_op_) { + return Instruction::Arg::Register(Instruction::kVoidRegister); + } + + // allocate dst register. + RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : NewRegister(); + if (call->op.as()) { + // special case generate for the intrinsics whose attribute fields + // cannot be represented by args in the CallNode + FCallPacked name = GetPackedFuncName(call); + if (!name.empty()) { + // If the operator has a registered packed function implementation, emit call to that packed + // function. + EmitPackedFuncCall(call, name, dst_reg); + } else if (call_node->op == call_builtin_with_ctx_op_) { + // TODO(relax-team) migrate most handling of op to + // directly map to call_builtin_with_ctx before codegen and simplify vm codegen. + EmitCallBuiltinWithCtx(call, dst_reg); + } else if (call_node->op == alloc_storage_op_) { + EmitAllocStorage(call, dst_reg); + } else if (call_node->op == alloc_tensor_op_) { + EmitAllocTensor(call, dst_reg); + } else { + // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those + // ops are handled in a pass when lowering them to TIR. + LOG(FATAL) << "CodeGenVM cannot handle this intrinsic now:\n" << call_node->op; + } + } else { + EmitNormalCall(call, dst_reg); + } + return Instruction::Arg::Register(dst_reg); + } + + Instruction::Arg VisitExpr_(const IfNode* op) final { + const If& ife = GetRef(op); + Instruction::Arg cond_value = this->VisitExpr(ife->cond); + + // Reserve a register for cond + RegName cond_reg = NewRegister(); + builder_->EmitCall("vm.builtin.read_if_cond", {cond_value}, cond_reg); + + // obtain the temp exec in progress. + vm::Executable* exec = builder_->exec(); + + // Record the offset of If instruction + size_t if_offset = exec->instr_offset.size(); + + builder_->EmitIf(Instruction::Arg::Register(cond_reg), 3); + size_t num_instr = exec->instr_offset.size(); + Instruction::Arg true_value = this->VisitExpr(ife->true_branch); + // Reserve a register for return + size_t merge_register = NewRegister(); + // Copy the output from true branch to merge register + builder_->EmitCall("vm.builtin.copy", {true_value}, merge_register); + + // Record the offset of Goto instruction + size_t goto_offset = exec->instr_offset.size(); + + builder_->EmitGoto(1); + + // Calculate the false offset of If + size_t false_offset = exec->instr_offset.size() - num_instr + 1; + + Instruction::Arg false_value = this->VisitExpr(ife->false_branch); + // Copy the output data of false branch to merge register + builder_->EmitCall("vm.builtin.copy", {false_value}, merge_register); + + // Update the offsets of the If instruction emitted above + // Jump to the behind of the next goto instruction + exec->SetInstructionData(if_offset, 2, static_cast(false_offset)); + // Update the pc_offset of Goto instruction + // Jump over the false branch + size_t pc_offset = exec->instr_offset.size() - goto_offset; + exec->SetInstructionData(goto_offset, 1, static_cast(pc_offset)); + return Instruction::Arg::Register(merge_register); + } + + Instruction::Arg VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + auto it = this->var_arg_map_.find(var); + ICHECK(it != this->var_arg_map_.end()) << "Var " << var << " is not defined"; + return it->second; + } + + Instruction::Arg VisitExpr_(const ConstantNode* op) final { + return builder_->ConvertConstant(op->data); + } + + Instruction::Arg VisitExpr_(const ShapeExprNode* op) final { + std::vector shape; + for (PrimExpr e : op->values) { + if (auto* int_value = e.as()) { + shape.push_back(int_value->value); + } else { + LOG(FATAL) << "Should only use constant shape after shape lowering: " << op->values; + } + } + return builder_->ConvertConstant(ShapeTuple(shape)); + } + + Instruction::Arg VisitExpr_(const PrimValueNode* op) final { + if (auto* int_imm = op->value.as()) { + return builder_->ConvertConstant(int_imm->value); + } else { + auto* float_imm = op->value.as(); + ICHECK(float_imm) << "PrimValue can only be IntImm/FloatImm for now"; + return builder_->ConvertConstant(float_imm->value); + } + } + + Instruction::Arg VisitExpr_(const StringImmNode* op) final { + return builder_->ConvertConstant(op->value); + } + + Instruction::Arg VisitExpr_(const DataTypeImmNode* op) final { + return builder_->ConvertConstant(op->value); + } + + Instruction::Arg VisitExpr_(const TupleNode* op) final { + Tuple tuple = GetRef(op); + std::vector args; + for (Expr arg : tuple->fields) { + args.push_back(this->VisitExpr(arg)); + } + size_t dst_register = NewRegister(); + builder_->EmitCall("vm.builtin.make_tuple", args, dst_register); + + return Instruction::Arg::Register(dst_register); + } + + Instruction::Arg VisitExpr_(const TupleGetItemNode* op) final { + TupleGetItem expr = GetRef(op); + std::vector args = {this->VisitExpr(expr->tuple)}; + + args.push_back(builder_->ConvertConstant(expr->index)); + + size_t dst_register = NewRegister(); + builder_->EmitCall("vm.builtin.tuple_getitem", args, dst_register); + + return Instruction::Arg::Register(dst_register); + } + + Instruction::Arg VisitExpr_(const GlobalVarNode* op) final { + GlobalVar gvar = GetRef(op); + Optional symbol; + VMFuncInfo::FuncKind kind = VMFuncInfo::FuncKind::kPackedFunc; + + // Run a look up in the env to see if it maps to an extern func. + auto it = ctx_mod_->functions.find(gvar); + if (it != ctx_mod_->functions.end()) { + BaseFunc func = (*it).second; + if (auto* efunc = func.as()) { + symbol = efunc->global_symbol; + kind = VMFuncInfo::FuncKind::kPackedFunc; + } else if (func.as()) { + symbol = gvar->name_hint; + kind = VMFuncInfo::FuncKind::kVMFunc; + } + } + // GlobalVar can be reference to a Relax function or a TIR primfunc + // At this point: all global var must corresponds to the right symbol. + // TODO(relax-team): switch everything to extern before splitting TIR/relax + // so we do not have idle global var here. + if (!symbol.defined()) { + symbol = gvar->name_hint; + kind = VMFuncInfo::FuncKind::kPackedFunc; + } + // declare the function to be safe. + ICHECK(symbol.defined()); + builder_->DeclareFunction(symbol.value(), kind); + return builder_->GetFunction(symbol.value()); + } + + Instruction::Arg VisitExpr_(const ExternFuncNode* op) final { + static const constexpr char* kCSource = "c_source"; + static const constexpr char* kCSourceFmt = "c_source_fmt"; + if (Optional opt_code = op->attrs.GetAttr(kCSource)) { + String sym = op->global_symbol; + String fmt = op->attrs.GetAttr(kCSourceFmt).value_or("c"); + String code = opt_code.value(); + Module c_source_module = + codegen::CSourceModuleCreate(/*code=*/code, /*fmt=*/fmt, /*func_names=*/{sym}, + /*const_vars=*/{}); + builder_->exec()->Import(c_source_module); + } + builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc); + return builder_->GetFunction(op->global_symbol); + } + + void EmitAllocStorage(const Call& call_node, RegName dst_reg) { + ICHECK_EQ(call_node->args.size(), 3); + // Handle args of the call + std::vector args; + args.push_back(Instruction::Arg::Register(Instruction::kVMRegister)); + // buffer size, dtype, device index + for (auto arg : call_node->args) { + args.push_back(this->VisitExpr(arg)); + } + builder_->EmitCall("vm.builtin.alloc_storage", args, dst_reg); + } + + void EmitAllocTensor(const Call& call_node, RegName dst_reg) { + ICHECK_EQ(call_node->args.size(), 4); + std::vector args; + args.reserve(4); + for (Expr arg : call_node->args) { + args.push_back(this->VisitExpr(arg)); + } + builder_->EmitCall("vm.builtin.alloc_tensor", args, dst_reg); + } + + void EmitCallBuiltinWithCtx(const Call& call_node, RegName dst_reg) { + std::vector args; + args.push_back(Instruction::Arg::Register(Instruction::kVMRegister)); + + auto func = this->VisitExpr(call_node->args[0]); + auto tuple_arg = Downcast(call_node->args[1]); + + // Handle args of the call + for (Expr arg : tuple_arg->fields) { + args.push_back(this->VisitExpr(arg)); + } + + builder_->EmitCall(func, args, dst_reg); + } + + void EmitNormalCall(const Call& call_node, RegName dst_reg) { + Instruction::Arg func = VisitExpr(call_node->op); + std::vector args = VisitArray(call_node->args); + builder_->EmitCall(func, args, dst_reg); + } + + // Emits call to packed function `name` with arguments copied over from `call_node` args + void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name, RegName dst_reg) { + std::vector args = VisitArray(call_node->args); + builder_->EmitCall(name, args, dst_reg); + } + + std::vector VisitArray(const Array& arr) { + std::vector ret; + for (size_t i = 0; i < arr.size(); ++i) { + ret.push_back(this->VisitExpr(arr[i])); + } + return ret; + } + + /*! \brief Internal ExecBuilder. */ + relax::ExecBuilder builder_; + /*! + * \brief Total number of virtual registers allocated. + * \note The first two registers are reserved for special registers. + */ + size_t registers_num_ = 0; + /*! \brief Map from var to register number. */ + std::unordered_map var_arg_map_; + /*! \brief the context module. */ + IRModule ctx_mod_; + /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */ + const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage"); + const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor"); + const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); + const Op& null_value_op_ = Op::Get("relax.null_value"); +}; + +/*! + * \brief Create the Relax VM executable from all relax.Function in mod. + * and add them to exec_builder. + * \param exec_builder Builder to collect executables. + * \param mod Input module. + * \return Left over IRModule that may contain otehr functions. + */ +IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) { + return CodeGenVM::Run(exec_builder, mod); +} + +TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen); + +/*! + * \brief Link the libaries together. + */ +Module VMLink(ExecBuilder builder, Target target, Optional lib, Array ext_libs, + Map params) { + // TODO(relax-team) Revisit the param and ext_lib options. + ObjectPtr executable = builder->Get(); + if (!lib.defined()) { + lib = codegen::CSourceModuleCreate(";", "", Array{}); + } + std::unordered_map conv_params; + for (const auto& [name, param] : params) { + conv_params[name] = param; + } + Module combined_lib = codegen::CreateMetadataModule( + conv_params, lib.value(), ext_libs, target, + + // TODO(@sunggg): Currently, CRT uses relay-specific executor for uTVM support. + // Before jumping into details, only support cpp runtime for now. + relay::Runtime::Create("cpp"), + relay::Executor::Create("graph"), // TODO(@sunggg): pass arbitrarily executor. CPP runtime + // won't use this anyways. + relay::backend::ExecutorCodegenMetadata()); + executable->Import(combined_lib); + return Module(executable); +} + +TVM_REGISTER_GLOBAL("relax.VMLink").set_body_typed(VMLink); + +} // namespace relax_vm +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc new file mode 100644 index 000000000000..2f63a50d370f --- /dev/null +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -0,0 +1,511 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/vm/codegen_tir.cc + * \brief A codegen to generate VMTIR function(that can be compiled) from executable. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace relax_vm { + +using vm::VMFuncInfo; + +/*! + * \brief A class to generate VMTIR for Relax functions. + * + * \note Skip CallPacked with special attrs for now, as they can be + * further simplified with PrimValue. + */ +class CodeGenVMTIR : public ExprFunctor(const Expr&)> { + public: + explicit CodeGenVMTIR(relax::ExecBuilder builder, IRModule ctx_mod) + : builder_(builder), ctx_mod_(ctx_mod) {} + + static IRModule Run(relax::ExecBuilder builder, IRModule mod) { + // create a new copy + IRModule res_mod = mod; + res_mod.CopyOnWrite(); + + CodeGenVMTIR codegen(builder, mod); + // Remove relax function and turn into TIR func. + for (auto& p : mod->functions) { + if (auto* func = p.second.as()) { + auto tir_func = codegen.Codegen(GetRef(func)); + auto gsymbol = tir_func->GetAttr(tvm::attr::kGlobalSymbol); + res_mod->Add(GlobalVar(gsymbol.value()), tir_func); + res_mod->Remove(p.first); + } + } + return res_mod; + } + + private: + int64_t NewRegister() { return registers_num_++; } + + static IntImm ConstInt64(int64_t value) { return IntImm(DataType::Int(64), value); } + + static IntImm ConstInt32(int64_t value) { return IntImm(DataType::Int(32), value); } + + PrimExpr RegListGet(int64_t slot) const { + // use 128 bits to represent any + return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(), + {reg_anylist_handle_, ConstInt32(slot)}); + } + + PrimExpr ConstListGet(int64_t slot) const { + // use 128 bits to represent any + return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(), + {const_anylist_handle_, ConstInt32(slot)}); + } + + PrimExpr FuncListGet(int64_t slot) const { + // use 128 bits to represent any + return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(), + {func_anylist_handle_, ConstInt32(slot)}); + } + + void EmitStmt(tir::Stmt stmt) { + ICHECK(!stmt_stack_.empty()); + stmt_stack_.back().emplace_back(stmt); + } + + void EmitCallPacked(String name, const Array& args, int64_t dst_anylist_slot = -1) { + Array all_args; + // negative index indicate return value can be discarded, emit call_packed + if (dst_anylist_slot >= 0) { + all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; + } + all_args.push_back(tir::StringImm(name)); + for (PrimExpr arg : args) { + all_args.push_back(arg); + } + if (dst_anylist_slot >= 0) { + this->EmitStmt(tir::Evaluate( + tir::Call(DataType::Int(32), tir::builtin::anylist_setitem_call_packed(), all_args))); + } else { + this->EmitStmt( + tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), all_args))); + } + } + + void EmitCallCPacked(const tir::PrimFunc& prim_func, const Array& args, + int64_t dst_anylist_slot = -1) { + Optional gsymbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(gsymbol.defined()) << "All functions must have global symbol at this phase"; + Array all_args; + // negative index indicate return value can be discarded, emit call_packed + if (dst_anylist_slot >= 0) { + all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; + } + all_args.push_back(tir::StringImm(gsymbol.value())); + for (PrimExpr arg : args) { + all_args.push_back(arg); + } + // push an empty handle to be compatible with current cpacked convention + // TODO(tqchen): revisit C Packed convention + all_args.push_back(tir::make_zero(DataType::Handle())); + if (dst_anylist_slot >= 0) { + this->EmitStmt(tir::Evaluate( + tir::Call(DataType::Int(32), tir::builtin::anylist_setitem_call_cpacked(), all_args))); + } else { + this->EmitStmt( + tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_call_cpacked(), all_args))); + } + } + + tir::PrimFunc Codegen(const Function& func) { + Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. " + "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; + // initialize the state + stmt_stack_ = {}; + registers_num_ = 0; + var_map_.clear(); + ctx_ptr_ = tir::Var("ctx_ptr", DataType::Handle()); + reg_anylist_handle_ = tir::Var("r", DataType::Handle()); + func_anylist_handle_ = tir::Var("f", DataType::Handle()); + const_anylist_handle_ = tir::Var("c", DataType::Handle()); + + Array param_names; + for (Var param : func->params) { + param_names.push_back(param->name_hint()); + } + // declare this function. + builder_->DeclareFunction(gsymbol.value(), vm::VMFuncInfo::FuncKind::kVMTIRFunc); + + for (size_t i = 0; i < func->params.size(); ++i) { + int64_t r = NewRegister(); + ICHECK_EQ(static_cast(r), i); + this->var_map_.insert({func->params[i], RegListGet(r)}); + } + size_t ret_reg = NewRegister(); + + tir::Stmt body = WithNewScope([&]() { + Optional ret = ExprFunctor::VisitExpr(func->body); + if (ret.defined()) { + this->EmitCallPacked("vm.builtin.copy", {ret.value()}, ret_reg); + } + }); + + // Mark the function entry internally. + builder_->EmitFunction(gsymbol.value(), param_names.size(), param_names, + VMFuncInfo::FuncKind::kVMTIRFunc, registers_num_); + builder_->EndFunction(gsymbol.value()); + + Type ret_type = VoidType(); + Array tir_params = {ctx_ptr_, reg_anylist_handle_, const_anylist_handle_, + func_anylist_handle_}; + String tir_func_name = "__vmtir__" + gsymbol.value(); + tir::PrimFunc tir_func(tir_params, body, ret_type, {}); + tir_func = WithAttr(tir_func, "global_symbol", tir_func_name); + registers_num_ = 0; + var_map_.clear(); + stmt_stack_.clear(); + return tir_func; + } + + Optional VisitExpr_(const SeqExprNode* op) final { + for (auto block : op->blocks) { + for (Binding binding : block->bindings) { + Optional value; + if (auto* var_binding = binding.as()) { + value = this->VisitExpr(var_binding->value); + } else if (auto* match_cast = binding.as()) { + value = this->VisitExpr(match_cast->value); + } else { + LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); + } + this->var_map_.insert({binding->var, value}); + } + } + return this->VisitExpr(op->body); + } + + Optional VisitExpr_(const CallNode* call_node) final { + Call call = GetRef(call_node); + + if (call_node->op == null_value_op_) { + return tir::Call(DataType::Handle(), tir::builtin::reinterpret(), + {IntImm(DataType::Int(64), 0)}); + } + int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister(); + if (call->op.as()) { + if (call_node->op == call_builtin_with_ctx_op_) { + EmitCallBuiltinWithCtx(call, dst_reg); + } else if (call_node->op == alloc_storage_op_) { + EmitAllocStorage(call, dst_reg); + } else if (call_node->op == alloc_tensor_op_) { + EmitAllocTensor(call, dst_reg); + } else { + // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those + // ops are handled in a pass when lowering them to TIR. + LOG(FATAL) << "CodeGenVMTIR cannot handle this intrinsic now:\n" << call_node->op; + } + } else { + EmitNormalCall(call, dst_reg); + } + if (dst_reg >= 0) { + return RegListGet(dst_reg); + } else { + return NullOpt; + } + } + + Optional VisitExpr_(const IfNode* op) final { + // Reserve a register for return + size_t merge_register = NewRegister(); + PrimExpr cond_value = this->VisitExpr(op->cond).value(); + + // turn ndarray cond value into scalar. + cond_value = tir::Cast(DataType::Bool(), + tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), + {tir::StringImm("vm.builtin.read_if_cond"), cond_value})); + + tir::Stmt true_branch = WithNewScope([&]() { + PrimExpr true_value = this->VisitExpr(op->true_branch).value(); + this->EmitCallPacked("vm.builtin.copy", {true_value}, merge_register); + }); + tir::Stmt false_branch = WithNewScope([&]() { + PrimExpr false_value = this->VisitExpr(op->false_branch).value(); + this->EmitCallPacked("vm.builtin.copy", {false_value}, merge_register); + }); + this->EmitStmt(tir::IfThenElse(cond_value, true_branch, false_branch)); + return RegListGet(merge_register); + } + + Optional VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + auto it = this->var_map_.find(var); + ICHECK(it != this->var_map_.end()) << "Var " << var << " is not defined"; + return it->second; + } + + Optional VisitExpr_(const ConstantNode* op) final { + return ConstListGet(builder_->ConvertConstant(op->data).value()); + } + + Optional VisitExpr_(const ShapeExprNode* op) final { + std::vector shape; + for (PrimExpr e : op->values) { + if (auto* int_value = e.as()) { + shape.push_back(int_value->value); + } else { + LOG(FATAL) << "Should only use constant shape after shape lowering: " << op->values; + } + } + return ConstListGet(builder_->ConvertConstant(ShapeTuple(shape)).value()); + } + + Optional VisitExpr_(const PrimValueNode* op) final { return op->value; } + + Optional VisitExpr_(const StringImmNode* op) final { + return ConstListGet(builder_->ConvertConstant(op->value).value()); + } + + Optional VisitExpr_(const DataTypeImmNode* op) final { + return ConstListGet(builder_->ConvertConstant(op->value).value()); + } + + Optional VisitExpr_(const TupleNode* op) final { + Tuple tuple = GetRef(op); + Array args; + for (auto arg : tuple->fields) { + args.push_back(this->VisitExpr(arg).value()); + } + int32_t dst_register = NewRegister(); + this->EmitCallPacked("vm.builtin.make_tuple", args, dst_register); + return RegListGet(dst_register); + } + + Optional VisitExpr_(const TupleGetItemNode* op) final { + TupleGetItem expr = GetRef(op); + Array args = {this->VisitExpr(expr->tuple).value()}; + + args.push_back(ConstInt64(expr->index)); + + int64_t dst_register = NewRegister(); + this->EmitCallPacked("vm.builtin.tuple_getitem", args, dst_register); + return RegListGet(dst_register); + } + + // Lookup the function and see if it matches + Optional LookupFunction(const Expr& expr, VMFuncInfo::FuncKind* kind) { + if (auto* ext_func = expr.as()) { + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return ext_func->global_symbol; + } else if (auto* gvar_ptr = expr.as()) { + GlobalVar gvar = GetRef(gvar_ptr); + // Run a look up in the env to see if it maps to an extern func. + auto it = ctx_mod_->functions.find(gvar); + if (it != ctx_mod_->functions.end()) { + BaseFunc func = (*it).second; + if (auto* efunc = func.as()) { + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return efunc->global_symbol; + } else if (func.as()) { + *kind = VMFuncInfo::FuncKind::kVMTIRFunc; + return gvar->name_hint; + } else if (func.as()) { + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return gvar->name_hint; + } else { + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return gvar->name_hint; + } + } + LOG(WARNING) << "Undefined global var " << gvar->name_hint; + // undefined global var, consider eliminate later. + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return gvar->name_hint; + } else { + return NullOpt; + } + } + // Lookup PrimFunc in the same module + // We can do direct PrimFunc call in such cases + Optional LookupPrimFunc(const String& name) { + if (!ctx_mod_->ContainGlobalVar(name)) return NullOpt; + + GlobalVar gvar = ctx_mod_->GetGlobalVar(name); + auto it = ctx_mod_->functions.find(gvar); + if (it != ctx_mod_->functions.end()) { + BaseFunc func = (*it).second; + if (auto* prim_func = func.as()) { + return GetRef(prim_func); + } + } + return NullOpt; + } + + Optional VisitExpr_(const GlobalVarNode* op) final { + VMFuncInfo::FuncKind kind; + auto symbol = LookupFunction(GetRef(op), &kind); + ICHECK(symbol.defined()); + builder_->DeclareFunction(symbol.value(), kind); + return FuncListGet(builder_->GetFunction(symbol.value()).value()); + } + + Optional VisitExpr_(const ExternFuncNode* op) final { + builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc); + return FuncListGet(builder_->GetFunction(op->global_symbol).value()); + } + + void EmitAllocStorage(const Call& call_node, int64_t dst_reg) { + // Handle args of the call + Array args; + args.push_back(ctx_ptr_); + for (Expr arg : call_node->args) { + args.push_back(this->VisitExpr(arg).value()); + } + this->EmitCallPacked("vm.builtin.alloc_storage", args, dst_reg); + } + + void EmitAllocTensor(const Call& call_node, int64_t dst_reg) { + ICHECK_EQ(call_node->args.size(), 4); + Array args; + args.reserve(4); + for (Expr arg : call_node->args) { + args.push_back(this->VisitExpr(arg).value()); + } + this->EmitCallPacked("vm.builtin.alloc_tensor", args, dst_reg); + } + + void EmitCallBuiltinWithCtx(const Call& call_node, int64_t dst_reg) { + Array args; + // if context is required, pass as first argument. + args.push_back(ctx_ptr_); + auto* func = call_node->args[0].as(); + ICHECK(func) << "CallBuiltin comes with extern func"; + + auto tuple_arg = Downcast(call_node->args[1]); + + // Handle args of the call + for (Expr arg : tuple_arg->fields) { + args.push_back(this->VisitExpr(arg).value()); + } + + this->EmitCallPacked(func->global_symbol, args, dst_reg); + } + + void EmitNormalCall(const Call& call_node, int64_t dst_reg) { + Array args = VisitArray(call_node->args); + // A function can be a closure that comes from parent + // Do call closure to be safe. + VMFuncInfo::FuncKind kind; + auto symbol = LookupFunction(call_node->op, &kind); + + if (symbol.defined() && kind == VMFuncInfo::FuncKind::kPackedFunc) { + // primfunc in the same module. + // use cpacked to directly invoke without named based lookup + if (Optional prim_func = LookupPrimFunc(symbol.value())) { + this->EmitCallCPacked(prim_func.value(), args, dst_reg); + } else { + this->EmitCallPacked(symbol.value(), args, dst_reg); + } + } else { + // Default path, leverage function table and invoke as closure + Array all_args; + all_args.push_back(ctx_ptr_); + all_args.push_back(this->VisitExpr(call_node->op).value()); + for (auto arg : args) { + all_args.push_back(arg); + } + this->EmitCallPacked("vm.builtin.invoke_closure", all_args, dst_reg); + } + } + + template + tir::Stmt WithNewScope(const FLambda& callback) { + stmt_stack_.push_back({}); + callback(); + tir::Stmt stmt = tir::SeqStmt::Flatten(stmt_stack_.back()); + stmt_stack_.pop_back(); + return stmt; + } + + Array VisitArray(const Array& arr) { + Array ret; + for (size_t i = 0; i < arr.size(); ++i) { + ret.push_back(this->VisitExpr(arr[i]).value()); + } + return ret; + } + /*! \brief Internal ExecBuilder. */ + relax::ExecBuilder builder_; + /*! \brief List to ctx_ptr */ + tir::Var ctx_ptr_; + /*! \brief List to store temp object registers */ + tir::Var reg_anylist_handle_; + /*! \brief List to store closures */ + tir::Var func_anylist_handle_; + /*! \brief List to store constants */ + tir::Var const_anylist_handle_; + /*! + * \brief Total number of virtual registers allocated. + * \note The first two registers are reserved for special registers. + */ + int64_t registers_num_ = 0; + /*! \brief Stack to build up statements */ + std::vector> stmt_stack_; + /*! \brief Map from var to Expr. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> var_map_; + /*! \brief the context module. */ + IRModule ctx_mod_; + /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */ + const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage"); + const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor"); + const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); + const Op& null_value_op_ = Op::Get("relax.null_value"); +}; + +/*! + * \brief Create the Relax VM executable from all relax.Function in mod. + * and add them to exec_builder. Create extra TIR functions. + * + * \param exec_builder Builder to collect executables. + * \param mod Input module. + * \return Extra TIR module created. + */ +IRModule VMTIRCodeGen(ExecBuilder exec_builder, IRModule mod) { + return CodeGenVMTIR::Run(exec_builder, mod); +} + +TVM_REGISTER_GLOBAL("relax.VMTIRCodeGen").set_body_typed(VMTIRCodeGen); + +} // namespace relax_vm +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc new file mode 100644 index 000000000000..b5d932137be0 --- /dev/null +++ b/src/relax/backend/vm/exec_builder.cc @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/vm/exec_builder.cc + */ +#include + +#include + +namespace tvm { +namespace relax { + +using namespace vm; + +TVM_REGISTER_NODE_TYPE(ExecBuilderNode); + +ExecBuilder ExecBuilderNode::Create() { + ExecBuilder ret(make_object()); + ret->exec_ = make_object(); + return ret; +} + +Executable* ExecBuilderNode::exec() const { return exec_.get(); } + +ObjectPtr ExecBuilderNode::Get() { + this->Formalize(); + this->CheckExecutable(); + return exec_; +} + +vm::Instruction::Arg ExecBuilderNode::ConvertConstant_(TVMRetValue cvalue) { + // emit constant immediate as immediate. + if (cvalue.type_code() == kDLInt) { + int64_t val = cvalue.operator int64_t(); + if (val <= vm::Instruction::kValueMaxLimit && val >= vm::Instruction::kValueMinLimit) { + return vm::Instruction::Arg::Immediate(val); + } + } + // convert string to object string + if (cvalue.type_code() == kTVMStr) { + cvalue = cvalue.operator String(); + } + + // run dedup for object with structural equality + if (cvalue.IsObjectRef()) { + ObjectRef obj = cvalue.operator ObjectRef(); + auto it = const_dedup_map_.find(obj); + if (it != const_dedup_map_.end()) { + return vm::Instruction::Arg::ConstIdx(it->second); + } + vm::Index idx = exec_->constants.size(); + exec_->constants.push_back(cvalue); + const_dedup_map_[obj] = idx; + return vm::Instruction::Arg::ConstIdx(idx); + } else { + // emit normal constant + vm::Index idx = exec_->constants.size(); + exec_->constants.push_back(cvalue); + return vm::Instruction::Arg::ConstIdx(idx); + } +} + +void ExecBuilderNode::DeclareFunction(const std::string& func_name, VMFuncInfo::FuncKind kind) { + auto it = exec_->func_map.find(func_name); + if (it != exec_->func_map.end()) { + ICHECK(kind == exec_->func_table[it->second].kind) + << "Function " << func_name << "already declared in a different kind"; + return; + } + VMFuncInfo vmfunc; + vmfunc.kind = kind; + vmfunc.name = func_name; + // use num args to mark undefined. + vmfunc.start_instr = 0; + vmfunc.num_args = -2; + vmfunc.register_file_size = 0; + exec_->func_map[func_name] = exec_->func_table.size(); + exec_->func_table.push_back(vmfunc); +} + +vm::Instruction::Arg ExecBuilderNode::GetFunction(const std::string& func_name) { + auto it = exec_->func_map.find(func_name); + ICHECK(it != exec_->func_map.end()) << "Cannot find function " << func_name; + return vm::Instruction::Arg::FuncIdx(it->second); +} + +void ExecBuilderNode::EmitFunction(const std::string& func_name, int64_t num_inputs, + Optional> param_names, + vm::VMFuncInfo::FuncKind kind, int64_t init_register_size) { + auto it = exec_->func_map.find(func_name); + if (it == exec_->func_map.end()) { + this->DeclareFunction(func_name, kind); + } + auto& vmfunc = exec_->func_table.at(exec_->func_map.at(func_name)); + ICHECK_EQ(vmfunc.name, func_name); + ICHECK_EQ(vmfunc.num_args, -2) << "Function " << func_name << " already defined"; + vmfunc.num_args = num_inputs; + if (param_names.defined()) { + std::vector names; + for (auto name : param_names.value()) { + names.push_back(name); + } + vmfunc.param_names = names; + } + vmfunc.register_file_size = init_register_size; + if (kind == vm::VMFuncInfo::FuncKind::kVMFunc) { + vmfunc.start_instr = exec_->instr_offset.size(); + } +} + +void ExecBuilderNode::EndFunction(const std::string& func_name) { + auto it = exec_->func_map.find(func_name); + ICHECK(it != exec_->func_map.end()); + VMFuncInfo& vmfunc = exec_->func_table.at(it->second); + ICHECK_EQ(vmfunc.end_instr, 0) << "EndFuncton can only be called once"; + + if (vmfunc.kind == vm::VMFuncInfo::FuncKind::kVMFunc) { + vmfunc.end_instr = exec_->instr_offset.size(); + } +} + +void ExecBuilderNode::EmitCall(vm::Instruction::Arg func, std::vector args, + vm::RegName dst) { + ICHECK(func.kind() == vm::Instruction::ArgKind::kFuncIdx); + // store instruction + exec_->instr_offset.push_back(exec_->instr_data.size()); + exec_->instr_data.push_back(static_cast(Opcode::Call)); + exec_->instr_data.push_back(dst); + exec_->instr_data.push_back(func.value()); + exec_->instr_data.push_back(args.size()); + for (Instruction::Arg arg : args) { + exec_->instr_data.push_back(arg.data()); + } +} + +void ExecBuilderNode::EmitCall(const std::string& func, std::vector args, + RegName dst) { + auto it = exec_->func_map.find(func); + if (it == exec_->func_map.end()) { + this->DeclareFunction(func, VMFuncInfo::FuncKind::kPackedFunc); + } + Index func_idx = exec_->func_map.at(func); + EmitCall(vm::Instruction::Arg::FuncIdx(func_idx), args, dst); +} + +void ExecBuilderNode::EmitRet(vm::Instruction::Arg result) { + ICHECK(result.kind() == vm::Instruction::ArgKind::kRegister); + exec_->instr_offset.push_back(exec_->instr_data.size()); + exec_->instr_data.push_back(static_cast(Opcode::Ret)); + exec_->instr_data.push_back(result.value()); +} + +void ExecBuilderNode::EmitGoto(Index pc_offset) { + exec_->instr_offset.push_back(exec_->instr_data.size()); + exec_->instr_data.push_back(static_cast(Opcode::Goto)); + exec_->instr_data.push_back(pc_offset); +} + +void ExecBuilderNode::EmitIf(vm::Instruction::Arg cond, vm::Index false_offset) { + ICHECK(cond.kind() == vm::Instruction::ArgKind::kRegister); + exec_->instr_offset.push_back(exec_->instr_data.size()); + exec_->instr_data.push_back(static_cast(Opcode::If)); + exec_->instr_data.push_back(cond.value()); + exec_->instr_data.push_back(false_offset); +} + +void ExecBuilderNode::CheckExecutable() { + for (auto it = exec_->func_table.cbegin(); it != exec_->func_table.cend(); ++it) { + if (it->kind == VMFuncInfo::FuncKind::kPackedFunc) continue; + if (it->kind == VMFuncInfo::FuncKind::kVMTIRFunc) { + ICHECK_GE(it->register_file_size, it->num_args + 1) + << "Function " << it->name << " do not meet register file constraint."; + continue; + } + Index num_inputs = it->num_args; + std::unordered_set dst_registers; + std::unordered_set arg_registers; + size_t start_instr = it->start_instr; + size_t end_instr = it->end_instr; + + CHECK_LT(start_instr, end_instr) + << "Function " << it->name << " EndFunction has not be been called"; + + auto check_reg_defined = [&](Instruction::Arg arg) { + if (arg.kind() != Instruction::ArgKind::kRegister) return; + if (arg.value() >= Instruction::kBeginSpecialReg) return; + if (arg.value() < num_inputs) return; + + if (dst_registers.find(arg.value()) == dst_registers.end()) { + LOG(FATAL) << "register r(" << arg.value() << ") in VM function \"" << it->name + << "\" is used as input while it is never defined" + << " as a destination. Dump:\n" + << exec_->AsText(); + } + }; + + auto check_const_defined = [&](Instruction::Arg arg) { + if (arg.kind() != Instruction::ArgKind::kConstIdx) return; + CHECK_LT(arg.value(), exec_->constants.size()) + << "Constant index " << arg.value() << " exceed size of constant pool. Dump:\n" + << exec_->AsText(); + }; + + auto check_func_defined = [&](Instruction::Arg arg) { + if (arg.kind() != Instruction::ArgKind::kFuncIdx) return; + CHECK_LT(arg.value(), exec_->func_table.size()) + << "Func index " << arg.value() << " exceed size of fun_table. Dump:\n" + << exec_->AsText(); + }; + + for (size_t idx = start_instr; idx < end_instr; ++idx) { + Instruction instr = exec_->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + check_func_defined(Instruction::Arg::FuncIdx(instr.func_idx)); + for (int i = 0; i < instr.num_args; ++i) { + check_reg_defined(instr.args[i]); + check_const_defined(instr.args[i]); + check_func_defined(instr.args[i]); + arg_registers.emplace(instr.args[i].value()); + } + if (instr.dst != Instruction::kVoidRegister) { + dst_registers.emplace(instr.dst); + } + break; + } + case Opcode::Ret: { + arg_registers.emplace(instr.result); + check_reg_defined(Instruction::Arg::Register(instr.result)); + break; + } + case Opcode::Goto: { + ICHECK_NE(instr.pc_offset, 0); + break; + } + case Opcode::If: { + ICHECK_GT(instr.false_offset, 1); + check_reg_defined(Instruction::Arg::Register(instr.cond)); + arg_registers.emplace(instr.cond); + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + } +} + +void ExecBuilderNode::Formalize() { + // a pass to formalize user-specified register indexes in the order of use + // and decide the number of registers to allocate for each VMFunction in the Executable + for (auto it = this->exec_->func_table.begin(); it != this->exec_->func_table.end(); ++it) { + if (it->kind == VMFuncInfo::FuncKind::kPackedFunc) continue; + if (it->kind == VMFuncInfo::FuncKind::kVMTIRFunc) continue; + + Index num_inputs = it->num_args; + RegName register_idx = num_inputs; + std::unordered_map register_map; + size_t start_instr = it->start_instr; + size_t end_instr = it->end_instr; + + for (size_t idx = start_instr; idx < end_instr; ++idx) { + Instruction instr = this->exec_->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + // rewrite args + for (int i = 0; i < instr.num_args; ++i) { + if (instr.args[i].kind() == Instruction::ArgKind::kRegister && + instr.args[i].value() >= num_inputs && + instr.args[i].value() < Instruction::kBeginSpecialReg && + register_map.find(instr.args[i].value()) != register_map.end()) { + this->exec_->instr_data[this->exec_->instr_offset[idx] + 4 + i] = + register_map[instr.args[i].value()]; + } + } + if (instr.dst >= num_inputs && instr.dst < Instruction::kBeginSpecialReg) { + auto it = register_map.find(instr.dst); + if (it != register_map.end()) { + this->exec_->instr_data[this->exec_->instr_offset[idx] + 1] = it->second; + } else { + this->exec_->instr_data[this->exec_->instr_offset[idx] + 1] = register_idx; + register_map[instr.dst] = register_idx++; + } + } + break; + } + case Opcode::Ret: { + if (register_map.find(instr.result) != register_map.end()) { + this->exec_->instr_data[this->exec_->instr_offset[idx] + 1] = + register_map[instr.result]; + } + break; + } + case Opcode::Goto: { + break; + } + case Opcode::If: { + if (register_map.find(instr.cond) != register_map.end()) { + this->exec_->instr_data[this->exec_->instr_offset[idx] + 1] = register_map[instr.cond]; + } + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + it->register_file_size = register_idx; + } +} + +TVM_REGISTER_GLOBAL("relax.ExecBuilderCreate").set_body_typed(ExecBuilderNode::Create); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderConvertConstant") + .set_body([](TVMArgs args, TVMRetValue* ret) { + ExecBuilder builder = args[0]; + TVMRetValue rt; + rt = args[1]; + *ret = builder->ConvertConstant(rt).data(); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitFunction") + .set_body_typed([](ExecBuilder builder, String func, int64_t num_inputs, + Optional> param_names) { + builder->EmitFunction(func, num_inputs, param_names); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEndFunction") + .set_body_method(&ExecBuilderNode::EndFunction); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderDeclareFunction") + .set_body_typed([](ExecBuilder builder, String name, int32_t kind) { + builder->DeclareFunction(name, static_cast(kind)); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") + .set_body_typed([](ExecBuilder builder, String name, Array args, int64_t dst) { + std::vector args_; + for (size_t i = 0; i < args.size(); ++i) { + args_.push_back(Instruction::Arg::FromData(args[i]->value)); + } + auto dst_ = Instruction::Arg::Register(dst); + builder->EmitCall(name, args_, dst_.value()); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitRet") + .set_body_typed([](ExecBuilder builder, int64_t data) { + builder->EmitRet(Instruction::Arg::FromData(data)); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitGoto") + .set_body_method(&ExecBuilderNode::EmitGoto); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitIf") + .set_body_typed([](ExecBuilder builder, int64_t data, vm::Index false_offset) { + builder->EmitIf(Instruction::Arg::FromData(data), false_offset); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderR").set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg::Register(value).data(); +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderImm").set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg::Immediate(value).data(); +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderC").set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg::ConstIdx(value).data(); +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderF").set_body_typed([](ExecBuilder builder, String value) { + return builder->GetFunction(value).data(); +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder builder) { + ObjectPtr p_exec = builder->Get(); + return runtime::Module(p_exec); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/vm_builtin_lower.cc new file mode 100644 index 000000000000..5bf419499714 --- /dev/null +++ b/src/relax/backend/vm/vm_builtin_lower.cc @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/vm/vm_builtin_lower.cc + * \brief Lowers most builtin functions and packed calls. + */ +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +// This pass lowers most ops to VM specific builtins. +// TODO(relax-team): revisit after PrimValue. +class VMBuiltinLowerMutator : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + + // A workaround to remove the CallNodes of killing tensors and storages. + void VisitBinding_(const VarBindingNode* binding) final { + const auto* call = binding->value.as(); + if (call != nullptr && (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_)) { + return; + } + ExprMutator::VisitBinding_(binding); + } + + Expr VisitExpr_(const CallNode* call_node) final { + // post-order mutation + Call call = Downcast(VisitExprPostOrder_(call_node)); + + if (call->op == call_tir_dyn_op_) { + return CallTIRDyn(call); + } else if (call->op == reshape_op_) { + return Reshape(call); + } else if (call->op == shape_of_op_) { + return ShapeOf(call); + } else if (call->op == make_closure_op_) { + return MakeClosure(call); + } else if (call->op == invoke_closure_op_) { + return InvokeClosure(call); + } else if (call->op == alloc_tensor_op_) { + return MakeAllocTensor(call); + } else if (call->op == mem_alloc_storage_op_) { + return MakeMemAllocStorage(call); + } else if (call->op == mem_alloc_tensor_op_) { + return MakeMemAllocTensor(call); + } else { + return call; + } + } + + Expr ComputeStorageSize(const Expr& shape, const DataType& dtype) const { + // Question: what if the dtype of tensor_type is unknown? + // Symbolic/static shape case + if (auto* shape_expr = shape.as()) { + int64_t elem_bytes = runtime::GetVectorBytes(dtype); + PrimExpr ret = IntImm(DataType::Int(64), elem_bytes); + for (PrimExpr dim : shape_expr->values) { + ret = ret * dim; + } + return ShapeExpr({ret}); + } else { + return Call(builtin_compute_alloc_shape_, {shape, DataTypeImm(dtype)}, Attrs(), + {GetStructInfo(shape)}); + } + } + + Expr MakeAllocTensor(const Call& call) { + ShapeExpr output_shape = Downcast(call->args[0]); + DataTypeImm output_dtype = Downcast(call->args[1]); + DataType dtype = output_dtype->value; + Expr storage_size = ComputeStorageSize(output_shape, dtype); + PrimValue runtime_device_index = Downcast(call->args[2]); + Var storage = builder_->Emit( + Call(vm_alloc_storage_op_, {storage_size, runtime_device_index, output_dtype}, Attrs()), + "storage"); + Expr shape = call->args[0]; + PrimValue offset = PrimValue::Int64(0); + return Call(vm_alloc_tensor_op_, {storage, offset, shape, DataTypeImm(dtype)}, Attrs()); + } + + Expr MakeMemAllocStorage(const Call& call) { + PrimValue runtime_device_index = Downcast(call->args[1]); + DataTypeImm output_dtype = Downcast(call->args[3]); + return Call(vm_alloc_storage_op_, {call->args[0], runtime_device_index, output_dtype}, Attrs()); + } + + Expr MakeMemAllocTensor(const Call& call) { + PrimValue offset = Downcast(call->args[1]); + DataTypeImm dtype = Downcast(call->args[3]); + return Call(vm_alloc_tensor_op_, {call->args[0], offset, call->args[2], dtype}, Attrs()); + } + + Expr CallTIRDyn(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->args[0]->IsInstance()); + ICHECK(call_node->args[1]->IsInstance()); + Array args; + + auto tir_args = Downcast(call_node->args[1]); + args.push_back(call_node->args[0]); + for (Expr arg : tir_args->fields) { + args.push_back(arg); + } + return Call(builtin_call_tir_dyn_, args, Attrs(), {void_sinfo_}); + } + + Expr Reshape(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->struct_info_.defined()); + auto arg = call_node->args[1]; + CHECK(arg->IsInstance() || arg->IsInstance()) + << "VMBuiltinLower expects the shape arg of reshape op to be a ShapeExpr or VarNode bound " + "to a ShapeExpr"; + + if (arg->IsInstance()) { + return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + } else { + // Handling the case when arg is VarNode + Optional _bound_val = LookupBinding(Downcast(arg)); + ICHECK(_bound_val.defined()); + Expr bound_val = _bound_val.value(); + CHECK(bound_val->IsInstance()) + << "VMBuiltinLower expects bound value to be a ShapeExpr"; + return Call(builtin_reshape_, {call_node->args[0], bound_val}, Attrs(), + {GetStructInfo(call_node)}); + } + } + + Expr ShapeOf(const Call& call_node) { + ICHECK(call_node->args.size() == 1); + ICHECK(call_node->struct_info_.defined()); + return Call(builtin_shape_of_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + } + + Expr MakeClosure(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->args[0]->IsInstance()); + ICHECK(call_node->args[1]->IsInstance()); + + Array args; + auto func = call_node->args[0]; + auto closure_args = Downcast(call_node->args[1]); + + args.push_back(func); + for (Expr arg : closure_args->fields) { + args.push_back(arg); + } + + return Call(builtin_make_closure_, args, Attrs(), {object_sinfo_}); + } + + Expr InvokeClosure(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->args[0]->IsInstance()); + ICHECK(call_node->args[1]->IsInstance()); + + Array args; + + args.push_back(call_node->args[0]); + + // args for the invoke_closure + auto invoke_closure_args = Downcast(call_node->args[1]); + for (Expr arg : invoke_closure_args->fields) { + args.push_back(arg); + } + return Call(call_builtin_with_ctx_op_, {builtin_invoke_closure_, Tuple(args)}, Attrs(), + {object_sinfo_}); + } + + const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); + const StructInfo object_sinfo_ = ObjectStructInfo(); + const StructInfo void_sinfo_ = TupleStructInfo(Array({})); + // object to pattern match. + const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); + const Op& reshape_op_ = Op::Get("relax.reshape"); + const Op& shape_of_op_ = Op::Get("relax.shape_of"); + const Op& make_closure_op_ = Op::Get("relax.make_closure"); + const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); + const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor"); + const Op& mem_alloc_storage_op_ = Op::Get("relax.memory.alloc_storage"); + const Op& mem_alloc_tensor_op_ = Op::Get("relax.memory.alloc_tensor"); + const Op& mem_kill_storage_op_ = Op::Get("relax.memory.kill_storage"); + const Op& mem_kill_tensor_op_ = Op::Get("relax.memory.kill_tensor"); + // functions to lower to + const Op& vm_alloc_storage_op_ = Op::Get("relax.vm.alloc_storage"); + const Op& vm_alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor"); + // Function to compute allocated shape. + const ExternFunc builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"}; + const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"}; + const ExternFunc builtin_reshape_{"vm.builtin.reshape"}; + const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"}; + const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; + const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; +}; + +Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); } + +namespace transform { + +Pass VMBuiltinLower() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(VMBuiltinLower(f)); }; + return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc new file mode 100644 index 000000000000..f4b272979bb6 --- /dev/null +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -0,0 +1,730 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/vm/vm_shape_lower.cc + * \brief Lower the function boundary type checks and symbolic shape computations. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief A slot used in PrimExpr lowering. */ +struct PrimExprSlot { + /*! \brief The existing */ + PrimExpr expr; + /*! \brief The slot index */ + int index; + // The following three members are auxiliary data + // to help shape rewriting. + /*! + * \brief List of slots whose PrimExpr uses this PrimExpr. + * \note Users won't be empty only if PrimExpr is a Var and it does not include itself. + */ + std::vector user_slots; + /*! + * \brief Number of outstanding vars that are not defined in this PrimExpr. + * \note This is a helper counter used in analysis to perform computations. + */ + int outstanding_defs = 0; + /*! \brief Whether we have computed the value. */ + bool value_computed = false; +}; + +/*! + * \brief Helper dats structure to collect pairs of match shapes + * in a recursive matching process. + */ +struct MatchShapeTodoItem { + Expr input; + Array pattern; + String err_ctx; +}; + +/*! \brief Slot map used for shape lowering. */ +using PrimExprSlotMap = + std::unordered_map; + +// Collector to collect PrimExprSlotMap +class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { + public: + // collect the PrimExpr slot for a given function + static void Collect(Function func, std::vector>* slot_vec, + PrimExprSlotMap* slot_map) { + PrimExprSlotCollector collector; + collector.slot_vec_ = slot_vec; + collector.slot_map_ = slot_map; + // collect shape declaration in func params + for (auto param : func->params) { + collector.VisitStructInfo(GetStructInfo(param)); + collector.VisitExpr(param); + } + collector.VisitExpr(func->body); + } + + private: + void VisitPrimExpr(const PrimExpr& expr) final { + if (expr->IsInstance()) return; + if (slot_map_->count(expr) == 0) { + auto slot = std::make_unique(); + slot->expr = expr; + slot->index = static_cast(slot_vec_->size()); + slot_map_->emplace(expr, slot.get()); + slot_vec_->emplace_back(std::move(slot)); + } + } + + void VisitBinding_(const MatchCastNode* op) final { + // Visit the match cast struct info so we can define + // the symbolic variables here. + this->VisitStructInfo(op->struct_info); + } + + void VisitExpr_(const FunctionNode* op) final { + // Do not recurse into function node as it is self-contained + } + + void VisitStructInfo_(const FuncStructInfoNode* op) final { + // Do not recurse into function struct info as it is self-contained + } + + void VisitStructInfoExprField(const PrimExpr& expr) final { VisitPrimExpr(expr); } + + void VisitStructInfoExprField(const Expr& expr) final { ExprVisitor::VisitExpr(expr); } + + std::vector>* slot_vec_; + PrimExprSlotMap* slot_map_; +}; + +/*! + * \brief Main logic to transform the shape lowered functions + * + * Consider the following input: + * + * \code + * + * def f(x: R.Tuple(R.Tensor([m, n+1]), R.Tensor([n, 2])) -> R.Tensor: + * return x + * + * \endcode + * + * Overall flow of the algorithm: + * - Preprocess: PrimExprSlot collection, we scan the function and allocate PrimExprSlot + * for each PrimExpr. In the above example, the result mapping from the slot index + * to expr would be {0:m, 1: n+1: 2: n}. Note that "n+1" also get a slot. + * PrimExprSlot also comes with auxiliary fields that track whether its value + * can be readily computed. + * + * Steps at each matching point: + * - Step 0: We call CheckMatchCast, + * which will recursively unpack the StructInfo, and generate static information checks. + * Note that this step only generates functions for checking types and ndim info, but not + * the symbolic shape variables. The symbolic shape-matching results will be returned as + * vector. This is because symbolic shape matching may not be completed + * in a single round. Importantly, CheckMatchCast also deals with tuple unpacking. + * + * - Step 1: We then call RunMatch to generate the statements for matching symbolic shapes. + * In the above example, the first round will store the value of m, n to their corresponding + * slot. RunMatch may return outstanding items. In the above example x.shape[1] == n+1 cannot + * be checked in the first round. RunMatch will populate new vars(this case n, m), these vars + * are added to a ready queue (ready_vars_) + * + * - Step 2: We EmitOutstandingPrimExprCompute to check if ready_vars will trigger new values + * to be computed. We eagerly compute all the outstanding values. The trigger is done through + * a ref counter which decreases when each outstanding def is satisfied. + * This step can also generate additional TIR functions to carry out shape computations. + * + * - Step 3: RunMatch again for given outstanding match todos. This time all invariants + * should be checked. + * + * The above step would populate each slot(which is backed by an element in shape_heap). + * Each time we find a symbolic shape tuple, we call MakeShape for given slot indices + * in the shape_heap. + * + * + * Key functions in the flow: + * - PrimExprSlotCollector: preprocessing and collecting the slots + * - CheckMatchCast: recursively structinfo unpacking, generate checks and match items. + * - RunMatch: generate symbolic shape matches + * - EmitOutstandingPrimExprCompute: tracks the variables to be computed and emit shape computation + * - VisitExpr_(ShapeExprNode*): makes symbolic shape tuple. + * + * The checks and symbolic shape all maps to runtime builtin functions. Please checkout + * runtime/relax_vm/builtin.cc for their definitions. + * + * Shape computation are lowered to host-side TIR functions that load var from slot + * and store computed results into the slot. For a given slot map: {0:m, 1: n+1: 2: n} + * It will create the shape_func below that loads data from H[2](n's slot) run compute + * and store back to H[1](n+1's slot). + * + * \code + * + * @T.prim_func + * def shape_func(H: T.Buffer([3], "int64")): + * H[1] = H[2] + 1 + * + * \endcode + * + * The current implementation will batch all shape computations at each match point. + * For example, all the expressions that depend on n, m will be computed in a single + * shape_func at the function boundary. If there are follow-up match_cast points, + * that defines new variable, then we might we will generate new shape functions + * to compute expressions that depend on these variables. + */ +class VMShapeLowerMutator + : public ExprMutator, + public StructInfoFunctor*)> { + public: + static IRModule Lower(IRModule mod, bool emit_err_ctx) { + VMShapeLowerMutator mutator(mod, emit_err_ctx); + + for (auto& kv : mod->functions) { + if (auto* func = kv.second.as()) { + Function updated_func = mutator.Rewrite(kv.first, GetRef(func)); + mutator.builder_->UpdateFunction(kv.first, updated_func); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + explicit VMShapeLowerMutator(IRModule mod, bool emit_err_ctx) + : ExprMutator(mod), emit_err_ctx_(emit_err_ctx) {} + + using ExprMutator::VisitExpr_; + + // Unit rewrite function per function. + Function Rewrite(GlobalVar gvar, Function func) { + // prepare mapping and heap var + PrimExprSlotCollector::Collect(func, &slot_vec_, &slot_map_); + heap_size_ = IntImm(ShapeDType(), static_cast(slot_vec_.size())); + VarBinding shape_heap_binding = this->AllocShapeHeapBinding(heap_size_); + shape_heap_ = shape_heap_binding->var; + + // prepare slot information + this->PopulateSlotInfo(); + + Array blocks; + + builder_->BeginScope(func->params); + + { + // Check the parameter section. + builder_->BeginBindingBlock(); + this->builder_->EmitNormalized(shape_heap_binding); + std::vector match_todos; + for (size_t i = 0; i < func->params.size(); ++i) { + StructInfo sinfo = GetStructInfo(func->params[i]); + std::ostringstream err_ctx; + err_ctx << "ErrorContext(fn=" << gvar->name_hint << ", loc=param[" << i + << "], param=" << func->params[i]->name_hint() << ", annotation=" << sinfo << ") "; + this->CheckMatchCast(sinfo, func->params[i], true, err_ctx.str(), &match_todos); + } + // insert heap generation logic. + match_todos = this->RunMatch(match_todos, false); + this->EmitOutstandingPrimExprCompute(); + this->RunMatch(match_todos, true); + + BindingBlock pre_block = builder_->EndBlock(); + blocks.push_back(pre_block); + } + + // new body. + auto body_seq = Downcast(this->VisitWithNewScope(func->body, func->params)); + blocks.insert(blocks.end(), body_seq->blocks.begin(), body_seq->blocks.end()); + + { + // Insert the return value check + builder_->BeginBindingBlock(); + std::ostringstream err_ctx; + err_ctx << "ErrorContext(fn=" << gvar->name_hint + << ", loc=return, annotation=" << func->ret_struct_info << ") "; + std::vector match_todos; + // NOTE: the return value's shape computation must already be defined. + this->CheckMatchCast(func->ret_struct_info, body_seq->body, false, err_ctx.str(), + &match_todos); + // NOTE: the return value's shape computation must already be defined. + this->RunMatch(match_todos, true); + BindingBlock post_block = builder_->EndBlock(); + blocks.push_back(post_block); + } + + auto new_body = builder_->Normalize(SeqExpr(blocks, body_seq->body)); + // create a new function + return Function(func->params, new_body, func->ret_struct_info, func->attrs); + } + + //------------------------------------------------------- + // PrimExpr slot handling + //------------------------------------------------------- + static DataType ShapeDType() { return DataType::Int(64); } + + /*! \brief populate additional information in the slot. */ + void PopulateSlotInfo() { + for (auto& kv : slot_map_) { + auto* slot = kv.second; + if (!slot->expr.as()) { + Array dep_vars = tir::UndefinedVars(slot->expr); + for (auto var : dep_vars) { + auto it = slot_map_.find(var); + ICHECK(it != slot_map_.end()) + << "Var " << var << "is not defined in the function but is referenced by " + << slot->expr; + auto* var_slot = it->second; + // populate the use slot. + var_slot->user_slots.push_back(slot); + } + // set outstanding defs. + slot->outstanding_defs += static_cast(dep_vars.size()); + } + } + } + //------------------------------------------------------- + // Helper functions + //------------------------------------------------------- + StringImm GetErrContext(String err_ctx) const { + return emit_err_ctx_ ? StringImm(err_ctx) : StringImm(""); + } + + VarBinding AllocShapeHeapBinding(IntImm heap_size) { + if (heap_size->value > 0) { + TensorStructInfo heap_sinfo(ShapeDType(), 1); + Var var("shape_heap", heap_sinfo); + // set up the builtin func. + Call call(call_builtin_with_ctx_op_, + {builtin_alloc_shape_heap_, Tuple({PrimValue(heap_size)})}, Attrs(), {heap_sinfo}); + UpdateStructInfo(call, heap_sinfo); + return VarBinding(var, call); + } else { + Var var("shape_heap", ObjectStructInfo()); + Call call(null_value_op_, {}); + UpdateStructInfo(call, ObjectStructInfo()); + return VarBinding(var, call); + } + } + + //------------------------------------------------------- + // Expr mutation overloading. + //------------------------------------------------------- + Expr VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "VMShapeLower do not work for local functions, make sure " + << " to run it after LambdaLift"; + return GetRef(op); + } + + Expr VisitExpr_(const ShapeExprNode* op) final { + using runtime::relax_vm::MakeShapeCode; + // Constant shape can be preserved. + bool is_const_shape = std::all_of(op->values.begin(), op->values.end(), [](const PrimExpr& e) { + return e->IsInstance(); + }); + if (is_const_shape) { + return GetRef(op); + } + + Array args = {shape_heap_, PrimValue::Int64(static_cast(op->values.size()))}; + for (PrimExpr expr : op->values) { + if (auto* int_expr = expr.as()) { + args.push_back(PrimValue::Int64(static_cast(MakeShapeCode::kUseImm))); + args.push_back(PrimValue::Int64(int_expr->value)); + } else { + auto it = slot_map_.find(expr); + ICHECK(it != slot_map_.end()); + auto* slot = it->second; + ICHECK(slot->value_computed) << "PrimExpr " << expr << " has not been computed"; + args.push_back(PrimValue::Int64(static_cast(MakeShapeCode::kLoadShape))); + args.push_back(PrimValue::Int64(slot->index)); + } + } + + // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n]) + Call call(builtin_make_shape_, args, Attrs(), + {ShapeStructInfo(static_cast(op->values.size()))}); + return call; + } + + void VisitBinding_(const MatchCastNode* binding) final { + Expr value = ExprMutator::VisitExpr(binding->value); + std::vector match_todos; + std::ostringstream err_ctx; + err_ctx << "ErrorContext(match_cast, struct_info=" << binding->struct_info << ") "; + // always_check=false + this->CheckMatchCast(binding->struct_info, value, false, err_ctx.str(), &match_todos); + + match_todos = this->RunMatch(match_todos, false); + this->EmitOutstandingPrimExprCompute(); + this->RunMatch(match_todos, true); + + // These checks are emitted as extra, in codegen + // match-cast is simply ignored and treated as a normal binding. + builder_->EmitNormalized(GetRef(binding)); + } + + // Do not override shape in struct info fields + // We only override the shape that are already part of the normal function values + // If future passes lift those values out into the values, + // then codegen may not be able to handle symbolic values. + // Place this pass as last pass before codegen. + StructInfo VisitExprDepStructInfoField(const StructInfo& sinfo) final { return sinfo; } + + //------------------------------------------------------- + // Shape computations. + //------------------------------------------------------- + /*! + * \brief Execute the match todo items. + * + * This function can populate vars in the match items when seeing it for the first time. + * These new vars will be added to this->ready_vars_. + * + * If an item contains PrimExpr that are yet to be computed (but may be computable through + * vars defined in this round), it will be returned to the caller. + * + * The caller should call EmitOutstandingPrimExprCompute, then call RunMatch again. + * + * \param match_todos The list of match items to be executed. + * \param require_value_computed Whether we require all expr to be computed. + * \return List of outstanding items that contains value that are yet to be computed. + */ + std::vector RunMatch(const std::vector& match_todos, + bool require_value_computed) { + std::vector outstanding_todos; + + using runtime::relax_vm::MatchShapeCode; + for (const MatchShapeTodoItem& item : match_todos) { + int64_t shape_len = static_cast(item.pattern.size()); + bool all_nop = true; + int num_outstanding_exprs = 0; + + Array args = {item.input, shape_heap_, PrimValue::Int64(shape_len)}; + + for (PrimExpr expr : item.pattern) { + MatchShapeCode code = MatchShapeCode::kNoOp; + int64_t rvalue = 0; + if (auto* int_expr = expr.as()) { + code = MatchShapeCode::kAssertEqualToImm; + rvalue = int_expr->value; + } else { + auto it = slot_map_.find(expr); + ICHECK(it != slot_map_.end()); + auto* slot = it->second; + if (slot->value_computed) { + code = MatchShapeCode::kAssertEqualToLoad; + rvalue = slot->index; + } else { + // the value is not yet computed + ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not computed"; + if (expr.as()) { + // if it is a var, we will populate it in this round. + // otherwise, we skip and mark it as outstanding + code = MatchShapeCode::kStoreToHeap; + rvalue = slot->index; + slot->value_computed = true; + ready_vars_.push_back(slot); + } else { + code = MatchShapeCode::kNoOp; + rvalue = 0; + ++num_outstanding_exprs; + } + } + } + all_nop = all_nop && code == MatchShapeCode::kNoOp; + args.push_back(PrimValue::Int64(static_cast(code))); + args.push_back(PrimValue::Int64(rvalue)); + } + if (num_outstanding_exprs != 0) { + outstanding_todos.push_back(item); + } + args.push_back(GetErrContext(item.err_ctx)); + if (!all_nop) { + Call call(builtin_match_shape_, args, Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + } + return std::move(outstanding_todos); + } + + /*! + * \brief Compute a list of prim expr that now be computed + * for given ready vars. + */ + std::vector GetReadyPrimExprSlots() { + std::vector to_compute; + for (PrimExprSlot* slot : ready_vars_) { + for (PrimExprSlot* user : slot->user_slots) { + ICHECK_GT(user->outstanding_defs, 0); + user->outstanding_defs -= 1; + if (user->outstanding_defs == 0) { + to_compute.push_back(user); + } + } + } + ready_vars_.clear(); + return to_compute; + } + + /*! + * \brief Check the dependent expressions of ready_vars_, + * + * If there are outstanding PrimExpr that can now be computed + * we generate a PrimFunc that compute the extra shape values + * + * We will then clear the ready_vars. + * + * \return Number of PrimExpr computed. + */ + size_t EmitOutstandingPrimExprCompute() { + std::vector to_compute = GetReadyPrimExprSlots(); + if (to_compute.size() == 0) return 0; + ICHECK_GT(heap_size_->value, 0); + // construct a PrimFunc that compute the shape. + tir::Var heap("heap", DataType::Handle()); + Array buffer_shape{heap_size_}; + tir::Buffer buffer = tir::decl_buffer(buffer_shape, ShapeDType(), "H", "global"); + Map buffer_map; + buffer_map.Set(heap, buffer); + + auto var_map = [&](const tir::Var& var) -> Optional { + auto it = slot_map_.find(var); + ICHECK(it != slot_map_.end()); + return tir::BufferLoad(buffer, {IntImm(ShapeDType(), it->second->index)}); + }; + + Array seq; + for (PrimExprSlot* slot : to_compute) { + ICHECK(!slot->value_computed); + slot->value_computed = true; + PrimExpr value = tir::Substitute(slot->expr, var_map); + seq.push_back(tir::BufferStore(buffer, value, {IntImm(ShapeDType(), slot->index)})); + } + + tir::Stmt body = tir::SeqStmt::Flatten(seq); + Array params{heap}; + Type ret_type = VoidType(); + + // TODO(relax-team): Consider attach the target attribute to + // the shape_func to indicate that this is a host function + // This could require us to attach target to the relax function here. + tir::PrimFunc shape_func(params, body, ret_type, buffer_map); + if (shape_func->attrs.GetAttr(tvm::attr::kTarget) == nullptr) { + // kTarget and kIsHostFunc are mutually exclusive + shape_func = + WithAttr(std::move(shape_func), tvm::tir::attr::kIsHostFunc, Integer(1)); + } + GlobalVar shape_func_var = builder_->AddFunction(shape_func, "shape_func"); + builder_->Emit(Call(shape_func_var, {shape_heap_}), "_"); + return to_compute.size(); + } + //------------------------------------------------------- + // StructInfo value match logic + // + // CheckMatchCast is the only function needed by + // other code sections + //------------------------------------------------------- + /*! + * \brief Insert runtime check of the match cast condition(value, struct_info). + * + * \param struct_info The struct info to be matched. + * \param value The input value. + * \param always_check Whether we insert runtime check even if we can prove + * that value's struct info already satisfies the condition. + * This option is necessary for argument checking per our calling convention. + * + * \param err_ctx Extra error context to bring more informative error reporting. + * \param match_todos List of match shape todo items collected when recursively + * visit the match cast. + */ + void CheckMatchCast(const StructInfo& struct_info, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) { + return this->VisitStructInfo(struct_info, value, always_check, err_ctx, match_todos); + } + + void VisitStructInfo(const StructInfo& struct_info, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + // short-cut, if the struct info already satisfies the + // constraint during match cast, we can skip matching + if (!always_check && IsBaseOf(struct_info, GetStructInfo(value))) return; + return StructInfoFunctor::VisitStructInfo(struct_info, value, always_check, err_ctx, + match_todos); + } + + void VisitStructInfo_(const ObjectStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + } + + void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + // TODO(relax-team) add PrimValue checks later. + LOG(FATAL) << "MatchCast of PrimValue is not yet supported"; + } + + void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + // emit runtime check of shape + if (always_check || !IsBaseOf(ShapeStructInfo(op->ndim), GetStructInfo(value))) { + // check_shape_info(value, ndim, err_ctx) + Call call(builtin_check_shape_info_, + {value, PrimValue::Int64(op->ndim), GetErrContext(err_ctx)}, Attrs(), + {void_sinfo_}); + builder_->Emit(call, "_"); + } + if (op->values.defined()) { + MatchShapeTodoItem item; + item.input = value; + item.pattern = op->values.value(); + item.err_ctx = err_ctx; + match_todos->push_back(item); + } + } + + void VisitStructInfo_(const TensorStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + // emit runtime check of shape + if (always_check || !IsBaseOf(TensorStructInfo(op->dtype, op->ndim), GetStructInfo(value))) { + // check_tensor_info(value, ndim, dtype, err_ctx) + Call call(builtin_check_tensor_info_, + {value, PrimValue::Int64(op->ndim), DataTypeImm(op->dtype), GetErrContext(err_ctx)}, + Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + + if (auto* shape_expr = op->shape.as()) { + MatchShapeTodoItem item; + item.input = value; + item.pattern = shape_expr->values; + item.err_ctx = err_ctx; + match_todos->push_back(item); + } else if (op->shape.as()) { + // NOTE: This part of the logic is left empty for future support as it is less common. + // Future implementors: we can emit a binding here and assert here. + LOG(FATAL) << "Cannot handle Tensor shape pattern where a var appears multiple times"; + } else { + ICHECK(!op->shape.defined()) << "Can only handle tensor shape pattern var"; + } + } + + // Internal helper function to make tuple get item. + // This function will try to simplify constant tuples + // the return value **always** have struct info. + Expr MakeTupleGetItem(Expr value, int64_t index) { + if (auto* tuple_expr = value.as()) { + return tuple_expr->fields[index]; + } else if (auto* tuple_sinfo = GetStructInfoAs(value)) { + // value is tuple type, it is OK to run tuple get item. + auto ret = TupleGetItem(value, index); + UpdateStructInfo(ret, tuple_sinfo->fields[index]); + return ret; + } else { + // call runtime tuple get item, and return a object. + Call call(builtin_tuple_getitem_, {value, PrimValue::Int64(index)}, Attrs(), {object_sinfo_}); + UpdateStructInfo(call, ObjectStructInfo()); + return call; + } + } + + void VisitStructInfo_(const TupleStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + auto* value_tinfo = GetStructInfoAs(value); + if (value_tinfo) { + CHECK_EQ(value_tinfo->fields.size(), op->fields.size()) + << "TypeError: " << err_ctx << " during match-cast we find tuple size mismatch"; + } + if (always_check || !value_tinfo) { + // check_tuple_info(value, tuple_size) + Call call(builtin_check_tuple_info_, + {value, PrimValue::Int64(static_cast(op->fields.size())), + GetErrContext(err_ctx)}, + Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + // recursively visit each sub-field and run matching + for (size_t i = 0; i < op->fields.size(); ++i) { + this->VisitStructInfo(op->fields[i], MakeTupleGetItem(value, i), always_check, err_ctx, + match_todos); + } + } + + void VisitStructInfo_(const FuncStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + // we only check function is callable. + if (!always_check && MatchStructInfo(value)) return; + // check_func_info(value, err_ctx) + Call call(builtin_check_func_info_, {value, GetErrContext(err_ctx)}, Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + + //------------------------------------------------------- + // Private member fields. + //------------------------------------------------------- + /*! \brief whether to emit error context, can be turned off for testing purposes. */ + bool emit_err_ctx_{true}; + /*! \brief heap ptr to store the PrimExpr slots. */ + Var shape_heap_; + /*! \brief heap size. */ + IntImm heap_size_; + /*! \brief index => slot. */ + std::vector> slot_vec_; + /*! \brief Expr => slot. */ + PrimExprSlotMap slot_map_; + /*! + * \brief List of vars that are being defined but + * have not go through outstanding shape compute check. + */ + std::vector ready_vars_; + // call builtin cop + const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); + const Op& null_value_op_ = Op::Get("relax.null_value"); + // common struct info + const StructInfo object_sinfo_ = ObjectStructInfo(); + const StructInfo void_sinfo_ = TupleStructInfo(Array({})); + // check function + const ExternFunc builtin_alloc_shape_heap_{"vm.builtin.alloc_shape_heap"}; + const ExternFunc builtin_match_shape_{"vm.builtin.match_shape"}; + const ExternFunc builtin_make_shape_{"vm.builtin.make_shape"}; + const ExternFunc builtin_check_shape_info_{"vm.builtin.check_shape_info"}; + const ExternFunc builtin_check_tensor_info_{"vm.builtin.check_tensor_info"}; + const ExternFunc builtin_check_tuple_info_{"vm.builtin.check_tuple_info"}; + const ExternFunc builtin_check_func_info_{"vm.builtin.check_func_info"}; + const ExternFunc builtin_tuple_getitem_{"vm.builtin.tuple_getitem"}; +}; + +namespace transform { + +Pass VMShapeLower(bool emit_err_ctx) { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return VMShapeLowerMutator::Lower(mod, emit_err_ctx); }; + return CreateModulePass(pass_func, 0, "VMShapeLower", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed([](bool emit_err_ctx) { + return VMShapeLower(emit_err_ctx); +}); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc new file mode 100644 index 000000000000..64866464fad5 --- /dev/null +++ b/src/relax/ir/binding_rewrite.cc @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/binding_rewrite.cc + * \brief Implementation of binding rewriters. + */ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(DataflowBlockRewriteNode); + +DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) { + auto n = make_object(); + n->dfb_ = dfb; + n->root_fn_ = root_fn; + n->original_fn_ptr_ = root_fn.get(); + auto p = FunctionUseDef(root_fn); + n->to_users_ = std::move(p.first); + n->fn_outputs_ = std::move(p.second); + n->name_table_ = NameTable(n->to_users_.begin(), n->to_users_.end(), + [](const auto& p) { return p.first->name_hint(); }); + + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowBlockRewrite") + .set_body_typed([](DataflowBlock dfb, Function root_fn) { + return DataflowBlockRewrite(dfb, root_fn); + }); + +void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { + class ReplaceAllUsePass : public ExprMutator { + Var old_var, new_var; + const DataflowBlockNode* const to_catch; + + public: + DataflowBlock caught; + + ReplaceAllUsePass(Var old_var, Var new_var, const DataflowBlockNode* to_catch) + : old_var(old_var), new_var(new_var), to_catch(to_catch) {} + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const VarNode* op) override { + return (op == old_var.get()) ? new_var : GetRef(op); + } + + Expr VisitExpr_(const DataflowVarNode* op) override { + return (op == old_var.get()) ? new_var : GetRef(op); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + BindingBlock res = ExprMutator::VisitBindingBlock_(op); + if (op == to_catch) caught = Downcast(res); + return res; + } + }; + + ICHECK(to_users_.find(old_var) != to_users_.end()) << "Cannot find " << old_var; + ICHECK(to_users_.find(new_var) != to_users_.end()) << "Cannot find " << new_var; + + // replace uses inside the DataflowBlock. + ReplaceAllUsePass replacer(old_var, new_var, dfb_.get()); + if (root_fn_) { + root_fn_ = Downcast(replacer.VisitExpr(root_fn_.value())); + dfb_ = replacer.caught; + } else { + dfb_ = Downcast(replacer.VisitBindingBlock(dfb_)); + } + + // update udchain + // old_var -> old_var users | changed to {} + // new_var -> {?} | changed to old_var users + for (Var user : to_users_[old_var]) { + auto new_var_uses = to_users_[new_var]; + if (new_var_uses.end() == std::find(new_var_uses.begin(), new_var_uses.end(), user)) { + new_var_uses.push_back(user); + } + } + + to_users_.Set(old_var, {}); + + auto it_old_output = std::find(fn_outputs_.begin(), fn_outputs_.end(), old_var); + if (it_old_output != fn_outputs_.end()) { + fn_outputs_.Set(std::distance(fn_outputs_.begin(), it_old_output), new_var); + } +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_replace_all_uses") + .set_body_typed([](DataflowBlockRewrite rwt, Var old_var, Var new_var) { + rwt->ReplaceAllUses(old_var, new_var); + }); + +class UpdateDFB : public ExprMutator { + private: + DataflowBlock old_dfb, new_dfb; + + public: + UpdateDFB(DataflowBlock old_dfb, DataflowBlock new_dfb) + : old_dfb(std::move(old_dfb)), new_dfb(std::move(new_dfb)) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + return old_dfb.get() == op ? new_dfb : old_dfb; + } +}; + +// TODO(masahi): Consider moving this to analysis +std::set GetUsedVars(Expr val) { + class UsedVars : public ExprVisitor { + public: + std::set used_vars; + void VisitExpr_(const VarNode* op) override { used_vars.insert(op); } + void VisitExpr_(const DataflowVarNode* op) override { used_vars.insert(op); } + } uvar{}; + uvar.VisitExpr(val); + return std::move(uvar.used_vars); +} + +void DataflowBlockRewriteNode::Add(Binding binding) { + auto [var, val] = [binding] { + if (auto vb = binding.as()) { + return std::make_pair(vb->var, vb->value); + } else if (auto mc = binding.as()) { + return std::make_pair(mc->var, mc->value); + } + LOG(FATAL) << "Unsupported binding type"; + return std::make_pair(Var{}, Expr{}); + }(); + + ICHECK(0 == to_users_.count(var)) << var << " has been defined so cannot be added."; + + // Add this VarBinding statement after the definition of uses. + auto used_vars = GetUsedVars(val); + + size_t line_last_req_def = 0; + for (size_t i = 0; i < dfb_->bindings.size(); ++i) { + auto line = dfb_->bindings[i]; + if (used_vars.find(line->var.get()) != used_vars.cend()) line_last_req_def = i; + } + + auto old_dfb = dfb_; + + dfb_.CopyOnWrite()->bindings.insert(dfb_->bindings.begin() + 1 + line_last_req_def, binding); + + if (root_fn_) { + auto updater = UpdateDFB(old_dfb, dfb_); + root_fn_ = Downcast(updater.VisitExpr(root_fn_.value())); + } + + for (const VarNode* v : used_vars) { + auto var = GetRef(v); + if (auto users = to_users_.Get(var)) { + users.value().push_back(var); + } + } +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add_binding") + .set_body_typed([](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb); }); + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add") + .set_body_typed([](DataflowBlockRewrite rwt, Expr expr, Optional name, bool is_dfvar) { + if (name.get()) { + rwt->Add(name.value(), expr, is_dfvar); + } else { + rwt->Add(expr, is_dfvar); + } + }); + +std::set GetUnusedVars(Map> users_map, Array fn_outputs) { + std::vector unused; + + // iterative dataflow algorithm. + size_t prev_size; + do { + prev_size = unused.size(); + + std::vector used; + used.reserve(users_map.size()); + for (const auto& [def, users] : users_map) { + // var -> [users...] + // var is unused iff + // user -> empty + // var is not output var + if (users.empty() && // def is not used by fn outputs. + std::find(fn_outputs.begin(), fn_outputs.end(), def) == fn_outputs.end()) { + unused.push_back(def); + } else { + used.push_back(def); + } + } + + for (size_t i = prev_size; i < unused.size(); ++i) { + users_map.erase(unused[i]); + // remove def site. + for (const auto& used_var : used) { + ICHECK(users_map.count(used_var)); + Array var_users = users_map[used_var]; + // remove the unused var from the use site. + if (auto it = std::find(var_users.begin(), var_users.end(), unused[i]); + it != var_users.end()) { + var_users.erase(it); + users_map.Set(used_var, std::move(var_users)); + } + } + } + } while (prev_size != unused.size()); // changed? => continue. + + return std::set(unused.begin(), unused.end()); +} + +class RemoveUnusedVars : public ExprMutator { + public: + std::set unused_vars; + Optional caught_rewrite = NullOpt; + + RemoveUnusedVars(std::set unused_vars) : unused_vars(std::move(unused_vars)) {} + + RemoveUnusedVars(Map> users, Array fn_outputs) + : RemoveUnusedVars(GetUnusedVars(users, fn_outputs)) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) { + auto prev_dfb = GetRef(block); + builder_->BeginDataflowBlock(); + for (Binding binding : block->bindings) { + if (!unused_vars.count(binding->var)) { + VisitBinding(binding); + } + } + auto new_dfb = builder_->EndBlock(); + if (caught_rewrite == prev_dfb) caught_rewrite = Downcast(new_dfb); + return std::move(new_dfb); + } +}; + +void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) { + // first need to check if this var is used. + if (to_users_.count(unused) == 0) { // no def. + if (allow_undef) return; + LOG(FATAL) << unused << " undefined. Set allow_undef=True to allow 'removing' undefined var"; + } + + ICHECK(to_users_[unused].empty()) + << unused << " is used by " << to_users_[unused].size() << " vars"; + + auto old_dfb = dfb_; + + RemoveUnusedVars remover({unused}); + dfb_ = Downcast(remover.VisitBindingBlock(old_dfb)); + + if (root_fn_) { + auto updater = UpdateDFB(old_dfb, dfb_); + root_fn_ = Downcast(updater.VisitExpr(root_fn_.value())); + } + + to_users_.erase(unused); // update use-def chain. +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_unused") + .set_body_typed([](DataflowBlockRewrite rwt, Var unused, bool allow_undef) { + rwt->RemoveUnused(unused, allow_undef); + }); + +void DataflowBlockRewriteNode::RemoveAllUnused() { + RemoveUnusedVars remover(to_users_, fn_outputs_); + remover.caught_rewrite = dfb_; + + if (root_fn_) { + // this could also clean unused variables in other DataflowBlock. + root_fn_ = Downcast(remover.VisitExpr(root_fn_.value())); + // DataflowBlock could be None. + dfb_ = remover.caught_rewrite.value(); + } else { + dfb_ = Downcast(remover.VisitBindingBlock(dfb_)); + } + + // clean up use-def chain. + for (const auto& unused : remover.unused_vars) to_users_.erase(unused); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused") + .set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); }); + +Function RemoveAllUnused(Function fn) { + auto [users, outputs] = FunctionUseDef(fn); + RemoveUnusedVars remover(users, outputs); + return Downcast(remover.VisitExpr_(fn.get())); +} + +TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused); + +IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { + BlockBuilder builder = BlockBuilder::Create(irmod); + + for (auto& [gvar, fn] : irmod->functions) { + if (root_fn_ && original_fn_ptr_ == fn.get()) { + builder->UpdateFunction(gvar, root_fn_.value()); + break; + } + } + + return builder->GetContextIRModule(); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_mutate_irmodule") + .set_body_typed([](DataflowBlockRewrite rwt, IRModule irmod) { + return rwt->MutateIRModule(irmod); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc new file mode 100644 index 000000000000..ac92114ef9cb --- /dev/null +++ b/src/relax/ir/block_builder.cc @@ -0,0 +1,948 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/block_builder.cc + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// Block builder have three categories of logics that are interdependent with each other. +// +// The logics are somewhat interdependent with each other. +// To help us implement a block builder in two parts: +// +// - BlockBuilderImpl: implements ctx and scope management, with no normalization. +// - BlockBuilderImplWithNormalize: subclasses BlockBuilderImpl and implements normalization. +// +// The final blockbuilder create will be backed by BlockBuilderWithNormalize + +namespace tvm { +namespace relax { + +//--------------------------------------- +// ctx and scope management. +//--------------------------------------- +class BlockBuilderImpl : public BlockBuilderNode { + public: + explicit BlockBuilderImpl(IRModule context_mod) : context_mod_(std::move(context_mod)) {} + + ~BlockBuilderImpl() { + if (!block_stack_.empty()) { + LOG(WARNING) << "BlockBuilder destroyed with remaining blocks!"; + } + } + + //------------------------------- + // Global Context management + //------------------------------- + NameTable* name_table() final { return name_table_.get(); } + + IRModule GetContextIRModule() const final { return context_mod_; } + + GlobalVar AddFunction(const BaseFunc& func, String func_name_hint) final { + LazyInitCtxFuncDedupMap(); + auto it = ctx_func_dedup_map_->find(func); + if (it == ctx_func_dedup_map_->end()) { + context_mod_.CopyOnWrite(); + + String func_name = name_table_->GetUniqueName(func_name_hint); + while (context_mod_->ContainGlobalVar(func_name)) { + func_name = name_table_->GetUniqueName(func_name_hint); + } + GlobalVar gvar = GlobalVar(func_name); + + StructInfo finfo; + if (func->struct_info_.defined()) { + finfo = GetStructInfo(func); + } else if (auto* prim_func = func.as()) { + // NOTE: use a slightly different struct info than checked type + // in PrimFunc so handle can turn into Tensor. + // TODO(relax-team): add fine-grained PrimFunc struct info signature generation. + finfo = FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)); + } else { + finfo = StructInfoFromType(func->checked_type()); + } + UpdateStructInfo(gvar, finfo); + + context_mod_->Add(gvar, func); + + ctx_func_dedup_map_->emplace(func, gvar); + return gvar; + } else { + return it->second; + } + } + + void UpdateFunction(const GlobalVar& gv, BaseFunc function) final { + context_mod_.CopyOnWrite(); + + // invalidate old dedup map + if (ctx_func_dedup_map_ != nullptr) { + auto it = context_mod_->functions.find(gv); + if (it != context_mod_->functions.end()) { + BaseFunc old_func = (*it).second; + auto ptr = ctx_func_dedup_map_->find(old_func); + ICHECK(ptr != ctx_func_dedup_map_->end()); + ctx_func_dedup_map_->erase(ptr); + } + } + + context_mod_->Update(gv, function); + + // add new dedup map item. + if (ctx_func_dedup_map_ != nullptr) { + ctx_func_dedup_map_->emplace(function, gv); + } + } + + void ReportFatal(const Diagnostic& diagnostic) final { + // TODO(relax-team): Print more context information by looking + // into the diagnostic->loc and surrounding IRModule. + // We do not materialzie DiagnosticContext to avoid double referencing to + // the change IRModule in COW. Additionally, we need to be able to + // continue use the builder after an error is thrown to avoid state building up. + // in an interactive environment. + LOG(FATAL) << diagnostic->message; + } + + //------------------------------- + // Scope management + //------------------------------- + Optional LookupBinding(const Var& var) final { + auto it = binding_table_.find(var->vid); + if (it == binding_table_.end()) return NullOpt; + return it->second; + } + + void BeginDataflowBlock() final { block_stack_.emplace_back(BlockFrame{{}, true}); } + + void BeginBindingBlock() final { block_stack_.emplace_back(BlockFrame{{}, false}); } + + void BeginScope(Optional> params) final { + // The current implementation handles the collection of shape var + // defined in parameter struct info annotations. The implementation + // is correct (since we will simply erase all relax Vars in EraseToWellDefined), + // but can be further improved. + // + // TODO(relax-team): Add support for relax Var in struct info annotations. + Map shape_var_map; + for (const Var& var : params.value_or(Array())) { + const Map& var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); + for (const auto& kv : var_map) { + const tir::Var& shape_var = kv.first; + const PrimExpr& shape_expr = kv.second; + auto it = shape_var_map.find(shape_var); + if (it == shape_var_map.end()) { + shape_var_map.Set(shape_var, shape_expr); + } else { + const PrimExpr& old_shape_expr = (*it).second; + CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr)) + << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " + << shape_expr; + } + shape_var_map.Set(shape_var, shape_expr); + } + } + scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)})); + } + + void EndScope() final { scope_stack_.pop_back(); } + + BindingBlock EndBlock() final { + BlockFrame* cur_frame = CurrentBlockFrame(); + BindingBlock ret = cur_frame->is_dataflow ? DataflowBlock(cur_frame->bindings) + : BindingBlock(cur_frame->bindings); + block_stack_.pop_back(); + return ret; + } + + bool CurrentBlockIsDataFlow() final { return CurrentBlockFrame()->is_dataflow; } + + Var Emit(Expr expr, String name_hint) final { + return this->Emit(expr, CurrentBlockFrame()->is_dataflow, name_hint); + } + + Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint) final { + value = this->Normalize(value); + + CHECK(StructInfoBaseCheck(GetStructInfo(value), struct_info) != BaseCheckResult::kFailL0) + << "It is impossible to match cast any value into the target struct_info. " + "But got value struct info: " + << GetStructInfo(value) << ", given struct info: " << struct_info; + + // NOTE: do match cast checking later in a pass. + BlockFrame* cur_frame = CurrentBlockFrame(); + Var var = CreateVar(cur_frame->is_dataflow, name_hint); + UpdateStructInfo(var, struct_info); + + MatchCast match_cast(var, value, struct_info); + cur_frame->bindings.push_back(match_cast); + // NOTE match shape do not follow simple binding rule + // as a result should not appear in binding table. + return var; + } + + Var EmitOutput(Expr output, String name_hint) final { + BlockFrame* cur_frame = CurrentBlockFrame(); + + ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; + + return Emit(output, false, name_hint); + } + + void EmitNormalized(Binding binding) final { + BlockFrame* cur_frame = CurrentBlockFrame(); + + if (const auto* var_binding = binding.as()) { + if (!cur_frame->is_dataflow) { + ICHECK(!var_binding->var.as()) + << "Cannot emit dataflow var in non-dataflow block"; + } + // normalized check + ICHECK(var_binding->var->struct_info_.defined()); + ICHECK(var_binding->value->struct_info_.defined()); + cur_frame->bindings.push_back(binding); + binding_table_[var_binding->var->vid] = var_binding->value; + } else if (const auto* match_cast = binding.as()) { + if (!cur_frame->is_dataflow) { + ICHECK(!match_cast->var.as()) + << "Cannot emit dataflow var in non-dataflow block"; + } + // normalized check + ICHECK(match_cast->var->struct_info_.defined()); + ICHECK(match_cast->value->struct_info_.defined()); + // NOTE match shape do not follow simple binding rule + // as a result should not appear in binding table. + cur_frame->bindings.push_back(binding); + } else { + LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey(); + } + } + + arith::Analyzer* GetAnalyzer() final { return &analyzer_; } + + protected: + /*! + * \brief A representation of a block frame. + * + * A block frame is a record containing the bindings needed + * to build a binding block, and a boolean to indicate if the + * block being built is a DataflowBlock or not. + */ + struct BlockFrame { + /*! + * \brief List of bindings + */ + Array bindings; + /*! \brief Whether current block is dataflow block. */ + bool is_dataflow; + /*! + * \brief Binding map used by normalizer. + * + * \note The normalizer only caches reuse in the current block scope + * and will not cache bindings from parent scope. + */ + std::unordered_map normalize_binding_map; + }; + /*! + * \brief A representation of a scope frame. + * + * A scope frame records tracks the context of current scope. + */ + struct ScopeFrame { + // NOTE: for simplicity, only tracks symbolic var for now + // the scope is only used for erasure, so less information means + // more conservative analysis. + // Consider impl alternative: merge with block frame if we have more frame kinds. + // + // TODO(relax-team) tracks the var defined also through match-cast. + /*! \brief set of defined symbolic vars, value as themself. */ + Map shape_var_map; + }; + + /*! \brief A stack to store block frames. */ + std::vector block_stack_; + + /*! \brief A stack to store scope frames. */ + std::vector scope_stack_; + + /*! \brief A binding table that maps var to value. */ + std::unordered_map binding_table_; + + /*! \brief A name table to get unique names for IR construction. */ + std::unique_ptr name_table_ = std::make_unique(); + + /*! \brief The IRModule being built by the BlockBuilder. */ + IRModule context_mod_; + + /*! \brief Internal analzyer */ + arith::Analyzer analyzer_; + + /*! + * \return The current frame. + * \note Never hold the value of current frame between Normalize + * or other scope calls this value can change if the block stack get updated, + * then the block frame is no longer valid. + */ + BlockFrame* CurrentBlockFrame() { + ICHECK(!block_stack_.empty()) << "no block is being built"; + return &block_stack_.back(); + } + + /*! + * \return The current scope frame. + * \note only use this value + */ + ScopeFrame* CurrentScopeFrame() { + ICHECK(!scope_stack_.empty()) << "no scope is being opened"; + return &scope_stack_.back(); + } + + /*! + * \brief Emits an Expr, and returns the variable it is bound to. + * \param expr The Expr to be emitted. + * \param is_dataflow Is the bound variable a DataflowVar or not(i.e. Var). + * \param name_hint Name hint for the bound variable. + * \note This Emit function normalizes the \p expr, + * and performs shape/type deductions by calling Normalize. + * \return The new variable that \p expr is bound to. + */ + Var Emit(Expr expr, bool is_dataflow, String name_hint) { + expr = this->Normalize(expr); + + Var var = CreateVar(is_dataflow, name_hint); + + // set the values + UpdateStructInfo(var, Downcast(expr->struct_info_.value())); + + CurrentBlockFrame()->bindings.push_back(VarBinding(var, expr)); + + // update the binding table + binding_table_[var->vid] = expr; + + return var; + } + + /*! + * \brief Create var for bindings + * \param is_dataflow Is the bound variable a DataflowVar or not(i.e. Var). + * \param name_hint Name hint for the bound variable. + * \return The created var. + */ + Var CreateVar(bool is_dataflow, String name_hint) { + if (name_hint.empty()) { + name_hint = is_dataflow ? "lv" : "gv"; + } + Id vid = Id(name_table_->GetUniqueName(name_hint)); + return is_dataflow ? DataflowVar(vid, /*struct_info_annotation=*/NullOpt) + : Var(vid, /*struct_info_annotation=*/NullOpt); + } + + private: + /*! + * \brief A hashmap to store the mapping of Relax functions and TIR PrimFuncs + * in context_mod to their GlobalVar to avoid generating duplicated functions. + */ + std::unique_ptr> + ctx_func_dedup_map_ = nullptr; + + /*! + * \brief lazily initialize function dedeup map. + */ + void LazyInitCtxFuncDedupMap() { + if (ctx_func_dedup_map_ != nullptr) return; + ctx_func_dedup_map_ = std::make_unique< + std::unordered_map>(); + for (const auto& kv : context_mod_->functions) { + const GlobalVar gv = kv.first; + const BaseFunc func = kv.second; + ctx_func_dedup_map_->emplace(func, gv); + } + } + + // Collect all the variables that a parameter var can define. + // The collector is used to making sure that we record the + // shape vars as defined when calling BeginScope(params) + class StructInfoVarCollector : public StructInfoVisitor { + public: + static Map Collect(const StructInfo& struct_info) { + StructInfoVarCollector collector; + collector(struct_info); + return collector.shape_var_map_; + } + + private: + void VisitStructInfo_(const TensorStructInfoNode* op) final { + if (const auto* shape_expr = op->shape.as()) { + for (const PrimExpr& s : shape_expr->values) { + // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) + if (const auto* var = s.as()) { + shape_var_map_.Set(GetRef(var), s); + } + } + } + } + + void VisitStructInfo_(const ShapeStructInfoNode* op) final { + for (const PrimExpr& s : op->values.value_or(Array())) { + // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) + if (const auto* var = s.as()) { + shape_var_map_.Set(GetRef(var), s); + } + } + } + + private: + Map shape_var_map_; + }; +}; + +//--------------------------------------- +// Normalization +//--------------------------------------- +#define RELAX_EXPR_NORMALIZER_LEAF(OP) \ + Expr VisitExpr_(const OP* op) final { return GetRef(op); } + +// TODO(relax-team): Check normalize logic after struct info. + +// Normalizer on struct info: +// +// We take benefit of the following invariants(that are checked in constructor): +// - If an expr appears in StructInfo, then it is already normalized. +// As a result, we do not need to peek into StructInfo in Normalization. +// - Constant, ShapeExpr, already have their StructInfo populated in constructing time. +class Normalizer : public BlockBuilderImpl, private ExprFunctor { + public: + explicit Normalizer(IRModule context_mod) : BlockBuilderImpl(context_mod) {} + + Expr Normalize(const Expr& expr) final { + Expr normalized = this->VisitExpr(expr); + // Invariant: + // After Normalize: an Expr always have + // struct_info (with the exception of Op). + if (!normalized->IsInstance()) { + ICHECK(normalized->struct_info_.defined()) + << "The struct_info_ of an Expr except OpNode after " + "normalization must not be nullptr. However, this Expr does not have struct_info_: " + << normalized; + } + + return normalized; + } + + /*! + * \brief Normalize Argument values to call and other IR sub-fields. + * \param arg The argument. + * \return The normalized value. + * + * \note This function create a new binding for non-leaf expressions except for tuple. + */ + Expr NormalizeArgument(const Expr& arg) final { + if (!block_stack_.empty()) { + // cache lookup + BlockFrame* cur_frame = CurrentBlockFrame(); + auto it = cur_frame->normalize_binding_map.find(arg); + if (it != cur_frame->normalize_binding_map.end()) { + return it->second; + } + } + // skip visit expr's cache, normalize arg + Expr post = ExprFunctor::VisitExpr(arg); + + if (!IsLeafOrTuple(arg)) { + ICHECK(!block_stack_.empty()) << "Cannot normalize non-leaf without a scope"; + Var var = this->Emit(post, ""); + // NOTE: current frame addr can change due to underlying vector + // re-allocation, redo lookup + CurrentBlockFrame()->normalize_binding_map[arg] = var; + return var; + } else { + return post; + } + } + + RELAX_EXPR_NORMALIZER_LEAF(ExternFuncNode); + RELAX_EXPR_NORMALIZER_LEAF(GlobalVarNode); + RELAX_EXPR_NORMALIZER_LEAF(OpNode); + RELAX_EXPR_NORMALIZER_LEAF(ConstantNode); + RELAX_EXPR_NORMALIZER_LEAF(ShapeExprNode); + RELAX_EXPR_NORMALIZER_LEAF(PrimValueNode); + RELAX_EXPR_NORMALIZER_LEAF(StringImmNode); + RELAX_EXPR_NORMALIZER_LEAF(DataTypeImmNode); + + template + Expr VisitVar_(const typename T::ContainerType* var) { + // Parameters and free-vars must be present with struct info + // Other vars must have already been normalized through binding + ICHECK(var->struct_info_.defined()) + << "Var " << var->name_hint() << " does not have struct info."; + return GetRef(var); + } + + Expr VisitExpr_(const VarNode* var) final { return VisitVar_(var); } + + Expr VisitExpr_(const DataflowVarNode* var) final { return VisitVar_(var); } + + Expr VisitExpr(const Expr& expr) final { + // lookup normalize map + if (!block_stack_.empty()) { + BlockFrame* cur_frame = CurrentBlockFrame(); + auto it = cur_frame->normalize_binding_map.find(expr); + if (it != cur_frame->normalize_binding_map.end()) { + return it->second; + } + } + return ExprFunctor::VisitExpr(expr); + } + + Expr VisitExpr_(const TupleNode* op) final { + bool unchanged = true; + Array new_fields; + + for (const Expr& field : op->fields) { + Expr new_field = this->NormalizeArgument(field); + new_fields.push_back(new_field); + unchanged &= new_field.same_as(field); + } + + Tuple tuple = unchanged ? GetRef(op) : Tuple(new_fields, op->span); + // Update tuple fields. + if (!tuple->struct_info_.defined()) { + Array tuple_sinfo; + for (Expr field : tuple->fields) { + tuple_sinfo.push_back(GetStructInfo(field)); + } + UpdateStructInfo(tuple, TupleStructInfo(tuple_sinfo, op->span)); + } + return tuple; + } + + Expr VisitExpr_(const FunctionNode* op) final { + Expr new_body = this->VisitWithNewScope(op->body, op->params); + + if (new_body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->params, new_body, op->ret_struct_info, op->attrs); + } + } + + Expr VisitExpr_(const CallNode* op) final { + Expr new_op = this->NormalizeArgument(op->op); + bool unchanged = new_op.same_as(op->op); + + Array new_args; + + for (Expr arg : op->args) { + Expr new_arg = this->NormalizeArgument(arg); + new_args.push_back(new_arg); + unchanged &= new_arg.same_as(arg); + } + + Call call; + if (unchanged) { + call = GetRef(op); + } else { + call = Call(new_op, new_args, op->attrs, op->sinfo_args); + } + + if (!call->struct_info_.defined()) { + auto inferred_sinfo = InferStructInfo(call); + UpdateStructInfo(call, inferred_sinfo); + } + + return call; + } + + Expr VisitExpr_(const SeqExprNode* op) final { + bool unchanged = true; + Array new_blocks; + for (BindingBlock block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + new_blocks.push_back(new_block); + unchanged &= new_block.same_as(block); + } + + this->BeginBindingBlock(); + // the body may not be a leaf expression, so check for that + Expr new_body = this->NormalizeArgument(op->body); + unchanged &= new_body.same_as(op->body); + BindingBlock prologue = this->EndBlock(); + + if (!prologue->bindings.empty()) { + new_blocks.push_back(prologue); + unchanged = false; + } + + // Combine nearby blocks if possible + Array normalized_blocks = NormalizeBlocks(new_blocks); + unchanged &= normalized_blocks.same_as(new_blocks); + + SeqExpr seq_expr; + if (unchanged) { + seq_expr = GetRef(op); + } else { + seq_expr = SeqExpr(normalized_blocks, new_body, op->span); + } + + // only do shape/type inference if the SeqExpr does not have shape/type + if (!seq_expr->struct_info_.defined()) { + UpdateStructInfo(seq_expr, EraseToWellDefinedInScope(GetStructInfo(seq_expr->body))); + } + return seq_expr; + } + + Expr VisitExpr_(const IfNode* op) final { + Expr new_cond = this->NormalizeArgument(op->cond); + Expr new_true = this->VisitWithNewScope(op->true_branch); + Expr new_false = this->VisitWithNewScope(op->false_branch); + + If if_node; + if (new_cond.same_as(op->cond) && new_true.same_as(op->true_branch) && + new_false.same_as(op->false_branch)) { + if_node = GetRef(op); + } else { + if_node = If(new_cond, new_true, new_false, op->span); + } + if (!if_node->struct_info_.defined()) { + auto true_info = EraseToWellDefinedInScope(GetStructInfo(new_true)); + auto false_info = EraseToWellDefinedInScope(GetStructInfo(new_false)); + UpdateStructInfo(if_node, StructInfoLCA(true_info, false_info)); + } + return if_node; + } + + Expr VisitExpr_(const TupleGetItemNode* op) final { + Expr new_tuple = this->NormalizeArgument(op->tuple); + + TupleGetItem node = new_tuple.same_as(op->tuple) ? GetRef(op) + : TupleGetItem(new_tuple, op->index); + + if (!node->struct_info_.defined()) { + auto opt = MatchStructInfo(node->tuple); + ICHECK(opt) << "The struct info of Tuple must be TupleStructInfo."; + UpdateStructInfo(node, opt.value()->fields[node->index]); + } + + return node; + } + + Binding VisitBinding(const Binding& binding) { + if (auto* var_binding = binding.as()) { + return this->VisitVarBinding(GetRef(var_binding)); + } else { + auto* match_cast = binding.as(); + ICHECK(match_cast) << "Unsupported binding type: " << binding->GetTypeKey(); + return this->VisitMatchCast(GetRef(match_cast)); + } + } + + VarBinding VisitVarBinding(VarBinding binding) { + Expr new_value = this->VisitExpr(binding->value); + if (!new_value.same_as(binding->value)) { + binding = VarBinding(binding->var, new_value, binding->span); + } + if (!binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, GetStructInfo(new_value)); + } + return binding; + } + + MatchCast VisitMatchCast(MatchCast binding) { + Expr new_value = this->VisitExpr(binding->value); + if (!new_value.same_as(binding->value)) { + binding = MatchCast(binding->var, new_value, binding->struct_info, binding->span); + } + if (!binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, binding->struct_info); + } + return binding; + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) { + if (block.as()) { + this->BeginDataflowBlock(); + } else { + this->BeginBindingBlock(); + } + + bool unchanged = true; + for (const Binding& binding : block->bindings) { + Binding new_binding = this->VisitBinding(binding); + unchanged &= new_binding.same_as(binding); + + this->EmitNormalized(new_binding); + } + BindingBlock new_block = this->EndBlock(); + unchanged &= new_block->bindings.size() == block->bindings.size(); + if (unchanged) { + return block; + } + return new_block; + } + + private: + // Helper function to infer the type of a Call. + StructInfo InferStructInfo(const Call& call) { + if (auto* op_ptr = call->op.as()) { + // Case 1: the op field is a primitive op, look up FInferStructInfo attribute + Op op = GetRef(op_ptr); + ICHECK(op_map_infer_struct_info_.count(op)) + << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; + return op_map_infer_struct_info_[op](call, GetRef(this)); + } else { + // derive using function parameters + ICHECK(call->op->struct_info_.defined()); + auto opt = MatchStructInfo(call->op); + ICHECK(opt) << "Call->op must contains a function struct info"; + FuncStructInfo finfo = opt.value(); + return DeriveCallRetStructInfo(finfo, call, GetRef(this), &analyzer_); + } + } + + // erase to well defined within current scope. + StructInfo EraseToWellDefinedInScope(StructInfo info) { + if (scope_stack_.empty()) { + return EraseToWellDefined(info); + } + auto* curr_scope = CurrentScopeFrame(); + auto f_shape_var_map = [curr_scope](tir::Var var) -> Optional { + auto it = curr_scope->shape_var_map.find(var); + if (it != curr_scope->shape_var_map.end()) return (*it).second; + return NullOpt; + }; + return EraseToWellDefined(info, f_shape_var_map); + } + + Expr VisitWithNewScope(const Expr& expr, Optional> params = NullOpt) { + // SeqExpr do not need to prepare for normalization. + if (expr.as()) { + this->BeginScope(params); + Expr ret = this->VisitExpr(expr); + this->EndScope(); + return ret; + } else { + this->BeginScope(params); + + this->BeginBindingBlock(); + Expr post = this->NormalizeArgument(expr); + BindingBlock prologue = this->EndBlock(); + // "New scopes" (function bodies, if/else clauses) must be wrapped in seq exprs. + // Don't wrap if it's already a seq and there are no bindings to add + if (post.as() && prologue->bindings.empty()) { + return post; + } + Array bindings; + if (!prologue->bindings.empty()) { + bindings.push_back(prologue); + } + + SeqExpr seq(bindings, post); + UpdateStructInfo(seq, EraseToWellDefinedInScope(GetStructInfo(seq->body))); + + this->EndScope(); + return seq; + } + } + + Array FlattenBlocks(const Array& blocks) { + // If there is a binding that is a seq expr, split the current block, + // add the nested blocks prior to the seq expr, and bind the seq expr body + // to the var + Array ret; + bool changed = false; + for (const BindingBlock& block : blocks) { + bool is_dataflow = block->IsInstance(); + Array current; + for (const Binding& binding : block->bindings) { + Expr value; + if (const auto* var_binding = binding.as()) { + value = var_binding->value; + } else if (const auto* match_cast = binding.as()) { + value = match_cast->value; + } else { + LOG(FATAL) << "Unknown binding type: " << binding->GetTypeKey(); + } + // if we encounter a nested seq, we have to flatten it: + // 1. Append the binding block we've accumulated so far + // 2. Reset the current block + // 3. Append the inner blocks + // 4. Add a binding of the current var to the seq expr's body to the current block + // then continue + if (auto seq = value.as()) { + changed = true; + ret.push_back(is_dataflow ? DataflowBlock(current) : BindingBlock(current)); + current = {}; + // We do not need to flatten recursively because the normalizer will have normalized + // and thus flattened the inner SeqExprs already + for (const BindingBlock& block : seq->blocks) { + if (is_dataflow && !block->IsInstance()) { + LOG(WARNING) << "Malformed AST: Seq expr nested inside a dataflow block contains a " + "non-dataflow block! " + << seq; + } + ret.push_back(block); + } + + if (const auto* var_binding = binding.as()) { + current.push_back(VarBinding(var_binding->var, seq->body)); + } else if (const auto* match_cast = binding.as()) { + current.push_back(MatchCast(match_cast->var, seq->body, match_cast->struct_info)); + } else { + LOG(FATAL) << "Unknown binding type: " << binding->GetTypeKey(); + } + } else { + current.push_back(binding); + } + } + ret.push_back(is_dataflow ? DataflowBlock(current) : BindingBlock(current)); + } + return changed ? ret : blocks; + } + + Array NormalizeBlocks(const Array& blocks) { + bool changed = false; + Array ret; + auto flattened = FlattenBlocks(blocks); + if (!flattened.same_as(blocks)) { + changed = true; + } + for (const BindingBlock& block : flattened) { + if (block->bindings.empty()) { + // Case 1. Skip empty blocks + changed = true; + } else if (!ret.empty() && ret.back()->type_index() == block->type_index()) { + // Case 2. Merge with previous block if possible + BindingBlock merged; + // NOTE: should check DataflowBlockNode first. + if (const auto* dataflow_block = ret.back().as()) { + auto n = make_object(*dataflow_block); + n->bindings.insert(n->bindings.end(), block->bindings.begin(), block->bindings.end()); + merged = DataflowBlock(n); + } else if (const auto* binding_block = ret.back().as()) { + auto n = make_object(*binding_block); + n->bindings.insert(n->bindings.end(), block->bindings.begin(), block->bindings.end()); + merged = BindingBlock(n); + } else { + LOG(FATAL) << "Unknown block type: " << ret.back()->GetTypeKey(); + } + ret.pop_back(); + ret.push_back(merged); + changed = true; + } else { + // Case 3. Add to the result + ret.push_back(block); + } + } + return changed ? ret : blocks; + } + + /*! \brief Operator struct info inference map. */ + tvm::OpAttrMap op_map_infer_struct_info_ = + Op::GetAttrMap("FInferStructInfo"); +}; + +BlockBuilder BlockBuilder::Create(Optional mod) { + ObjectPtr n = make_object(mod.value_or(IRModule())); + return BlockBuilder(n); +} + +//--------------------------------------- +// User facing function registration. +//--------------------------------------- +TVM_REGISTER_OBJECT_TYPE(BlockBuilderNode); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed([](Optional mod) { + return BlockBuilder::Create(mod); +}); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginDataflowBlock") + .set_body_method(&BlockBuilderNode::BeginDataflowBlock); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginBindingBlock") + .set_body_method(&BlockBuilderNode::BeginBindingBlock); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEndBlock") + .set_body_method(&BlockBuilderNode::EndBlock); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize") + .set_body_method(&BlockBuilderNode::Normalize); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit") + .set_body_typed([](BlockBuilder builder, Expr expr, String name_hint) { + return builder->Emit(expr, name_hint); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast") + .set_body_typed([](BlockBuilder builder, Expr value, StructInfo struct_info) { + return builder->EmitMatchCast(value, struct_info); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") + .set_body_typed([](BlockBuilder builder, const Expr& output, String name_hint) { + return builder->EmitOutput(output, name_hint); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitNormalized") + .set_body_typed([](BlockBuilder builder, Binding binding) { + return builder->EmitNormalized(binding); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") + .set_body_typed([](BlockBuilder builder, String name_hint) { + return builder->name_table()->GetUniqueName(name_hint); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderAddFunction") + .set_body_method(&BlockBuilderNode::AddFunction); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderUpdateFunction") + .set_body_method(&BlockBuilderNode::UpdateFunction); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderGetContextIRModule") + .set_body_method(&BlockBuilderNode::GetContextIRModule); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderCurrentBlockIsDataFlow") + .set_body_method(&BlockBuilderNode::CurrentBlockIsDataFlow); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderLookupBinding") + .set_body_method(&BlockBuilderNode::LookupBinding); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginScope") + .set_body_method(&BlockBuilderNode::BeginScope); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEndScope") + .set_body_method(&BlockBuilderNode::EndScope); +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc new file mode 100644 index 000000000000..c1306ff69093 --- /dev/null +++ b/src/relax/ir/dataflow_matcher.cc @@ -0,0 +1,929 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_matcher.cc + * \brief The dataflow pattern matcher for Relax. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dataflow_matcher_impl.h" + +namespace tvm { +namespace relax { + +using tvm::arith::Analyzer; + +// Pattern Matcher +bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { + memo_.clear(); + matched_nodes_.clear(); + return VisitDFPattern(pattern, expr); +} + +static Expr TryGetValOfVar(const Expr& expr, const Map& var2val) { + if (var2val.empty()) return expr; + + // if not match, try to match value of var if expr is a var. + if (const VarNode* var = expr.as()) { + auto may = var2val.Get(GetRef(var)); + if (may.defined()) return may.value(); + } + + return expr; +} + +void DFPatternMatcher::ClearMap(size_t watermark) { + for (size_t i = watermark; i < matched_nodes_.size(); ++i) { + memo_.erase(matched_nodes_[i]); + } + matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end()); +} + +bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + if (memoize_ && memo_.count(pattern)) { + ICHECK_EQ(memo_[pattern].size(), 1); + return expr.same_as(memo_[pattern][0]); + } else { + size_t watermark = matched_nodes_.size(); + bool out = DFPatternFunctor::VisitDFPattern(pattern, expr); + if (out) { + memo_[pattern].push_back(expr); + matched_nodes_.push_back(pattern); + } else { + ClearMap(watermark); + } + return out; + } +} + +bool DFPatternMatcher::VisitDFPattern_(const OrPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const AndPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + return VisitDFPattern(op->left, expr) && VisitDFPattern(op->right, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const NotPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + return !VisitDFPattern(op->reject, expr); +} + +bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { + switch (rhs.type_code()) { + case kDLInt: + if (auto* val = lhs.as()) { + return val->value == rhs.operator int64_t(); + } + break; + case kDLFloat: + if (auto* val = lhs.as()) { + return val->value == rhs.operator double(); + } + break; + case kTVMStr: + if (auto* val = lhs.as()) { + return val->value == rhs.operator std::string(); + } else if (auto* val = lhs.as()) { + return val->data == rhs.operator std::string(); + } + break; + case kTVMDataType: + if (auto* val = lhs.as()) { + return rhs.operator std::string() == val->value; + } else if (auto* val = lhs.as()) { + return rhs.operator std::string() == val->data; + } else { + ICHECK(false) << "PatternMatcher: Unsupported TVMDataType " << lhs; + } + break; + case kTVMObjectHandle: + if (rhs.IsObjectRef()) { + if (auto* val = lhs.as()) { + return rhs.operator String() == val->value; + } else if (auto* val = lhs.as()) { + return rhs.operator String() == val->data; + } + } else { + // Compare the objects for structural equality + static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); + ICHECK(structural_equal) << "node.StructuralEqual is not registered."; + if ((*structural_equal)(lhs, GetRef(rhs.ptr()), false, true)) { + return true; + } + } + break; + default: + ICHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code(); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + bool matches = VisitDFPattern(attr_pattern->pattern, expr); + if (!matches) return matches; + VLOG(1) << "considering AttrPatternNode at:\n" << expr; + auto attributes = attr_pattern->attrs.as()->dict; + if (const auto* op_node = expr.as()) { + Op op = GetRef(op_node); + for (auto kv : attributes) { + auto attr_name = kv.first; + auto attr_value = kv.second; + if (Op::HasAttrMap(attr_name)) { + auto op_map = Op::GetAttrMap(attr_name); + if (op_map.count(op)) { + matches &= MatchRetValue(attr_value, op_map[op]); + } else { + matches = false; + } + } else { + matches = false; + } + } + } else if (auto* op = expr.as()) { + matches = true; + // TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this + // and replace the whole thing with a Visitor-based approach + ReflectionVTable* reflection = ReflectionVTable::Global(); + auto attrs_node = const_cast(op->attrs.get()); + // attrs may be undefined on non-op calls so we check first + std::vector attr_names; + if (attrs_node) { + attr_names = reflection->ListAttrNames(attrs_node); + } + for (auto kv : attributes) { + std::string attr = kv.first; + if (matches && std::find(attr_names.begin(), attr_names.end(), attr) != attr_names.end()) { + matches &= MatchRetValue(kv.second, reflection->GetAttr(attrs_node, attr)); + } else { + matches = false; + break; + } + } + } else if (auto* op = expr.as()) { + matches = true; + for (auto kv : attributes) { + if (matches && op->attrs.defined() && op->attrs->dict.count(kv.first)) { + matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]); + } else { + matches = false; + break; + } + } + } else { + matches = false; + } + return matches; +} + +bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + // utilities + auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { + if (op) { + if (auto* expr_pattern = op->op.as()) { + return expr_pattern->expr.as(); + } + } + return nullptr; + }; + auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) { + if (const auto* op_node = get_op_node(op)) { + if (op_node->name == op_type) { + return true; + } + } + return false; + }; + auto is_expr_op = [](const Expr& expr, std::string op_type) { + if (const auto* call_node = expr.as()) { + if (const auto* op_node = call_node->op.as()) { + if (op_node->name == op_type) { + return true; + } + } + } + return false; + }; + + // logic + auto watermark = matched_nodes_.size(); + if (const auto* call_node = expr.as()) { + auto matches_op = VisitDFPattern(op->op, call_node->op); + if (matches_op) { + auto watermark2 = matched_nodes_.size(); + + auto match_args = [this, &watermark2](const Array& pattern_args, auto expr_begin, + auto expr_end) { + bool matches = true; + auto pattern_it = pattern_args.begin(); + auto expr_it = expr_begin; + if (pattern_args.defined()) { + while (matches && pattern_it != pattern_args.end()) + matches &= VisitDFPattern(*(pattern_it++), *(expr_it++)); + } + if (!matches) ClearMap(watermark2); + return matches; + }; + + const size_t n_arg_pattern = op->args.size(); + const size_t n_arg_expr = call_node->args.size(); + // if allow variable args, #pattern must >= #expr. + if (op->varg_default_wildcard && n_arg_expr < n_arg_pattern) return false; + // if variable args are not allowed, #pattern must == #expr. + if (!op->varg_default_wildcard && n_arg_expr != n_arg_pattern) return false; + + // Standard case + if (match_args(op->args, call_node->args.begin(), call_node->args.end())) return true; + + // Commutative Matching + if (const OpNode* op_node = get_op_node(op)) { + if ((op_node->name == "relax.add") || (op_node->name == "relax.multiply")) { + if (match_args(op->args, call_node->args.rbegin(), call_node->args.rend())) { + return true; + } + } + } + } else { + ClearMap(watermark); + // associate divide/multiply + if (is_pattern_op(op, "relax.divide")) { + if (const auto* arg_node = op->args[0].as()) { + if (is_pattern_op(arg_node, "relax.multiply") && is_expr_op(expr, "relax.multiply") && + (is_expr_op(call_node->args[0], "relax.divide") || + is_expr_op(call_node->args[1], "relax.divide"))) { + bool out = false; + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}); + auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}); + out = VisitDFPattern(mul, expr); + if (out) { + return true; + } else { + ClearMap(watermark); + } + } + return out; + } + } + } + if (is_pattern_op(op, "relax.multiply")) { + // associate multiply/divide + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + if (auto* arg_node = op->args[arg_id].as()) { + if (is_pattern_op(arg_node, "relax.divide") && is_expr_op(expr, "relax.divide") && + (is_expr_op(call_node->args[0], "relax.multiply") || + is_expr_op(call_node->args[1], "relax.multiply"))) { + auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}); + auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}); + return VisitDFPattern(div, expr); + } + } + } + } + } + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + return StructuralEqual()(op->expr, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + bool matches = false; + if (const auto* func = expr.as()) { + matches = true; + if (op->params.defined()) { + size_t i = 0; + if (op->params.size() == func->params.size()) { + while (matches && i < op->params.size()) { + matches &= VisitDFPattern(op->params[i], func->params[i]); + ++i; + } + } else { + matches = false; + } + } + if (matches) { + matches &= VisitDFPattern(op->body, func->body); + } + } + return matches; +} + +bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + if (const auto* tuple_get_item_node = expr.as()) { + return (op->index == -1 || op->index == tuple_get_item_node->index) && + VisitDFPattern(op->tuple, tuple_get_item_node->tuple); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + bool matches = false; + if (const auto* tuple_node = expr.as()) { + matches = true; + if (op->fields.size() == tuple_node->fields.size()) { + size_t i = 0; + while (matches && i < op->fields.size()) { + matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]); + ++i; + } + } else { + matches = false; + } + } + return matches; +} + +bool DFPatternMatcher::TryUnorderedMatch(size_t idx, const tvm::Array patterns, + const tvm::Array fields, + std::vector& match_cache, + std::vector& matched) { + if (idx >= patterns.size()) return true; + constexpr int8_t kUnknown = -1; + auto this_pattern = patterns[idx]; + for (size_t i = 0; i < fields.size(); ++i) { + if (matched[i]) continue; + const size_t table_idx = idx * fields.size() + i; + match_cache[table_idx] = + kUnknown ? VisitDFPattern(this_pattern, fields[i]) : match_cache[table_idx]; + if (match_cache[table_idx]) { + // continue to match the rest; + matched[i] = true; + if (TryUnorderedMatch(idx + 1, patterns, fields, match_cache, matched)) return true; + matched[i] = false; + } + } + + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const UnorderedTuplePatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + + if (const auto* tuple_node = expr.as()) { + if (op->fields.size() == tuple_node->fields.size()) { + constexpr int8_t kUnknown = -1; + ICHECK_LE(op->fields.size(), std::numeric_limits::max()) << "Too many fields!"; + // dynamic programming. + std::vector match_cache(op->fields.size() * op->fields.size(), kUnknown); + std::vector field_match_bitmap(op->fields.size(), false); + return TryUnorderedMatch(0, op->fields, tuple_node->fields, match_cache, field_match_bitmap); + } + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + auto expr_type = expr.as()->checked_type(); + return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); +} + +static bool ShapeEqual(Analyzer* analyzer, const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) + if (!tir::is_one(analyzer->Simplify(lhs[i] == rhs[i]))) return false; + return true; +} + +bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) { + // no need to jump, as var.shape == value.shape + if (const auto* tinfo = GetStructInfoAs(expr)) { + if (const ShapeExprNode* shape_expr = tinfo->shape.as()) { + return ShapeEqual(&analyzer_, op->shape, shape_expr->values) && + VisitDFPattern(op->pattern, expr); + } + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const PrimArrPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + if (const ShapeExprNode* shape_expr = expr.as()) + return ShapeEqual(&analyzer_, op->fields, shape_expr->values); + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) { + // no need to jump, as var.dtype == value.dtype + auto expr_type = expr.as()->checked_type(); + if (const DynTensorTypeNode* tensor_type = expr_type.as()) { + return (StructuralEqual()(op->dtype, tensor_type->dtype)) && VisitDFPattern(op->pattern, expr); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) { + // We don't jump for var pattern, as there's no need to access its value to judge it. + if (const auto* var_node = expr.as()) { + // "" means any name. + return "" == op->name_hint() || op->name_hint() == var_node->name_hint(); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const ExternFuncPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + if (const auto* extern_fn = expr.as()) { + return "" == op->global_symbol() || op->global_symbol() == extern_fn->global_symbol; + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr0) { + // constants can be binded to relax.Var as well. + auto expr = TryGetValOfVar(expr0, var2val_); + return expr.as() != nullptr; +} + +bool DFPatternMatcher::VisitDFPattern_(const DataflowVarPatternNode* op, const Expr& expr) { + // DataflowVar is inherented from Var, so dispatch it to VarPattern. + return expr->IsInstance() && + VisitDFPattern_(static_cast(op), expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const GlobalVarPatternNode* op, const Expr& expr) { + // GlobalVarPattern is not inherited from Var, so we need to handle it separately. + if (const auto* var_node = expr.as()) + return "" == op->name_hint() || op->name_hint() == var_node->name_hint; + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) { + return true; +} + +Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, + Optional> bindings_opt) { + auto bindings = bindings_opt.value_or({}); + DFPatternMatcher matcher(bindings); + + if (!matcher.Match(pattern, expr)) { + return NullOpt; + } + + Map matching; + for (const auto& [pat, matches] : matcher.GetMemo()) { + ICHECK_EQ(matches.size(), 1) << "More than one match for the pattern " << pat; + matching.Set(pat, matches[0]); + } + return matching; +} + +TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); + +bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { + return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); + +struct PNode { + const DFPatternNode* ptr; + const VarNode* matched = nullptr; + std::vector&>> children; + std::vector&>> parents; +}; + +struct RNode { + const VarNode* ptr; + const DFPatternNode* matched = nullptr; + std::vector children; + std::vector parents; +}; + +/** + * \brief This method try to match a real node and a pattern node along with its neighbors. + */ +using UndoItems = std::vector>; +static std::optional try_match( + PNode* p, RNode* r, DFPatternMatcher* m, + const std::map>& def2use, + const std::map>& use2def) { + if (p->matched != nullptr && p->matched == r->ptr) return {}; // matched before. + if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return std::nullopt; + + UndoItems undo; + + const auto commit = [&undo](PNode* p, RNode* r) { + // match with each other. + // TODO(ganler, masahi): Why commit on the same p-r pair happens more than once? + if (p->ptr == r->matched) { + ICHECK_EQ(p->matched, r->ptr); + return; + } + p->matched = r->ptr; + r->matched = p->ptr; + undo.emplace_back(p, r); + }; + + const auto quit = [&undo] { + for (auto& [p_node, r_node] : undo) { + p_node->matched = nullptr; + r_node->matched = nullptr; + } + return std::nullopt; + }; + + const auto try_match_update_undo = [&](PNode* p, RNode* r) { + if (auto undo_more = try_match(p, r, m, def2use, use2def)) { + undo.insert(undo.end(), undo_more->begin(), undo_more->end()); + return true; + } + return false; + }; + + commit(p, r); + + // match parent patterns. + for (auto& [pparent, constraints] : p->parents) { + bool any_cons_sat = false; + for (auto& rparent : r->parents) { + // skip if mismatch. + if (rparent->matched && rparent->matched != pparent->ptr) continue; + + const auto& uses = def2use.at(rparent->ptr); + + // check edge constraints. + bool cons_sat = true; + for (const auto& cons : constraints) { + if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { + cons_sat = false; + break; + } + + if (cons.index != -1) { + const auto& callees = use2def.at(r->ptr); + if (callees.size() <= static_cast(cons.index) || + callees[cons.index] != rparent->ptr) { + cons_sat = false; + break; + } + } + } + if (!cons_sat) continue; + any_cons_sat = true; + + // try all parent R nodes that are not matched yet. + // as long as ppattern can match one node. + if (!pparent->matched && try_match_update_undo(pparent, rparent)) { + commit(pparent, rparent); + break; + } + } + if (!pparent->matched || !any_cons_sat) return quit(); + } + + // forward matching; + for (auto& [pchild, constraints] : p->children) { + bool any_cons_sat = false; + for (auto& rchild : r->children) { + if (rchild->matched && rchild->matched != pchild->ptr) continue; + + const auto& uses = def2use.at(r->ptr); + + // check edge constraints. + bool all_cons_pass = true; + for (const auto& cons : constraints) { + if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { + all_cons_pass = false; + break; + } + + if (cons.index != -1) { + const auto& callees = use2def.at(rchild->ptr); + if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r->ptr) { + all_cons_pass = false; + break; + } + } + } + if (!all_cons_pass) continue; + any_cons_sat = true; + + if (!pchild->matched && try_match_update_undo(pchild, rchild)) { + commit(pchild, rchild); + break; + } + } + if (!pchild->matched || !any_cons_sat) return quit(); + } + return undo; +} + +class MatcherUseDefAnalysis : public relax::ExprVisitor { + public: + std::vector vars; + std::map> def2use; + // caller -> callee table. + std::map> caller2callees; + + const VarNode* cur_user_; + + void VisitBinding_(const VarBindingNode* binding) override { + // init + cur_user_ = binding->var.get(); + this->VisitVarDef(binding->var); + this->VisitExpr(binding->value); + cur_user_ = nullptr; + } + + void VisitExpr_(const VarNode* op) override { + if (nullptr == cur_user_) return; + + auto check_and_push = [](std::vector& vec, const VarNode* var) { + if (std::find(vec.begin(), vec.end(), var) == vec.end()) { + vec.push_back(var); + } + }; + + check_and_push(def2use[op], cur_user_); + check_and_push(vars, op); + + caller2callees[cur_user_].push_back(op); + } + + void VisitExpr_(const DataflowVarNode* op) override { + VisitExpr_(static_cast(op)); + } +}; + +Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, + Optional start_hint, bool must_include_hint) { + if (ctx->src_ordered.size() == 0) { + return NullOpt; + } + + // TODO(@ganler): Handle non-may external use. + ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; + ICHECK(!must_include_hint || start_hint.defined()) + << "must_include_hint is only supported with start_hint."; + + const auto var2val = AnalyzeVar2Value(dfb); + DFPatternMatcher matcher(var2val); + + MatcherUseDefAnalysis ud_analysis; + ud_analysis.VisitBindingBlock_(dfb.get()); + const auto& def2use = ud_analysis.def2use; + const auto& caller2callees = ud_analysis.caller2callees; + + // First construct a graph of PNode and RNode. + std::unordered_map var2node; + var2node.reserve(dfb->bindings.size()); + + for (const VarNode* cur_var : ud_analysis.vars) { + const auto& uses = def2use.at(cur_var); + RNode& cur_node = var2node[cur_var]; + cur_node.ptr = cur_var; + for (const VarNode* use : uses) { + auto& use_node = var2node[use]; + use_node.ptr = use; + cur_node.children.push_back(&use_node); + use_node.parents.push_back(&cur_node); + } + } + + std::unordered_map pattern2node; + pattern2node.reserve(ctx->constraints.size()); + + for (const auto& [def_pattern, uses] : ctx->constraints) { + PNode& def_node = pattern2node[def_pattern.get()]; + def_node.ptr = def_pattern.get(); + def_node.children.reserve(uses.size()); + for (const auto& [use_pattern, cons] : uses) { + PNode& use_node = pattern2node[use_pattern.get()]; + use_node.ptr = use_pattern.get(); + use_node.parents.emplace_back(&def_node, std::ref(cons)); + def_node.children.emplace_back(&use_node, std::ref(cons)); + } + } + + Map ret; + + if (start_hint) { + auto rnode_ptr = var2node.at(start_hint.value().get()); + for (auto& p_node : pattern2node) { + if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use, caller2callees)) { + for (const auto& [df_pattern, pattern_node] : pattern2node) { + ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); + } + return ret; + } + } + + if (must_include_hint) return ret; + } + + PNode& pnode_start = pattern2node[ctx->src_ordered[0].get()]; + + if (!pnode_start.matched) { + for (const auto& var : ud_analysis.vars) { + if (start_hint.defined() && start_hint.value().get() == var) continue; + RNode& r_node = var2node[var]; + if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees)) { + for (const auto& [df_pattern, pattern_node] : pattern2node) { + ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); + } + return ret; + } + } + } + + return NullOpt; +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph); + +/*! + * \brief Apply pattern matching to each call node and dataflow block, and replace matching ones + * with the output of a user-provided rewriter function. + */ +class PatternRewriter : ExprMutator { + public: + using ExprMutator::VisitBindingBlock_; + using ExprMutator::VisitExpr_; + + PatternRewriter(DFPattern pat, PackedFunc rewriter_func, + const std::unordered_set& params) + : pattern_(pat), rewriter_func_(rewriter_func), params_(params) {} + + PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func, + const std::unordered_set& params) + : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {} + + template + static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) { + std::unordered_set params; + for (const auto& p : f->params) { + params.insert(p.get()); + } + PatternRewriter rewriter(pat, rewriter_func, params); + return RemoveAllUnused(Downcast(rewriter.VisitExpr(f))); + } + + void VisitBinding_(const VarBindingNode* binding) final { + bindings_.Set(binding->var, binding->value); + ExprMutator::VisitBinding_(binding); + if (auto it = memo_.find(binding->value.get()); it != memo_.end()) { + // We need to update the binding to pass to ExtractMatchedExpr, so that the rewritten + // expression can be subject to further pattern matchings. + bindings_.Set(binding->var, it->second); + } + } + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = ExprMutator::VisitExpr_(call_node); + if (!pattern_) { + return call; + } else if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), call, bindings_)) { + auto rewriten_expr = rewriter_func_(call, matches_opt.value()); + memo_[call_node] = rewriten_expr; + return rewriten_expr; + } + return call; + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final { + if (!ctx_) { + return ExprMutator::VisitBindingBlock_(block_node); + } + return RewriteDataflowBlockFixedPoint(GetRef(block_node)); + } + + private: + void EmitUsedVars(Expr val, const Array& pending_bindings, + std::unordered_set* emitted_vars) { + std::unordered_set unemitted_vars; + PostOrderVisit(val, [=, &unemitted_vars](Expr e) { + if (auto v = e.as(); v && !emitted_vars->count(v)) { + unemitted_vars.insert(v); + } + }); + + if (unemitted_vars.empty()) { + return; + } + + size_t num_unemitted = unemitted_vars.size(); + for (size_t i = 0; i < pending_bindings.size(); ++i) { + const auto& binding = pending_bindings[i]; + if (auto var_bind = binding.as(); + var_bind && unemitted_vars.count(var_bind->var.get())) { + // var_bind->value may also depend on other unemitted vars in this range + Array prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i); + EmitUsedVars(var_bind->value, prev_bindings, emitted_vars); + this->VisitBinding(binding); + emitted_vars->insert(var_bind->var.get()); + if (--num_unemitted == 0) { + return; + } + } + } + } + + // Repeat until all matchable subsets of bindings are rewritten. + BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { + if (auto matches = MatchGraph(ctx_.value(), Downcast(block))) { + builder_->BeginDataflowBlock(); + Map replacements = rewriter_func_(matches.value()); + + std::unordered_set emitted_vars; + + for (size_t i = 0; i < block->bindings.size(); ++i) { + const auto& binding = block->bindings[i]; + if (auto var_bind = binding.as()) { + if (replacements.count(var_bind->var)) { + auto new_val = replacements[var_bind->var]; + Array pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); + // Make sure there is no unbound variable used in the new value before it is emitted + EmitUsedVars(new_val, pending_bindings, &emitted_vars); + this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); + } else if (!emitted_vars.count(var_bind->var.get())) { + this->VisitBinding(binding); + emitted_vars.insert(var_bind->var.get()); + } + } else { + this->VisitBinding(binding); + } + } + return RewriteDataflowBlockFixedPoint(builder_->EndBlock()); + } + return block; + } + + /*! \brief The pattern for rewriting call nodes */ + Optional pattern_; + /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ + Optional ctx_; + /*! + * \brief The user-provided rewriter function. Its signature and semantics are: + * - (Call, Map) -> Call for call node rewriting. Given the matched + * call node and the map of patterns and matched expressions, it should return a new call node + * to replace the original one or the original matched call node as is. + * - Map -> Map for dataflow block rewriting. Given the map of patterns + * and corresponding variables (bound variables or parameters), it should return a map that + * specifies new values for matched bound variables. + */ + PackedFunc rewriter_func_; + std::unordered_set params_; + Map bindings_; + std::unordered_map memo_; +}; + +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call") + .set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) { + return PatternRewriter::Run(pat, rewriter, f); + }); + +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings") + .set_body_typed([](const PatternContext& ctx, PackedFunc rewriter, Function f) { + return PatternRewriter::Run(ctx, rewriter, f); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_matcher_impl.h b/src/relax/ir/dataflow_matcher_impl.h new file mode 100644 index 000000000000..89f3d114c1e3 --- /dev/null +++ b/src/relax/ir/dataflow_matcher_impl.h @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relax/dataflow_matcher_impl.h + * \brief The auxiliary data structure for dataflow matcher. + */ +#ifndef TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ +#define TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ + +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { + +class DFPatternMatcher : public DFPatternFunctor { + public: + using var2val_t = runtime::Map; + + explicit DFPatternMatcher() {} + explicit DFPatternMatcher(var2val_t var2val) : var2val_(std::move(var2val)) {} + bool Match(const DFPattern& pattern, const Expr& expr); + Map> GetMemo() { return Map>(memo_); } + + protected: + bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; + bool VisitDFPattern_(const OrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const AndPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const NotPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; + + bool VisitDFPattern_(const DataflowVarPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const GlobalVarPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ExternFuncPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const PrimArrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const UnorderedTuplePatternNode* op, const Expr& expr) override; + + void ClearMap(size_t watermark); + bool TryUnorderedMatch(size_t idx, const tvm::Array patterns, + const tvm::Array fields, std::vector& match_cache, + std::vector& matched); + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; + var2val_t var2val_; + std::vector matched_nodes_; + arith::Analyzer analyzer_; + bool memoize_ = true; +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc new file mode 100644 index 000000000000..5580f6a1ab74 --- /dev/null +++ b/src/relax/ir/dataflow_pattern.cc @@ -0,0 +1,623 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_pattern.cc + * \brief The dataflow pattern language for Relax (inherited from Relay). + */ + +#include +#include + +#include +#include +#include + +#define RELAX_PATTERN_PRINTER_DEF(NODE_TYPE, REPR_LAMBDA) \ + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) \ + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { \ + auto* node = static_cast(ref.get()); \ + REPR_LAMBDA(p, node); \ + }) + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(ExternFuncPatternNode); +ExternFuncPattern::ExternFuncPattern(String global_symbol) { + ObjectPtr n = make_object(); + n->global_symbol_ = std::move(global_symbol); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.ExternFuncPattern").set_body_typed([](String global_symbol) { + return ExternFuncPattern(global_symbol); +}); +RELAX_PATTERN_PRINTER_DEF(ExternFuncPatternNode, [](auto p, auto node) { + p->stream << "ExternFuncPattern(" << node->global_symbol() << ")"; +}); + +TVM_REGISTER_NODE_TYPE(VarPatternNode); +VarPattern::VarPattern(String name_hint) { + ObjectPtr n = make_object(); + n->name = std::move(name_hint); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.VarPattern").set_body_typed([](String name_hint) { + return VarPattern(name_hint); +}); +RELAX_PATTERN_PRINTER_DEF(VarPatternNode, [](auto p, auto node) { + p->stream << "VarPattern(" << node->name_hint() << ")"; +}); + +TVM_REGISTER_NODE_TYPE(DataflowVarPatternNode); +TVM_REGISTER_GLOBAL("relax.dpl.DataflowVarPattern").set_body_typed([](String name_hint) { + return DataflowVarPattern(name_hint); +}); +DataflowVarPattern::DataflowVarPattern(String name_hint) { + ObjectPtr n = make_object(); + n->name = std::move(name_hint); + data_ = std::move(n); +} +RELAX_PATTERN_PRINTER_DEF(DataflowVarPatternNode, [](auto p, auto node) { + p->stream << "DataflowVarPattern(" << node->name_hint() << ")"; +}); + +TVM_REGISTER_NODE_TYPE(GlobalVarPatternNode); +GlobalVarPattern::GlobalVarPattern(String name_hint) { + ObjectPtr n = make_object(); + n->name = std::move(name_hint); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.GlobalVarPattern").set_body_typed([](String name_hint) { + return GlobalVarPattern(name_hint); +}); +RELAX_PATTERN_PRINTER_DEF(GlobalVarPatternNode, [](auto p, auto node) { + p->stream << "GlobalVarPattern(" << node->name_hint() << ")"; +}); + +TVM_REGISTER_NODE_TYPE(ExprPatternNode); +ExprPattern::ExprPattern(Expr expr) { + ObjectPtr n = make_object(); + n->expr = std::move(expr); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.ExprPattern").set_body_typed([](Expr e) { return ExprPattern(e); }); +RELAX_PATTERN_PRINTER_DEF(ExprPatternNode, [](auto p, auto node) { p->Print(node->expr); }); + +TVM_REGISTER_NODE_TYPE(ConstantPatternNode); +TVM_REGISTER_GLOBAL("relax.dpl.ConstantPattern").set_body_typed([]() { + auto c = ConstantPattern(make_object()); + return c; +}); +RELAX_PATTERN_PRINTER_DEF(ConstantPatternNode, + [](auto p, auto node) { p->stream << "ConstantPattern()"; }); + +TVM_REGISTER_NODE_TYPE(CallPatternNode); +CallPattern::CallPattern(DFPattern op, Array args, bool varg_default_wildcard) { + ObjectPtr n = make_object(); + n->op = std::move(op); + n->args = std::move(args); + n->varg_default_wildcard = varg_default_wildcard; + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.CallPattern") + .set_body_typed([](DFPattern op, Array args, bool varg_default_wildcard) { + return CallPattern(op, args, varg_default_wildcard); + }); +RELAX_PATTERN_PRINTER_DEF(CallPatternNode, [](auto p, auto node) { + p->stream << node->op << "("; + for (size_t i = 0; i < node->args.size(); ++i) { + if (i != 0) p->stream << ", "; + p->stream << node->args[i]; + } + if (node->varg_default_wildcard) { + if (node->args.size() != 0) p->stream << ", "; + p->stream << "..."; + } + p->stream << ")"; +}); + +TVM_REGISTER_NODE_TYPE(PrimArrPatternNode); +PrimArrPattern::PrimArrPattern(Array arr) { + ObjectPtr n = make_object(); + n->fields = std::move(arr); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.PrimArrPattern").set_body_typed([](Array arr) { + return PrimArrPattern(std::move(arr)); +}); +RELAX_PATTERN_PRINTER_DEF(PrimArrPatternNode, [](auto p, auto node) { + p->stream << "PrimArrPattern(" << node->fields << ")"; +}); + +TVM_REGISTER_NODE_TYPE(FunctionPatternNode); +FunctionPattern::FunctionPattern(Array params, DFPattern body) { + ObjectPtr n = make_object(); + n->params = std::move(params); + n->body = std::move(body); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.FunctionPattern") + .set_body_typed([](Array params, DFPattern body) { + return FunctionPattern(params, body); + }); +RELAX_PATTERN_PRINTER_DEF(FunctionPatternNode, [](auto p, auto node) { + p->stream << "FunctionPattern(" << node->params << ", " << node->body << ")"; +}); + +TVM_REGISTER_NODE_TYPE(TuplePatternNode); +TuplePattern::TuplePattern(tvm::Array fields) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.TuplePattern").set_body_typed([](tvm::Array fields) { + return TuplePattern(fields); +}); +RELAX_PATTERN_PRINTER_DEF(TuplePatternNode, [](auto p, auto node) { + p->stream << "TuplePattern(" << node->fields << ")"; +}); + +TVM_REGISTER_NODE_TYPE(UnorderedTuplePatternNode); +UnorderedTuplePattern::UnorderedTuplePattern(tvm::Array fields) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.UnorderedTuplePattern") + .set_body_typed([](tvm::Array fields) { return UnorderedTuplePattern(fields); }); +RELAX_PATTERN_PRINTER_DEF(UnorderedTuplePatternNode, [](auto p, auto node) { + p->stream << "UnorderedTuplePattern(" << node->fields << ")"; +}); + +TVM_REGISTER_NODE_TYPE(TupleGetItemPatternNode); +TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { + ObjectPtr n = make_object(); + n->tuple = std::move(tuple); + n->index = index; + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.TupleGetItemPattern").set_body_typed([](DFPattern tuple, int index) { + return TupleGetItemPattern(tuple, index); +}); +RELAX_PATTERN_PRINTER_DEF(TupleGetItemPatternNode, [](auto p, auto node) { + p->stream << "TupleGetItemPattern(" << node->tuple << ", " << node->index << ")"; +}); + +TVM_REGISTER_NODE_TYPE(AndPatternNode); +AndPattern::AndPattern(DFPattern left, DFPattern right) { + ObjectPtr n = make_object(); + n->left = std::move(left); + n->right = std::move(right); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.AndPattern").set_body_typed([](DFPattern left, DFPattern right) { + return AndPattern(left, right); +}); +RELAX_PATTERN_PRINTER_DEF(AndPatternNode, [](auto p, auto node) { + p->stream << "AndPattern(" << node->left << " & " << node->right << ")"; +}); + +TVM_REGISTER_NODE_TYPE(OrPatternNode); +OrPattern::OrPattern(DFPattern left, DFPattern right) { + ObjectPtr n = make_object(); + n->left = std::move(left); + n->right = std::move(right); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.OrPattern").set_body_typed([](DFPattern left, DFPattern right) { + return OrPattern(left, right); +}); +RELAX_PATTERN_PRINTER_DEF(OrPatternNode, [](auto p, auto node) { + p->stream << "OrPattern(" << node->left << " | " << node->right << ")"; +}); + +TVM_REGISTER_NODE_TYPE(NotPatternNode); +NotPattern::NotPattern(DFPattern reject) { + ObjectPtr n = make_object(); + n->reject = std::move(reject); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.NotPattern").set_body_typed([](DFPattern reject) { + return NotPattern(reject); +}); +RELAX_PATTERN_PRINTER_DEF(NotPatternNode, + [](auto p, auto node) { p->stream << "!(" << node->reject << ")"; }); + +TVM_REGISTER_NODE_TYPE(WildcardPatternNode); +TVM_REGISTER_GLOBAL("relax.dpl.WildcardPattern").set_body_typed([]() { + auto w = WildcardPattern(make_object()); + return w; +}); +RELAX_PATTERN_PRINTER_DEF(WildcardPatternNode, [](auto p, auto node) { p->stream << "*"; }); + +TVM_REGISTER_NODE_TYPE(TypePatternNode); +TypePattern::TypePattern(DFPattern pattern, Type type) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->type = std::move(type); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.TypePattern").set_body_typed([](DFPattern pattern, Type type) { + return TypePattern(pattern, type); +}); +RELAX_PATTERN_PRINTER_DEF(TypePatternNode, [](auto p, auto node) { + p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")"; +}); + +TVM_REGISTER_NODE_TYPE(ShapePatternNode); +ShapePattern::ShapePattern(DFPattern pattern, Array shape) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->shape = std::move(shape); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.ShapePattern") + .set_body_typed([](DFPattern pattern, Array shape) { + return ShapePattern(pattern, shape); + }); +RELAX_PATTERN_PRINTER_DEF(ShapePatternNode, [](auto p, auto node) { + p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")"; +}); + +TVM_REGISTER_NODE_TYPE(DataTypePatternNode); +DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->dtype = std::move(dtype); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.DataTypePattern") + .set_body_typed([](DFPattern pattern, DataType dtype) { + return DataTypePattern(pattern, dtype); + }); +RELAX_PATTERN_PRINTER_DEF(DataTypePatternNode, [](auto p, auto node) { + p->stream << "DataTypePattern(" << node->pattern << " has dtype " << node->dtype << ")"; +}); + +TVM_REGISTER_NODE_TYPE(AttrPatternNode); +AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->attrs = std::move(attrs); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.AttrPattern").set_body_typed([](DFPattern pattern, DictAttrs attrs) { + return AttrPattern(pattern, attrs); +}); +RELAX_PATTERN_PRINTER_DEF(AttrPatternNode, [](auto p, auto node) { + p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; +}); + +class DFPatternDuplicator : public DFPatternFunctor { + public: + DFPattern VisitDFPattern(const DFPattern& pattern) override { + return DFPatternFunctor::VisitDFPattern(pattern); + } + DFPattern VisitDFPattern_(const OrPatternNode* op) override { + return OrPattern(op->left, op->right); + } + DFPattern VisitDFPattern_(const AndPatternNode* op) override { + return AndPattern(op->left, op->right); + } + DFPattern VisitDFPattern_(const NotPatternNode* op) override { return NotPattern(op->reject); } + DFPattern VisitDFPattern_(const VarPatternNode* op) override { return VarPattern(op->name); } + DFPattern VisitDFPattern_(const ConstantPatternNode* op) override { + return ConstantPattern(make_object()); + } + DFPattern VisitDFPattern_(const WildcardPatternNode* op) override { + return WildcardPattern(make_object()); + } + DFPattern VisitDFPattern_(const ExprPatternNode* op) override { return ExprPattern(op->expr); } + DFPattern VisitDFPattern_(const GlobalVarPatternNode* op) override { + return GlobalVarPattern(op->name); + } + DFPattern VisitDFPattern_(const TuplePatternNode* op) override { + return TuplePattern(op->fields); + } + DFPattern VisitDFPattern_(const UnorderedTuplePatternNode* op) override { + return UnorderedTuplePattern(op->fields); + } + DFPattern VisitDFPattern_(const TupleGetItemPatternNode* op) override { + return TupleGetItemPattern(op->tuple, op->index); + } + DFPattern VisitDFPattern_(const CallPatternNode* op) override { + return CallPattern(op->op, op->args); + } + DFPattern VisitDFPattern_(const DataTypePatternNode* op) override { + return DataTypePattern(op->pattern, op->dtype); + } + DFPattern VisitDFPattern_(const FunctionPatternNode* op) override { + return FunctionPattern(op->params, op->body); + } + DFPattern VisitDFPattern_(const ShapePatternNode* op) override { + return ShapePattern(op->pattern, op->shape); + } + DFPattern VisitDFPattern_(const TypePatternNode* op) override { + return TypePattern(op->pattern, op->type); + } + DFPattern VisitDFPattern_(const DataflowVarPatternNode* op) override { + return DataflowVarPattern(op->name); + } + DFPattern VisitDFPattern_(const ExternFuncPatternNode* op) override { + return ExternFuncPattern(op->global_symbol()); + } + DFPattern VisitDFPattern_(const PrimArrPatternNode* op) override { + return PrimArrPattern(op->fields); + } +}; + +// Syntatic Sugar +CallPattern DFPattern::operator()(const std::vector& args) const { + return CallPattern(*this, Array(args)); +} +OrPattern DFPattern::operator|(const DFPattern& other) const { return OrPattern(*this, other); } + +AndPattern DFPattern::operator&(const DFPattern& other) const { return AndPattern(*this, other); } + +NotPattern DFPattern::operator~() const { return NotPattern(*this); } + +AttrPattern DFPattern::HasAttr(const Map& attrs) const { + return AttrPattern(*this, DictAttrs(attrs)); +} +TypePattern DFPattern::HasType(const Type& type) const { return TypePattern(*this, type); } +DataTypePattern DFPattern::HasDtype(const DataType& dtype) const { + return DataTypePattern(*this, dtype); +} +DataTypePattern DFPattern::HasDtype(const std::string& dtype) const { + return HasDtype(DataType(runtime::String2DLDataType(dtype))); +} +ShapePattern DFPattern::HasShape(const Array& shape) const { + return ShapePattern(*this, shape); +} + +DFPattern::operator PatternSeq() const { return PatternSeq{{*this}}; } + +std::stack& pattern_ctx_stack() { + thread_local std::stack graph_pattern_managers; + return graph_pattern_managers; +} + +Optional PatternContext::Current() { + if (pattern_ctx_stack().empty()) return NullOpt; + return pattern_ctx_stack().top(); +} + +PatternContext::PatternContext(bool incremental) { + auto n = make_object(); + if (incremental) { + ICHECK(!pattern_ctx_stack().empty()) + << "Incremental context needs to be built inside a existing context."; + n->allow_extern_use = pattern_ctx_stack().top()->allow_extern_use; + n->constraints = pattern_ctx_stack().top()->constraints; + } + + data_ = std::move(n); +} + +void PatternContext::EnterWithScope() { pattern_ctx_stack().push(*this); } + +void PatternContext::ExitWithScope() { + ICHECK(pattern_ctx_stack().top().same_as(*this)); + pattern_ctx_stack().pop(); +} + +static void sync_graph_constraints(const DFPattern& lhs, const DFPattern& rhs, PairCons pcon) { + if (auto ctx = PatternContext::Current()) { + ctx.value().add_constraint(lhs, rhs, pcon); + } +} + +TVM_REGISTER_NODE_TYPE(PatternSeqNode); +PatternSeq::PatternSeq(DFPattern init_pattern) { + ObjectPtr n = make_object(); + n->patterns = {init_pattern}; + n->pair_constraints = {}; + data_ = std::move(n); +} +PatternSeq::PatternSeq(tvm::Array patterns, bool only_used_by) { + ICHECK_GE(patterns.size(), 1) << "PatternSeq must have at least one pattern"; + const auto cons = PairCons(only_used_by ? PairCons::kOnlyUsedBy : PairCons::kUsedBy); + + ObjectPtr n = make_object(); + n->patterns = std::move(patterns); + n->pair_constraints = std::vector(n->patterns.size() - 1, cons); + data_ = std::move(n); +} + +PatternSeq PatternSeq::UsedBy(PatternSeq other, int index) const { + return relax::UsedBy(*this, other, index); +} + +PatternSeq PatternSeq::OnlyUsedBy(PatternSeq other, int index) const { + return relax::OnlyUsedBy(*this, other, index); +} + +PatternSeq PatternSeq::dup() const { + PatternSeq ret; + + ObjectPtr n = make_object(); + n->patterns = Array{}; + n->patterns.reserve(get()->patterns.size()); + n->pair_constraints = this->get()->pair_constraints; + + for (size_t i = 0; i < get()->patterns.size(); ++i) { + n->patterns.push_back(get()->patterns[i].dup()); + if (i >= 1) + sync_graph_constraints(n->patterns[i - 1], n->patterns[i], n->pair_constraints[i - 1]); + } + + ret.data_ = std::move(n); + + return ret; +} +TVM_REGISTER_GLOBAL("relax.dpl.PatternSeq") + .set_body_typed([](Array patterns, bool only_used_by) { + return PatternSeq(std::move(patterns), only_used_by); + }); +RELAX_PATTERN_PRINTER_DEF(PatternSeqNode, [](auto p, auto node) { + p->stream << "["; + for (size_t i = 0; i < node->patterns.size(); ++i) { + if (i != 0) + p->stream << (PairCons::kOnlyUsedBy == node->pair_constraints[i].type ? " >> " : " ^ "); + p->stream << node->patterns[i]; + } + p->stream << "]"; +}); + +TVM_REGISTER_GLOBAL("relax.dpl.used_by") + .set_body_typed([](PatternSeq lhs, PatternSeq rhs, int index) { + return lhs.UsedBy(rhs, index); + }); + +TVM_REGISTER_GLOBAL("relax.dpl.only_used_by") + .set_body_typed([](PatternSeq lhs, PatternSeq rhs, int index) { + return lhs.OnlyUsedBy(rhs, index); + }); + +PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { + PatternSeq ret; + + const auto constraint = PairCons{PairCons::kOnlyUsedBy, index}; + + sync_graph_constraints(lhs->patterns.back(), rhs->patterns.front(), + PairCons{PairCons::kUsedBy, index}); + + Array patterns; + patterns.reserve(lhs->patterns.size() + rhs->patterns.size()); + patterns.insert(patterns.end(), lhs->patterns.begin(), lhs->patterns.end()); + patterns.insert(patterns.end(), rhs->patterns.begin(), rhs->patterns.end()); + + std::vector pair_constraints = lhs->pair_constraints; + pair_constraints.reserve(pair_constraints.size() + rhs->pair_constraints.size() + 1); + pair_constraints.push_back(constraint); + pair_constraints.insert(pair_constraints.end(), rhs->pair_constraints.begin(), + rhs->pair_constraints.end()); + + ObjectPtr n = make_object(); + n->patterns = std::move(patterns); + n->pair_constraints = std::move(pair_constraints); + ret.data_ = std::move(n); + + return ret; +} +PatternSeq operator^(const PatternSeq& lhs, const PatternSeq& rhs) { return lhs.UsedBy(rhs); } + +PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { + PatternSeq ret; + + const auto constraint = PairCons{PairCons::kOnlyUsedBy, index}; + + sync_graph_constraints(lhs->patterns.back(), rhs->patterns.front(), constraint); + + Array patterns; + patterns.reserve(lhs->patterns.size() + rhs->patterns.size()); + patterns.insert(patterns.end(), lhs->patterns.begin(), lhs->patterns.end()); + patterns.insert(patterns.end(), rhs->patterns.begin(), rhs->patterns.end()); + + std::vector pair_constraints = lhs->pair_constraints; + pair_constraints.reserve(pair_constraints.size() + rhs->pair_constraints.size() + 1); + pair_constraints.push_back(constraint); + pair_constraints.insert(pair_constraints.end(), rhs->pair_constraints.begin(), + rhs->pair_constraints.end()); + + ObjectPtr n = make_object(); + n->patterns = std::move(patterns); + n->pair_constraints = std::move(pair_constraints); + ret.data_ = std::move(n); + + return ret; +} +PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs) { return lhs.OnlyUsedBy(rhs); } + +VarPattern IsVar(const String& name) { return VarPattern(name); } +ConstantPattern IsConst() { return ConstantPattern(make_object()); } +WildcardPattern Wildcard() { return WildcardPattern(make_object()); } +ExprPattern IsExpr(const Expr& expr) { return ExprPattern(expr); } +ExprPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); } +CallPattern IsCallTIR(const String& name, Optional var_args) { + DFPattern arg_pattern; + if (!var_args.defined()) { + arg_pattern = Wildcard(); + } else { + arg_pattern = var_args.value(); + } + + return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern); +} + +CallPattern IsCallTIR(const String& name, TuplePattern var_args) { + return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args); +} +CallPattern IsCallDPSPacked(const String& name, Optional var_args) { + DFPattern arg_pattern; + if (!var_args.defined()) { + arg_pattern = Wildcard(); + } else { + arg_pattern = var_args.value(); + } + + return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), arg_pattern); +} + +CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args) { + return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), var_args); +} + +DFPattern IsTuple(const Array& fields, bool unordered) { + if (unordered) + return UnorderedTuplePattern(fields); + else + return TuplePattern(fields); +} +TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index) { + return TupleGetItemPattern(tuple, index); +} + +DFPattern DFPattern::dup() const { + auto pattern = DFPatternDuplicator().VisitDFPattern(*this); + return pattern; +} + +TVM_REGISTER_GLOBAL("relax.dpl.dup_pattern").set_body_typed([](DFPattern pattern) { + return pattern.dup(); +}); + +TVM_REGISTER_GLOBAL("relax.dpl.dup_seq").set_body_typed([](PatternSeq seq) { return seq.dup(); }); + +TVM_REGISTER_GLOBAL("relax.dpl.PatternContext").set_body_typed([](bool incre) { + return PatternContext(incre); +}); + +TVM_REGISTER_GLOBAL("relax.dpl.current_context").set_body_typed([] { + return PatternContext::Current(); +}); + +class PatternContext::Internal { + public: + static void EnterScope(PatternContext pass_ctx) { pass_ctx.EnterWithScope(); } + static void ExitScope(PatternContext pass_ctx) { pass_ctx.ExitWithScope(); } +}; + +TVM_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed(PatternContext::Internal::EnterScope); + +TVM_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed(PatternContext::Internal::ExitScope); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_pattern_functor.cc b/src/relax/ir/dataflow_pattern_functor.cc new file mode 100644 index 000000000000..37a98f28beef --- /dev/null +++ b/src/relax/ir/dataflow_pattern_functor.cc @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_matcher.cc + * \brief The dataflow pattern matcher for Relay. + */ + +#include + +namespace tvm { +namespace relax { + +// DFPatternVisitor + +void DFPatternVisitor::VisitDFPattern(const DFPattern& pattern) { + if (this->visited_.count(pattern.get()) == 0) { + visited_.insert(pattern.get()); + DFPatternFunctor::VisitDFPattern(pattern); + } +} + +void DFPatternVisitor::VisitDFPattern_(const OrPatternNode* op) { + VisitDFPattern(op->left); + VisitDFPattern(op->right); +} + +void DFPatternVisitor::VisitDFPattern_(const AndPatternNode* op) { + VisitDFPattern(op->left); + VisitDFPattern(op->right); +} + +void DFPatternVisitor::VisitDFPattern_(const NotPatternNode* op) { VisitDFPattern(op->reject); } + +void DFPatternVisitor::VisitDFPattern_(const AttrPatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) { + VisitDFPattern(op->op); + if (op->args.defined()) { + for (auto arg : op->args) { + VisitDFPattern(arg); + } + } +} + +void DFPatternVisitor::VisitDFPattern_(const DataTypePatternNode* op) { + VisitDFPattern(op->pattern); +} + +void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {} + +void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) { + if (op->params.defined()) { + for (auto param : op->params) { + VisitDFPattern(param); + } + } + VisitDFPattern(op->body); +} + +void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) { + VisitDFPattern(op->tuple); +} + +void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) { + if (op->fields.defined()) { + for (auto field : op->fields) { + VisitDFPattern(field); + } + } +} + +void DFPatternVisitor::VisitDFPattern_(const UnorderedTuplePatternNode* op) { + if (op->fields.defined()) { + for (auto field : op->fields) { + VisitDFPattern(field); + } + } +} + +void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); } + +// leaf nodes. +void DFPatternVisitor::VisitDFPattern_(const PrimArrPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const DataflowVarPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const GlobalVarPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const ExternFuncPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc new file mode 100644 index 000000000000..bfb5896c9988 --- /dev/null +++ b/src/relax/ir/emit_te.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/src/ir/emit_te.cc + */ +#include "./emit_te.h" + +#include +#include + +namespace tvm { +namespace relax { + +// RXPlaceholderOpNode +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "rxplaceholder(" << op->name << ", " << op << ")"; + }); + +TVM_REGISTER_NODE_TYPE(RXPlaceholderOpNode); + +te::Tensor TETensor(Expr value, Map tir_var_map, std::string name) { + auto n = make_object(); + n->name = name; + n->value = value; + + // If the value is a constant, it might come as an argument of EmitTE and thus its shape and + // checked-type might not be properly set. In this case we set the shape and dtype of the returned + // TE tensor. + if (const auto* constant = value.as()) { + n->dtype = DataType(constant->data->dtype); + + int ndim = constant->data->ndim; + ShapeTuple shape_tuple = constant->data.Shape(); + Array shape; + shape.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + shape.push_back(IntImm(DataType::Int(64), shape_tuple[i])); + } + n->shape = std::move(shape); + return te::PlaceholderOp(n).output(0); + } + ICHECK(value->struct_info_.defined()) << "value must be normalized and contain StructInfo"; + auto* tensor_sinfo = GetStructInfoAs(value); + ICHECK(tensor_sinfo) << "Value must be a tensor"; + auto* shape_expr = tensor_sinfo->shape.as(); + CHECK(shape_expr) + << "ValueError: Expression does not have an known symbolic shape, please consider use " + "match_cast " + << "to constrain the shape before passing into te_tensor"; + n->shape = shape_expr->values.Map( + [&tir_var_map](const PrimExpr& e) { return tir::Substitute(e, tir_var_map); }); + n->dtype = tensor_sinfo->dtype; + return te::PlaceholderOp(n).output(0); +} + +TVM_REGISTER_GLOBAL("relax.TETensor").set_body_typed(TETensor); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h new file mode 100644 index 000000000000..46207479c7ef --- /dev/null +++ b/src/relax/ir/emit_te.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/src/ir/emit_te.h + * \brief Tensor expression extension in Relax. + */ +#ifndef TVM_RELAX_IR_EMIT_TE_H_ +#define TVM_RELAX_IR_EMIT_TE_H_ + +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A placeholder op that represents a relax expression. + */ +class RXPlaceholderOpNode : public te::PlaceholderOpNode { + public: + /*! \brief The relax expression. */ + Expr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("tag", &tag); + v->Visit("attrs", &attrs); + v->Visit("value", &value); + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + } + + static constexpr const char* _type_key = "RXPlaceholderOp"; + TVM_DECLARE_FINAL_OBJECT_INFO(RXPlaceholderOpNode, te::PlaceholderOpNode); +}; + +/*! + * \brief Create a TE tensor from relax expression, with TIR variables in the + * tensor shape substituted by the given mapping. + * \param value The relax expression, which is required to have TensorStructInfo. + * \param tir_var_map The mapping to substitute the TIR variables appeared in the + * shape of the input Expr. + * \param name The name of the created tensor. + */ +te::Tensor TETensor(Expr value, Map tir_var_map, std::string name); + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_IR_EMIT_TE_H_ diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc new file mode 100644 index 000000000000..5392be7cb69b --- /dev/null +++ b/src/relax/ir/expr.cc @@ -0,0 +1,580 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using tvm::ReprPrinter; +using tvm::runtime::Optional; + +TVM_REGISTER_NODE_TYPE(IdNode); + +Id::Id(String name_hint) { + ObjectPtr n = make_object(); + n->name_hint = std::move(name_hint); + data_ = std::move(n); +} + +Call::Call(Expr op, Array args, Attrs attrs, Array sinfo_args, Span span) { + ObjectPtr n = make_object(); + n->op = std::move(op); + n->args = std::move(args); + n->attrs = std::move(attrs); + n->sinfo_args = std::move(sinfo_args); + n->span = std::move(span); + data_ = std::move(n); +} + +Call WithFields(Call call, Optional opt_op, Optional> opt_args, + Optional opt_attrs, Optional> opt_sinfo_args, + Optional opt_span) { + // Collect new values for fields. + Expr op = opt_op.value_or(call->op); + Array args = opt_args.value_or(call->args); + Attrs attrs = opt_attrs.value_or(call->attrs); + Array sinfo_args = opt_sinfo_args.value_or(call->sinfo_args); + Span span = opt_span.value_or(call->span); + + // Check if anything changed. + bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span); + if (unchanged) { + if (args.size() == call->args.size()) { + for (size_t i = 0; i < args.size(); i++) { + unchanged &= args[i].same_as(call->args[i]); + } + } else { + unchanged = false; + } + } + if (unchanged) { + if (sinfo_args.size() == call->sinfo_args.size()) { + for (size_t i = 0; i < sinfo_args.size(); i++) { + unchanged &= sinfo_args[i].same_as(call->sinfo_args[i]); + } + } else { + unchanged = false; + } + } + + if (!unchanged) { + // If call is only references, update it in place. Otherwise copy and update. + CallNode* cow_call_node = call.CopyOnWrite(); + cow_call_node->op = op; + cow_call_node->args = args; + cow_call_node->attrs = attrs; + cow_call_node->sinfo_args = sinfo_args; + cow_call_node->span = span; + } + return call; +} + +TVM_REGISTER_NODE_TYPE(CallNode); + +TVM_REGISTER_GLOBAL("relax.Call") + .set_body_typed([](Expr op, Array args, Attrs attrs, Array sinfo_args, + Span span) { return Call(op, args, attrs, sinfo_args, span); }); + +If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { + ObjectPtr n = make_object(); + n->cond = std::move(cond); + n->true_branch = std::move(true_branch); + n->false_branch = std::move(false_branch); + n->span = std::move(span); + data_ = std::move(n); +} + +If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branch, + Optional opt_false_branch, Optional opt_span) { + Expr cond = opt_cond.value_or(if_expr->cond); + Expr true_branch = opt_true_branch.value_or(if_expr->true_branch); + Expr false_branch = opt_false_branch.value_or(if_expr->false_branch); + Span span = opt_span.value_or(if_expr->span); + + bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && + false_branch.same_as(if_expr->false_branch) && span.same_as(if_expr->span); + + if (!unchanged) { + IfNode* cow_if_node = if_expr.CopyOnWrite(); + cow_if_node->cond = cond; + cow_if_node->true_branch = true_branch; + cow_if_node->false_branch = false_branch; + cow_if_node->span = span; + } + return if_expr; +} + +TVM_REGISTER_NODE_TYPE(IfNode); + +TVM_REGISTER_GLOBAL("relax.If") + .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) { + return If(cond, true_branch, false_branch, span); + }); + +Tuple::Tuple(tvm::Array fields, Span span) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleNode); + +TVM_REGISTER_GLOBAL("relax.Tuple").set_body_typed([](tvm::Array fields, Span span) { + return Tuple(fields, span); +}); + +Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_span) { + Array fields = opt_fields.value_or(tuple->fields); + Span span = opt_span.value_or(tuple->span); + + bool all_fields_unchanged = true; + if (fields.size() == tuple->fields.size()) { + for (size_t i = 0; i < fields.size(); i++) { + all_fields_unchanged &= fields[i].same_as(tuple->fields[i]); + } + } else { + all_fields_unchanged = false; + } + + all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span); + if (!all_fields_unchanged) { + TupleNode* cow_tuple_node = tuple.CopyOnWrite(); + cow_tuple_node->fields = fields; + cow_tuple_node->span = span; + } + return tuple; +} + +TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { + ObjectPtr n = make_object(); + n->tuple = std::move(tuple); + n->index = index; + n->span = std::move(span); + data_ = std::move(n); +} + +TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, + Optional opt_index, Optional opt_span) { + Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); + Integer index = opt_index.value_or(tuple_get_item->index); + Span span = opt_span.value_or(tuple_get_item->span); + + bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && + span.same_as(tuple_get_item->span); + if (!unchanged) { + TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); + cow_tuple_get_item_node->tuple = tuple; + cow_tuple_get_item_node->index = index.IntValue(); + cow_tuple_get_item_node->span = span; + } + return tuple_get_item; +} + +TVM_REGISTER_NODE_TYPE(TupleGetItemNode); + +TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int index) { + return TupleGetItem(tuple, index); +}); + +TVM_REGISTER_NODE_TYPE(ShapeExprNode); + +ShapeExpr::ShapeExpr(Array values, Span span) { + ObjectPtr n = make_object(); + + n->values = values.Map([](PrimExpr value) { + if (value->IsInstance()) { + return tvm::cast(DataType::Int(64), value); + } + ICHECK(value.dtype() == DataType::Int(64)) + << "the value in ShapeStructInfo can only have dtype of int64"; + return value; + }); + n->span = span; + n->checked_type_ = ShapeType(values.size()); + n->struct_info_ = ShapeStructInfo(values, span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, Span span) { + return ShapeExpr(values, span); +}); + +TVM_REGISTER_NODE_TYPE(VarNode); + +Var::Var(Id vid, Optional struct_info_annotation, Span span) { + ObjectPtr n = make_object(); + n->vid = std::move(vid); + if (struct_info_annotation) { + n->checked_type_ = GetStaticType(struct_info_annotation.value()); + } + n->struct_info_ = std::move(struct_info_annotation); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.Var") + .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { + return Var(name_hint, struct_info_annotation, span); + }); + +TVM_REGISTER_GLOBAL("relax.VarFromId") + .set_body_typed([](Id vid, Optional struct_info_annotation, Span span) { + return Var(vid, struct_info_annotation, span); + }); + +TVM_REGISTER_NODE_TYPE(DataflowVarNode); + +DataflowVar::DataflowVar(Id vid, Optional struct_info_annotation, Span span) { + ObjectPtr n = make_object(); + n->vid = std::move(vid); + if (struct_info_annotation) { + n->checked_type_ = GetStaticType(struct_info_annotation.value()); + } + n->struct_info_ = std::move(struct_info_annotation); + n->span = std::move(span); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowVar") + .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { + return DataflowVar(name_hint, struct_info_annotation, span); + }); + +TVM_REGISTER_GLOBAL("relax.DataflowVarFromId") + .set_body_typed([](Id vid, Optional struct_info_annotation, Span span) { + return DataflowVar(vid, struct_info_annotation, span); + }); + +Constant::Constant(runtime::NDArray data, Span span) { + ObjectPtr n = make_object(); + n->data = std::move(data); + n->span = std::move(span); + + // set struct info. + Array values; + auto shape_tuple = n->data.Shape(); + for (size_t dim = 0; dim < shape_tuple.size(); ++dim) { + values.push_back(IntImm(DataType::Int(64), shape_tuple[dim])); + } + TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), span); + + n->struct_info_ = tinfo; + n->checked_type_ = DynTensorType(tinfo->ndim, tinfo->dtype); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ConstantNode); + +TVM_REGISTER_GLOBAL("relax.Constant").set_body_typed([](runtime::NDArray data, Span span = Span()) { + return Constant(data, span); +}); + +PrimValue::PrimValue(PrimExpr value, Span span) { + ObjectPtr n = make_object(); + n->checked_type_ = PrimType(value.dtype()); + n->struct_info_ = PrimStructInfo(value.dtype()); + n->value = std::move(value); + n->span = std::move(span); + data_ = std::move(n); +} + +PrimValue PrimValue::Int64(int64_t value, Span span) { + return PrimValue(IntImm(DataType::Int(64), value), span); +} + +TVM_REGISTER_NODE_TYPE(PrimValueNode); + +TVM_REGISTER_GLOBAL("relax.PrimValue").set_body_typed([](PrimExpr value, Span span) { + return PrimValue(value, span); +}); + +StringImm::StringImm(String value, Span span) { + ObjectPtr n = make_object(); + n->value = std::move(value); + n->span = std::move(span); + // use the base structinfo for now + // we can choose to introduce more fine-grained struct info later if necessary. + n->checked_type_ = ObjectType(); + n->struct_info_ = ObjectStructInfo(); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(StringImmNode); + +TVM_REGISTER_GLOBAL("relax.StringImm").set_body_typed([](String value, Span span) { + return StringImm(value, span); +}); + +DataTypeImm::DataTypeImm(DataType value, Span span) { + ObjectPtr n = make_object(); + n->value = std::move(value); + n->span = std::move(span); + // use the base structinfo for now + // we can choose to introduce more fine-grained struct info later if necessary. + n->checked_type_ = ObjectType(); + n->struct_info_ = ObjectStructInfo(); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DataTypeImmNode); + +TVM_REGISTER_GLOBAL("relax.DataTypeImm").set_body_typed([](DataType value, Span span) { + return DataTypeImm(value, span); +}); + +TVM_REGISTER_NODE_TYPE(MatchCastNode); + +MatchCast::MatchCast(Var var, Expr value, StructInfo struct_info, Span span) { + ObjectPtr n = make_object(); + ICHECK(var.defined()) << "MatchCast requires var to be defined"; + n->var = std::move(var); + n->value = std::move(value); + n->struct_info = std::move(struct_info); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.MatchCast") + .set_body_typed([](Var var, Expr value, StructInfo struct_info, Span span) { + return MatchCast(var, value, struct_info, span); + }); + +TVM_REGISTER_NODE_TYPE(VarBindingNode); + +VarBinding::VarBinding(Var var, Expr value, Span span) { + ObjectPtr n = make_object(); + n->var = std::move(var); + n->value = std::move(value); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, Span span) { + return VarBinding(var, value, span); +}); + +TVM_REGISTER_NODE_TYPE(BindingBlockNode); + +BindingBlock::BindingBlock(Array bindings, Span span) { + ObjectPtr n = make_object(); + n->bindings = std::move(bindings); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array bindings, Span span) { + return BindingBlock(bindings, span); +}); + +TVM_REGISTER_NODE_TYPE(DataflowBlockNode); + +DataflowBlock::DataflowBlock(Array bindings, Span span) { + ObjectPtr n = make_object(); + n->bindings = std::move(bindings); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowBlock").set_body_typed([](Array bindings, Span span) { + return DataflowBlock(bindings, span); +}); + +TVM_REGISTER_NODE_TYPE(SeqExprNode); + +SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { + ObjectPtr n = make_object(); + n->blocks = std::move(blocks); + n->body = std::move(body); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.SeqExpr") + .set_body_typed([](Array blocks, Expr body, Span span) { + return SeqExpr(blocks, body, span); + }); + +TVM_REGISTER_NODE_TYPE(FunctionNode); + +Function::Function(Array params, Expr body, Optional ret_struct_info, + DictAttrs attrs, Span span) { + // Set the function type. + // For function, we take a conservative approach and require the function type + // to be known at construction time. + Array param_sinfo; + + for (const Var& param : params) { + CHECK(param->struct_info_.defined()) + << "relax.Function requires params to contain struct_info_"; + param_sinfo.push_back(GetStructInfo(param)); + } + + Optional body_sinfo; + + if (body->struct_info_.defined()) { + body_sinfo = GetStructInfo(body); + } + + if (ret_struct_info.defined()) { + // allow body to override ret if body is more fine-grained. + if (body_sinfo.defined()) { + if (IsBaseOf(ret_struct_info.value(), body_sinfo.value())) { + ret_struct_info = body_sinfo; + } + } + } else { + CHECK(body_sinfo.defined()) + << "Function do not have a return signature and body is not normalized"; + ret_struct_info = body_sinfo; + } + + FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value()); + + // set the fields + ObjectPtr n = make_object(); + n->params = std::move(params); + n->body = std::move(body); + n->ret_struct_info = std::move(ret_struct_info.value()); + n->checked_type_ = GetStaticType(func_sinfo); + n->struct_info_ = std::move(func_sinfo); + n->attrs = std::move(attrs); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.Function") + .set_body_typed([](Array params, Expr body, Optional ret_struct_info, + DictAttrs attrs, + Span span) { return Function(params, body, ret_struct_info, attrs, span); }); + +Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, DictAttrs attrs, + Span span) { + Array param_sinfo; + for (const Var& param : params) { + ICHECK(param->checked_type_.defined()) + << "relax.Function requires params to contain checked_type_."; + param_sinfo.push_back(GetStructInfo(param)); + } + FuncStructInfo finfo(param_sinfo, ret_struct_info); + + // set the fields + ObjectPtr n = make_object(); + n->params = std::move(params); + n->body = Expr(); + n->checked_type_ = GetStaticType(finfo); + n->struct_info_ = std::move(finfo); + n->ret_struct_info = std::move(ret_struct_info); + n->attrs = std::move(attrs); + n->span = std::move(span); + return Function(std::move(n)); +} + +TVM_REGISTER_GLOBAL("relax.FunctionCreateEmpty") + .set_body_typed([](Array params, StructInfo ret_struct_info, DictAttrs attrs, Span span) { + return Function::CreateEmpty(params, ret_struct_info, attrs, span); + }); + +// Special opaque derivation function for ExternFunc +// Take look at sinfo_args to figure out the return StructInfo. +TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_by_sinfo_args") + .set_body_typed([](const Call& call, const BlockBuilder& ctx) -> StructInfo { + ICHECK(call->sinfo_args.defined()) << "sinfo_args field of CallNode should always be defined"; + if (call->sinfo_args.empty()) { + return ObjectStructInfo(); + } else if (call->sinfo_args.size() == 1) { + return call->sinfo_args[0]; + } else { + return TupleStructInfo(call->sinfo_args); + } + }); + +// Get the derive function. +FuncStructInfo GetExternFuncStructInfo() { + EnvFunc fn = EnvFunc::Get("tvm.relax.struct_info.infer_by_sinfo_args"); + StructInfoDeriveFunc derive; + derive = fn; + return FuncStructInfo::OpaqueFunc(derive); +} + +TVM_REGISTER_NODE_TYPE(ExternFuncNode); + +ExternFunc::ExternFunc(String global_symbol, Span span) { + ObjectPtr n = make_object(); + n->global_symbol = std::move(global_symbol); + n->span = span; + static auto sinfo = GetExternFuncStructInfo(); + n->struct_info_ = sinfo; + n->checked_type_ = GetStaticType(sinfo); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, Span span) { + return ExternFunc(global_symbol, span); +}); + +Expr GetShapeOf(const Expr& expr) { + // default case, to be normalized. + ICHECK(expr->struct_info_.defined()) << "GetShapeOf can only be applied to normalized expr"; + auto* tinfo = GetStructInfoAs(expr); + + ICHECK(tinfo != nullptr) << "ShapeOf can only be applied to expr with TensorStructInfo"; + if (tinfo->shape.defined()) return tinfo->shape.value(); + + static const Op& op = Op::Get("relax.shape_of"); + // default case, call shape of, eagerly normalize the expr. + relax::Call call_shape_of(op, {expr}, {}, {}); + UpdateStructInfo(call_shape_of, ShapeStructInfo(tinfo->ndim)); + return call_shape_of; +} + +TVM_REGISTER_GLOBAL("relax.GetShapeOf").set_body_typed([](const Expr& expr) { + return GetShapeOf(expr); +}); + +TVM_REGISTER_GLOBAL("relax.FuncWithAttr") + .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> Optional { + if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } + return NullOpt; + }); + +TVM_REGISTER_GLOBAL("relax.FuncWithAttrs") + .set_body_typed([](BaseFunc func, Map attr_map) -> Optional { + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } + return NullOpt; + }); + +TVM_REGISTER_GLOBAL("relax.FuncWithoutAttr") + .set_body_typed([](BaseFunc func, String key) -> Optional { + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } + return NullOpt; + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc new file mode 100644 index 000000000000..3f0fc86a2a37 --- /dev/null +++ b/src/relax/ir/expr_functor.cc @@ -0,0 +1,793 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/expr_functor.cc + * \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator uses memoization and self return in order to amortize + * the cost of using functional updates. + */ +#include +#include +#include +#include + +// functions to be overriden. +#define RELAX_VISIT_BINDING_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const ObjectRef& n, TSelf* self, const VarBindingNode* binding) { \ + self->VisitBinding_(binding, static_cast(n.get())); \ + }); + +#define RELAX_VAR_BINDING_DISPATCH_IMPL(Type) \ + Type::VisitBindingVTable Type::InitVisitBindingVTable() { \ + VisitBindingVTable vtable; \ + RELAX_VISIT_BINDING_DISPATCH(ConstantNode); \ + RELAX_VISIT_BINDING_DISPATCH(TupleNode); \ + RELAX_VISIT_BINDING_DISPATCH(VarNode); \ + RELAX_VISIT_BINDING_DISPATCH(DataflowVarNode); \ + RELAX_VISIT_BINDING_DISPATCH(ShapeExprNode); \ + RELAX_VISIT_BINDING_DISPATCH(ExternFuncNode); \ + RELAX_VISIT_BINDING_DISPATCH(GlobalVarNode); \ + RELAX_VISIT_BINDING_DISPATCH(FunctionNode); \ + RELAX_VISIT_BINDING_DISPATCH(CallNode); \ + RELAX_VISIT_BINDING_DISPATCH(SeqExprNode); \ + RELAX_VISIT_BINDING_DISPATCH(IfNode); \ + RELAX_VISIT_BINDING_DISPATCH(OpNode); \ + RELAX_VISIT_BINDING_DISPATCH(TupleGetItemNode); \ + RELAX_VISIT_BINDING_DISPATCH(PrimValueNode); \ + RELAX_VISIT_BINDING_DISPATCH(StringImmNode); \ + RELAX_VISIT_BINDING_DISPATCH(DataTypeImmNode); \ + return vtable; \ + } \ + void Type::VisitBinding_(const VarBindingNode* binding) { \ + static VisitBindingVTable vtable = InitVisitBindingVTable(); \ + const Expr& value = binding->value; \ + ICHECK(value.defined()) << "Found null pointer node while traversing AST."; \ + ICHECK(vtable.can_dispatch(value)) \ + << "VisitVarBinding do not allow binding value type" << value->GetTypeKey(); \ + vtable(value, this, binding); \ + } + +// functions to be overriden. +#define RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(OP) \ + void ExprVisitor::VisitBinding_(const VarBindingNode* binding, const OP* value) { \ + this->VisitExpr(binding->value); \ + this->VisitVarDef(binding->var); \ + } + +// functions to be overriden. +#define RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(OP) \ + void ExprMutator::VisitBinding_(const VarBindingNode* binding, const OP* value) { \ + Expr new_value = this->VisitExpr(binding->value); \ + this->ReEmitBinding(binding, new_value); \ + } + +namespace tvm { +namespace relax { + +// ================== +// ExprVisitor + +void ExprVisitor::VisitExprDepStructInfoField(const StructInfo& struct_info) { + // recurse into struct info in case they depend on value + // under the current scope. + default_struct_info_field_visitor_.VisitStructInfo(struct_info); +} + +ExprVisitor::DefaultStructInfoFieldVisitor::DefaultStructInfoFieldVisitor(ExprVisitor* parent) + : parent_(parent) {} + +void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfoExprField(const Expr& expr) { + parent_->VisitExpr(expr); +} + +void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfoExprField(const PrimExpr& expr) { + parent_->VisitPrimExpr(expr); +} + +void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { + // Do not recurse into function struct info + // as they won't contain ref to values in current scope. +} + +void ExprVisitor::VisitExpr(const Expr& expr) { ExprFunctor::VisitExpr(expr); } + +void ExprVisitor::VisitExpr_(const ConstantNode* op) { + this->VisitSpan(op->span); + // Constant's StructInfo does not depend on Expr. +} + +void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { + this->VisitSpan(op->span); + // FuncStructInfo is not value-dep +} + +void ExprVisitor::VisitExpr_(const TupleNode* op) { + this->VisitSpan(op->span); + for (Expr field : op->fields) { + this->VisitExpr(field); + } + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +// Visit the use-site of a defined Var +void ExprVisitor::VisitExpr_(const VarNode* op) { + this->VisitSpan(op->span); + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +// Visit the use-site of a defined DataflowVar +void ExprVisitor::VisitExpr_(const DataflowVarNode* op) { + this->VisitSpan(op->span); + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const FunctionNode* op) { + this->VisitSpan(op->span); + for (Var param : op->params) { + this->VisitVarDef(param); + } + + this->VisitExpr(op->body); + // FuncStructInfo does not depend on Expr. +} + +void ExprVisitor::VisitExpr_(const CallNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->op); + + for (StructInfo sinfo_arg : op->sinfo_args) { + this->VisitExprDepStructInfoField(sinfo_arg); + } + + for (Expr arg : op->args) { + this->VisitExpr(arg); + } + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const IfNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->cond); + this->VisitExpr(op->true_branch); + this->VisitExpr(op->false_branch); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const OpNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->tuple); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { + for (PrimExpr val : op->values) { + this->VisitPrimExpr(val); + } + this->VisitSpan(op->span); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const ExternFuncNode* op) { + this->VisitSpan(op->span); + // FuncStructInfo does not depend on Expr. +} + +void ExprVisitor::VisitExpr_(const SeqExprNode* op) { + this->VisitSpan(op->span); + for (BindingBlock block : op->blocks) { + this->VisitBindingBlock(block); + } + this->VisitExpr(op->body); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const PrimValueNode* op) { + this->VisitPrimExpr(op->value); + this->VisitSpan(op->span); +} + +void ExprVisitor::VisitExpr_(const StringImmNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitExpr_(const DataTypeImmNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitSpan(const Span& span) {} + +void ExprVisitor::VisitPrimExpr(const PrimExpr& expr) {} + +// implementations of binding visitor dispatch +RELAX_VAR_BINDING_DISPATCH_IMPL(ExprVisitor); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ConstantNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(TupleNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(VarNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataflowVarNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ShapeExprNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ExternFuncNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(GlobalVarNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(FunctionNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(CallNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(SeqExprNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(IfNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(OpNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(TupleGetItemNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(PrimValueNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(StringImmNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataTypeImmNode); + +void ExprVisitor::VisitBinding_(const MatchCastNode* binding) { + this->VisitExpr(binding->value); + this->VisitVarDef(binding->var); +} + +void ExprVisitor::VisitBindingBlock_(const BindingBlockNode* block) { + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } +} + +void ExprVisitor::VisitBindingBlock_(const DataflowBlockNode* block) { + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } +} + +void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) { this->VisitSpan(var->span); } + +void ExprVisitor::VisitVarDef_(const VarNode* var) { this->VisitSpan(var->span); } + +void ExprVisitor::VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } +} + +void ExprVisitor::VisitBindingBlock(const BindingBlock& block) { + if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } +} + +void ExprVisitor::VisitVarDef(const Var& var) { + if (const auto* node = var.as()) { + VisitVarDef_(node); + } else if (const auto* node = var.as()) { + VisitVarDef_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } +} + +class ExprApplyVisit : public ExprVisitor { + public: + explicit ExprApplyVisit(std::function f) : f_(f) {} + + void VisitExpr(const Expr& e) final { + ExprVisitor::VisitExpr(e); + f_(e); + } + + private: + std::function f_; +}; + +void PostOrderVisit(const Expr& e, std::function fvisit) { + ExprApplyVisit(fvisit).VisitExpr(e); +} + +TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr expr, PackedFunc f) { + PostOrderVisit(expr, [f](const Expr& n) { f(n); }); +}); + +// ================== +// ExprMutatorBase + +StructInfo ExprMutatorBase::VisitExprDepStructInfoField(const StructInfo& struct_info) { + // recurse into struct info in case they depend on value + // under the current scope. + return default_struct_info_field_mutator_.VisitStructInfo(struct_info); +} + +ExprMutatorBase::DefaultStructInfoFieldMutator::DefaultStructInfoFieldMutator( + ExprMutatorBase* parent) + : parent_(parent) {} + +Expr ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfoExprField(const Expr& expr) { + return parent_->VisitExpr(expr); +} + +PrimExpr ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfoExprField( + const PrimExpr& expr) { + return parent_->VisitPrimExpr(expr); +} + +StructInfo ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfo_( + const FuncStructInfoNode* op) { + // Do not recurse into function struct info + // as they won't contain ref to values in current scope. + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } + +Expr ExprMutatorBase::VisitExpr_(const ConstantNode* op) { + // Constant' struct info won't be affected by Expr/PrimExpr change. + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr_(const GlobalVarNode* op) { + // FuncStructInfo won't be affected by Expr/PrimExpr change. + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { + bool unchanged = true; + tvm::Array fields; + for (Expr field : op->fields) { + Expr new_field = this->VisitExpr(field); + fields.push_back(new_field); + unchanged &= new_field.same_as(field); + } + + if (unchanged) { + // If tuple's struct info change it means that + // one of its fields' struct info will change + // so un-changed already implies that struct info won't change + return GetRef(op); + } else { + // when there is a change return a new tuple node + return Tuple(fields, op->span); + } +} + +// Visit the use-site of a defined Var +Expr ExprMutatorBase::VisitExpr_(const VarNode* op) { + // struct info of var-use should remain stable + // or the var itself will get replaced + return GetRef(op); +} + +// Visit the use-site of a defined DataflowVar +Expr ExprMutatorBase::VisitExpr_(const DataflowVarNode* op) { + // struct info of var-use should remain stable + // or the var itself will get replaced + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) { + // struct info of function is not value dependent + // so no need to check struct_info field + Expr body = this->VisitExpr(op->body); + + if (body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->params, body, op->ret_struct_info, op->attrs); + } +} + +Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { + Expr new_op = this->VisitExpr(call_node->op); + bool unchanged = call_node->op.same_as(new_op); + + Array sinfo_args; + for (StructInfo sinfo_arg : call_node->sinfo_args) { + StructInfo new_sinfo_arg = this->VisitExprDepStructInfoField(sinfo_arg); + sinfo_args.push_back(new_sinfo_arg); + unchanged &= new_sinfo_arg.same_as(sinfo_arg); + } + + tvm::Array call_args; + for (Expr arg : call_node->args) { + Expr new_arg = this->VisitExpr(arg); + call_args.push_back(new_arg); + unchanged &= new_arg.same_as(arg); + } + + if (unchanged && VisitAndCheckStructInfoFieldUnchanged(call_node->struct_info_)) { + return GetRef(call_node); + } else { + return Call(new_op, call_args, call_node->attrs, sinfo_args, call_node->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const IfNode* op) { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitExpr(op->true_branch); + Expr false_b = this->VisitExpr(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const OpNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const TupleGetItemNode* op) { + auto t = this->VisitExpr(op->tuple); + if (op->tuple.same_as(t)) { + // struct info can be deterministically derived by tuple and index + // if t does not change, then struct info won't change. + return GetRef(op); + } else { + return TupleGetItem(t, op->index, op->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const PrimValueNode* op) { + auto value = this->VisitPrimExpr(op->value); + if (op->value.same_as(value)) { + // struct info can be deterministically derived by value + // if value does not change, then struct info won't change. + return GetRef(op); + } + return PrimValue(value, op->span); +} + +Expr ExprMutatorBase::VisitExpr_(const StringImmNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const DataTypeImmNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { + auto values = op->values.Map([this](const PrimExpr& e) { return this->VisitPrimExpr(e); }); + + if (values.same_as(op->values)) { + // If values does not change, struct info won't change. + return GetRef(op); + } else { + return ShapeExpr(values, op->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const ExternFuncNode* op) { + // StructInfo of function remains value independent. + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { + bool all_blocks_unchanged = true; + Array blocks; + for (auto block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + all_blocks_unchanged &= block.same_as(new_block); + } + + Expr body = this->VisitExpr(op->body); + + if (all_blocks_unchanged && body.same_as(op->body) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); + } + return SeqExpr(blocks, body); +} + +BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { + Array bindings; + if (const auto* node = block.as()) { + for (auto binding : node->bindings) { + if (auto var_binding = binding.as()) { + Expr new_value = this->VisitExpr(var_binding->value); + bindings.push_back(VarBinding(var_binding->var, new_value)); + } else if (auto match_cast = binding.as()) { + Expr new_value = this->VisitExpr(match_cast->value); + bindings.push_back(MatchCast(match_cast->var, new_value, match_cast->struct_info)); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + } + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + + if (block.as()) { + return DataflowBlock(bindings); + } else { + return BindingBlock(bindings); + } +} + +PrimExpr ExprMutatorBase::VisitPrimExpr(const PrimExpr& expr) { return expr; } + +// ================== +// ExprMutator + +Expr ExprMutator::VisitExpr(const Expr& expr) { + return builder_->Normalize(ExprFunctor::VisitExpr(expr)); +} + +// Visit the use-site of a defined Var +Expr ExprMutator::VisitExpr_(const VarNode* op) { + auto it = var_remap_.find(op->vid); + if (it != var_remap_.end()) { + return it->second; + } + + // default case return self. + return GetRef(op); +} + +// Visit the use-site of a defined DataflowVar +Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { + auto it = var_remap_.find(op->vid); + if (it != var_remap_.end()) { + return it->second; + } + + // default case return self. + return GetRef(op); +} + +Expr ExprMutator::VisitExpr_(const FunctionNode* op) { + tvm::Array params; + bool all_params_unchanged = true; + for (Var param : op->params) { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + if (!param.same_as(new_param)) { + var_remap_[param->vid] = new_param; + all_params_unchanged = false; + } + } + + Expr body = this->VisitWithNewScope(op->body, params); + + // FuncStructInfo does not depend on Expr + if (all_params_unchanged && body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(params, body, op->ret_struct_info, op->attrs); + } +} + +Expr ExprMutator::VisitExpr_(const IfNode* op) { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitWithNewScope(op->true_branch); + Expr false_b = this->VisitWithNewScope(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } +} + +Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { + bool all_blocks_unchanged = true; + Array blocks; + for (auto block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + all_blocks_unchanged &= block.same_as(new_block); + } + + builder_->BeginBindingBlock(); + Expr body = this->VisitExpr(op->body); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + blocks.push_back(prologue); + all_blocks_unchanged = false; + } + + if (all_blocks_unchanged && body.same_as(op->body) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); + } else { + return SeqExpr(blocks, body); + } +} + +RELAX_VAR_BINDING_DISPATCH_IMPL(ExprMutator); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ConstantNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(TupleNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(VarNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(DataflowVarNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ShapeExprNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ExternFuncNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(GlobalVarNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(FunctionNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(CallNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(SeqExprNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(IfNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(OpNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(TupleGetItemNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(PrimValueNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(StringImmNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(DataTypeImmNode); + +void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { + Var new_var = this->VisitVarDef(binding->var); + + // fast path: re-emit binding if nothing changes + if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + return; + } + + Var temp = WithStructInfo(new_var, GetStructInfo(new_value)); + if (!temp.same_as(new_var)) { + new_var = temp; + } + this->var_remap_[binding->var->vid] = new_var; + + builder_->EmitNormalized(VarBinding(new_var, new_value)); +} + +void ExprMutator::VisitBinding_(const MatchCastNode* binding) { + Var new_var = this->VisitVarDef(binding->var); + Expr new_value = this->VisitExpr(binding->value); + + // re-emit old binding if nothing changes + if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + new_value = builder_->NormalizeArgument(new_value); + builder_->EmitNormalized(MatchCast(new_var, new_value, binding->struct_info, binding->span)); + } +} + +BindingBlock ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); +} + +BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) { + builder_->BeginDataflowBlock(); + for (auto binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); +} + +Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { + if (auto* sinfo = var->struct_info_.as()) { + StructInfo struct_info = this->VisitExprDepStructInfoField(GetRef(sinfo)); + if (struct_info.same_as(var->struct_info_)) { + return GetRef(var); + } else { + return DataflowVar(var->vid, struct_info, var->span); + } + } else { + return GetRef(var); + } +} + +Var ExprMutator::VisitVarDef_(const VarNode* var) { + if (auto* sinfo = var->struct_info_.as()) { + StructInfo struct_info = this->VisitExprDepStructInfoField(GetRef(sinfo)); + if (struct_info.same_as(var->struct_info_)) { + return GetRef(var); + } else { + return Var(var->vid, struct_info, var->span); + } + } else { + return GetRef(var); + } +} + +void ExprMutator::VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } +} + +BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) { + BindingBlock ret; + if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return ret; +} + +Var ExprMutator::VisitVarDef(const Var& var) { + Var ret; + if (const auto* node = var.as()) { + ret = VisitVarDef_(node); + } else if (const auto* node = var.as()) { + ret = VisitVarDef_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } + return ret; +} + +Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> params) { + ICHECK(expr->IsInstance()) + << "Normal form requires all new scope is stored as SeqExpr"; + builder_->BeginScope(params); + Expr ret = this->VisitExpr(expr); + builder_->EndScope(); + return ret; +} + +Optional ExprMutator::LookupBinding(const Var& var) { return builder_->LookupBinding(var); } + +Var ExprMutator::WithStructInfo(Var var, StructInfo struct_info) { + ICHECK(struct_info.defined()); + + // TODO(relax-team) add StructInfoEqual check + if (var->struct_info_.defined()) { + // use same-as as a quick path + if (var->struct_info_.same_as(struct_info) || + StructuralEqual()(var->struct_info_, struct_info)) { + return var; + } else { + Var new_var = var.as() ? DataflowVar(var->vid, struct_info, var->span) + : Var(var->vid, struct_info, var->span); + return new_var; + } + } else { + UpdateStructInfo(var, struct_info); + return var; + } +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc new file mode 100644 index 000000000000..7e86235aa61e --- /dev/null +++ b/src/relax/ir/py_expr_functor.cc @@ -0,0 +1,649 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/py_expr_functor.cc + * \brief The backbone of PyExprVisitor/PyExprMutator. + */ +#include + +namespace tvm { +namespace relax { + +/*! + * \brief The abstract interface of ExprVisitor. + */ +class PyExprVisitorNode : public Object, public ExprVisitor { + private: + using TSelf = PyExprVisitorNode; + using FType = tvm::NodeFunctor; + + public: + /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ + PackedFunc f_visit_expr{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ConstantNode* op)` function. */ + PackedFunc f_visit_constant_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleNode* op)` function. */ + PackedFunc f_visit_tuple_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ + PackedFunc f_visit_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataflowVarNode* op)` function. */ + PackedFunc f_visit_dataflow_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ + PackedFunc f_visit_shape_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ + PackedFunc f_visit_extern_func_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ + PackedFunc f_visit_global_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FunctionNode* op)` function. */ + PackedFunc f_visit_function_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ + PackedFunc f_visit_call_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SeqExprNode* op)` function. */ + PackedFunc f_visit_seq_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const IfNode* op)` function. */ + PackedFunc f_visit_if_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const OpNode* op)` function. */ + PackedFunc f_visit_op_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ + PackedFunc f_visit_tuple_getitem_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const PrimValueNode* op)` function. */ + PackedFunc f_visit_prim_value_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ + PackedFunc f_visit_string_imm_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataTypeImmNode* op)` function. */ + PackedFunc f_visit_data_type_imm_{nullptr}; + /*! \brief The packed function to the `VisitBinding(const Binding& binding)` function. */ + PackedFunc f_visit_binding{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` + * function. */ + PackedFunc f_visit_var_binding_{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)` + * function. */ + PackedFunc f_visit_match_cast_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` + * function. */ + PackedFunc f_visit_binding_block{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const BindingBlockNode* block)` + * function. */ + PackedFunc f_visit_binding_block_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const DataflowBlockNode* block)` + * function. */ + PackedFunc f_visit_dataflow_block_{nullptr}; + /*! \brief The packed function to the `VisitVarDef(const Var& var)` function. */ + PackedFunc f_visit_var_def{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const VarNode* var)` function. */ + PackedFunc f_visit_var_def_{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const DataflowVarNode* var)` function. */ + PackedFunc f_visit_dataflow_var_def_{nullptr}; + /*! \brief The packed function to the `VisitSpan(const Span& span)` function. */ + PackedFunc f_visit_span{nullptr}; + + void VisitExpr(const Expr& expr) { + if (f_visit_expr != nullptr) { + f_visit_expr(expr); + } else { + // Need to init the overwrite VTable + static FType vtable = InitVTable(); + vtable(expr, this); + } + } + + void VisitBinding(const Binding& binding) + PY_EXPR_VISITOR_DEFAULT(binding, f_visit_binding, ExprVisitor::VisitBinding(binding)); + + void VisitBinding_(const VarBindingNode* binding) + PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_var_binding_, + ExprVisitor::VisitBinding_(binding)); + void VisitBinding_(const MatchCastNode* binding) + PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_match_cast_, + ExprVisitor::VisitBinding_(binding)); + + void VisitBindingBlock(const BindingBlock& block) + PY_EXPR_VISITOR_DEFAULT(block, f_visit_binding_block, ExprVisitor::VisitBindingBlock(block)); + + void VisitBindingBlock_(const BindingBlockNode* block) + PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_binding_block_, + ExprVisitor::VisitBindingBlock_(block)); + void VisitBindingBlock_(const DataflowBlockNode* block) + PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + ExprVisitor::VisitBindingBlock_(block)); + + void VisitVarDef(const Var& var) + PY_EXPR_VISITOR_DEFAULT(var, f_visit_var_def, ExprVisitor::VisitVarDef(var)); + void VisitVarDef_(const VarNode* var) + PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_var_def_, ExprVisitor::VisitVarDef_(var)); + void VisitVarDef_(const DataflowVarNode* var) + PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + ExprVisitor::VisitVarDef_(var)); + + void VisitSpan(const Span& span) + PY_EXPR_VISITOR_DEFAULT(span, f_visit_span, ExprVisitor::VisitSpan(span)); + + void VisitAttrs(AttrVisitor* v) {} + static constexpr const char* _type_key = "expr_functor.PyExprVisitor"; + TVM_DECLARE_BASE_OBJECT_INFO(PyExprVisitorNode, Object); + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + PY_EXPR_VISITOR_DISPATCH(ConstantNode, f_visit_constant_); + PY_EXPR_VISITOR_DISPATCH(TupleNode, f_visit_tuple_); + PY_EXPR_VISITOR_DISPATCH(VarNode, f_visit_var_); + PY_EXPR_VISITOR_DISPATCH(DataflowVarNode, f_visit_dataflow_var_); + PY_EXPR_VISITOR_DISPATCH(ShapeExprNode, f_visit_shape_expr_); + PY_EXPR_VISITOR_DISPATCH(ExternFuncNode, f_visit_extern_func_); + PY_EXPR_VISITOR_DISPATCH(GlobalVarNode, f_visit_global_var_); + PY_EXPR_VISITOR_DISPATCH(FunctionNode, f_visit_function_); + PY_EXPR_VISITOR_DISPATCH(CallNode, f_visit_call_); + PY_EXPR_VISITOR_DISPATCH(SeqExprNode, f_visit_seq_expr_); + PY_EXPR_VISITOR_DISPATCH(IfNode, f_visit_if_); + PY_EXPR_VISITOR_DISPATCH(OpNode, f_visit_op_); + PY_EXPR_VISITOR_DISPATCH(TupleGetItemNode, f_visit_tuple_getitem_); + PY_EXPR_VISITOR_DISPATCH(PrimValueNode, f_visit_prim_value_); + PY_EXPR_VISITOR_DISPATCH(StringImmNode, f_visit_string_imm_); + PY_EXPR_VISITOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_); + return vtable; + } +}; + +TVM_REGISTER_NODE_TYPE(PyExprVisitorNode); + +/*! + * \brief Managed reference to PyExprVisitorNode. + * \sa PyExprVisitorNode + */ +class PyExprVisitor : public ObjectRef { + public: + /*! + * \brief Create a PyExprVisitor with customized methods on the python-side. + * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. + * \param f_visit_constant_ The packed function of `VisitExpr_(const ConstantNode* op)`. + * \param f_visit_tuple_ The packed function of `VisitExpr_(const TupleNode* op)`. + * \param f_visit_var_ The packed function of `VisitExpr_(const VarNode* op)`. + * \param f_visit_dataflow_var_ The packed function of `VisitExpr_(const DataflowVarNode* op)`. + * \param f_visit_shape_expr_ The packed function of `VisitExpr_(const ShapeExprNode* op)`. + * \param f_visit_extern_func_ The packed function of `VisitExpr_(const ExternFuncNode* op)`. + * \param f_visit_global_var_ The packed function of `VisitExpr_(const GlobalVarNode* op)`. + * \param f_visit_function_ The packed function of `VisitExpr_(const FunctionNode* op)`. + * \param f_visit_call_ The packed function of `VisitExpr_(const CallNode* op)`. + * \param f_visit_seq_expr_ The packed function of `VisitExpr_(const SeqExprNode* op)`. + * \param f_visit_if_ The packed function of `VisitExpr_(const IfNode* op)`. + * \param f_visit_op_ The packed function of `VisitExpr_(const OpNode* op)`. + * \param f_visit_tuple_getitem_ The packed function of `VisitExpr_(const TupleGetItemNode* op)`. + * \param f_visit_prim_value_ The packed function of `VisitExpr_(const PrimValueNode* op)`. + * \param f_visit_string_imm_ The packed function of `VisitExpr_(const StringImmNode* op)`. + * \param f_visit_data_type_imm_ The packed function of `VisitExpr_(const DataTypeImmNode* op)`. + * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. + * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode* + * binding)`. + * \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode* + * binding)`. + * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock& + * block)`. + * \param f_visit_binding_block_ The packed function of `VisitBindingBlock_(const + * BindingBlockNode* block)`. + * \param f_visit_dataflow_block_ The packed function of `VisitBindingBlock_(const + * DataflowBlockNode* block)`. + * \param f_visit_var_def The packed function of `VisitVarDef(const Var& var)`. + * \param f_visit_var_def_ The packed function of `VisitVarDef_(const VarNode* var)`. + * \param f_visit_dataflow_var_def_ The packed function of `VisitVarDef_(const DataflowVarNode* + * var)`. + * \param f_visit_span The packed function of `VisitSpan(const Span& span)`. + * \return The PyVisitor created. + */ + TVM_DLL static PyExprVisitor MakePyExprVisitor( + PackedFunc f_visit_expr, PackedFunc f_visit_constant_, PackedFunc f_visit_tuple_, + PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, PackedFunc f_visit_shape_expr_, + PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, + PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, + PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_prim_value_, + PackedFunc f_visit_string_imm_, PackedFunc f_visit_data_type_imm_, PackedFunc f_visit_binding, + PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_, + PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, + PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, + PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_span) { + ObjectPtr n = make_object(); + n->f_visit_expr = f_visit_expr; + n->f_visit_binding = f_visit_binding; + n->f_visit_binding_block = f_visit_binding_block; + n->f_visit_var_def = f_visit_var_def; + n->f_visit_span = f_visit_span; + n->f_visit_constant_ = f_visit_constant_; + n->f_visit_tuple_ = f_visit_tuple_; + n->f_visit_var_ = f_visit_var_; + n->f_visit_dataflow_var_ = f_visit_dataflow_var_; + n->f_visit_shape_expr_ = f_visit_shape_expr_; + n->f_visit_extern_func_ = f_visit_extern_func_; + n->f_visit_global_var_ = f_visit_global_var_; + n->f_visit_function_ = f_visit_function_; + n->f_visit_call_ = f_visit_call_; + n->f_visit_seq_expr_ = f_visit_seq_expr_; + n->f_visit_if_ = f_visit_if_; + n->f_visit_op_ = f_visit_op_; + n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; + n->f_visit_prim_value_ = f_visit_prim_value_; + n->f_visit_string_imm_ = f_visit_string_imm_; + n->f_visit_data_type_imm_ = f_visit_data_type_imm_; + n->f_visit_var_binding_ = f_visit_var_binding_; + n->f_visit_match_cast_ = f_visit_match_cast_; + n->f_visit_binding_block_ = f_visit_binding_block_; + n->f_visit_dataflow_block_ = f_visit_dataflow_block_; + n->f_visit_var_def_ = f_visit_var_def_; + n->f_visit_dataflow_var_def_ = f_visit_dataflow_var_def_; + return PyExprVisitor(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprVisitor, ObjectRef, PyExprVisitorNode); +}; + +/*! + * \brief The abstract interface of ExprMutator. + */ +class PyExprMutatorNode : public Object, public ExprMutator { + private: + using TSelf = PyExprMutatorNode; + using FType = tvm::NodeFunctor; + + public: + /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ + PackedFunc f_visit_expr{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ConstantNode* op)` function. */ + PackedFunc f_visit_constant_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleNode* op)` function. */ + PackedFunc f_visit_tuple_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ + PackedFunc f_visit_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataflowVarNode* op)` function. */ + PackedFunc f_visit_dataflow_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ + PackedFunc f_visit_shape_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ + PackedFunc f_visit_extern_func_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ + PackedFunc f_visit_global_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FunctionNode* op)` function. */ + PackedFunc f_visit_function_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ + PackedFunc f_visit_call_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SeqExprNode* op)` function. */ + PackedFunc f_visit_seq_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const IfNode* op)` function. */ + PackedFunc f_visit_if_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const OpNode* op)` function. */ + PackedFunc f_visit_op_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ + PackedFunc f_visit_tuple_getitem_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const PrimValueNode* op)` function. */ + PackedFunc f_visit_prim_value_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ + PackedFunc f_visit_string_imm_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataTypeImmNode* op)` function. */ + PackedFunc f_visit_data_type_imm_{nullptr}; + /*! \brief The packed function to the `VisitBinding(const Binding& binding)` function. */ + PackedFunc f_visit_binding{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` + * function. */ + PackedFunc f_visit_var_binding_{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)` + * function. */ + PackedFunc f_visit_match_cast_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` + * function. */ + PackedFunc f_visit_binding_block{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const BindingBlockNode* block)` + * function. */ + PackedFunc f_visit_binding_block_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const DataflowBlockNode* block)` + * function. */ + PackedFunc f_visit_dataflow_block_{nullptr}; + /*! \brief The packed function to the `VisitVarDef(const Var& var)` function. */ + PackedFunc f_visit_var_def{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const VarNode* var)` function. */ + PackedFunc f_visit_var_def_{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const DataflowVarNode* var)` function. */ + PackedFunc f_visit_dataflow_var_def_{nullptr}; + /*! \brief The packed function to the `VisitSpan(const Span& span)` function. */ + PackedFunc f_visit_span{nullptr}; + + Expr VisitExpr(const Expr& expr) { + if (f_visit_expr != nullptr) { + return builder_->Normalize(f_visit_expr(expr)); + } else { + static FType vtable = InitVTable(); + return builder_->Normalize(vtable(expr, this)); + } + } + + void VisitBinding(const Binding& binding) { + if (f_visit_binding != nullptr) + f_visit_binding(binding); + else + ExprMutator::VisitBinding(binding); + } + + void VisitBinding_(const VarBindingNode* binding) { + if (f_visit_var_binding_ != nullptr) + f_visit_var_binding_(GetRef(binding)); + else + ExprMutator::VisitBinding_(binding); + } + + void VisitBinding_(const MatchCastNode* binding) { + if (f_visit_match_cast_ != nullptr) + f_visit_match_cast_(GetRef(binding)); + else + ExprMutator::VisitBinding_(binding); + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) + PY_EXPR_MUTATOR_DEFAULT(block, f_visit_binding_block, ExprMutator::VisitBindingBlock(block), + BindingBlock); + + BindingBlock VisitBindingBlock_(const BindingBlockNode* block) + PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_binding_block_, + ExprMutator::VisitBindingBlock_(block), BindingBlock); + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) + PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + ExprMutator::VisitBindingBlock_(block), BindingBlock); + + Var VisitVarDef(const Var& var) + PY_EXPR_MUTATOR_DEFAULT(var, f_visit_var_def, ExprMutator::VisitVarDef(var), Var); + Var VisitVarDef_(const VarNode* var) PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_var_def_, + ExprMutator::VisitVarDef_(var), Var); + Var VisitVarDef_(const DataflowVarNode* var) + PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + ExprMutator::VisitVarDef_(var), Var); + + /*! + * \brief Dispatcher for post-order rewrite. + * \param expr The Expr to be rewritten. + * \return The Expr after post-order rewritten. + */ + Expr VisitExprPostOrder(const Expr& expr) { + static FType post_order_vtable = InitPostOrderVTable(); + return post_order_vtable(expr, this); + } + + using ExprMutator::builder_; + using ExprMutator::LookupBinding; + using ExprMutator::var_remap_; + using ExprMutator::VisitWithNewScope; + using ExprMutator::WithStructInfo; + + void VisitAttrs(AttrVisitor* v) { v->Visit("builder_", &builder_); } + static constexpr const char* _type_key = "expr_functor.PyExprMutator"; + TVM_DECLARE_BASE_OBJECT_INFO(PyExprMutatorNode, Object); + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + PY_EXPR_MUTATOR_DISPATCH(ConstantNode, f_visit_constant_); + PY_EXPR_MUTATOR_DISPATCH(TupleNode, f_visit_tuple_); + PY_EXPR_MUTATOR_DISPATCH(VarNode, f_visit_var_); + PY_EXPR_MUTATOR_DISPATCH(DataflowVarNode, f_visit_dataflow_var_); + PY_EXPR_MUTATOR_DISPATCH(ShapeExprNode, f_visit_shape_expr_); + PY_EXPR_MUTATOR_DISPATCH(ExternFuncNode, f_visit_extern_func_); + PY_EXPR_MUTATOR_DISPATCH(GlobalVarNode, f_visit_global_var_); + PY_EXPR_MUTATOR_DISPATCH(FunctionNode, f_visit_function_); + PY_EXPR_MUTATOR_DISPATCH(CallNode, f_visit_call_); + PY_EXPR_MUTATOR_DISPATCH(SeqExprNode, f_visit_seq_expr_); + PY_EXPR_MUTATOR_DISPATCH(IfNode, f_visit_if_); + PY_EXPR_MUTATOR_DISPATCH(OpNode, f_visit_op_); + PY_EXPR_MUTATOR_DISPATCH(TupleGetItemNode, f_visit_tuple_getitem_); + PY_EXPR_MUTATOR_DISPATCH(PrimValueNode, f_visit_prim_value_); + PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm_); + PY_EXPR_MUTATOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_); + return vtable; + } + + // initialize the vtable for post order visit. + static FType InitPostOrderVTable() { + FType post_order_vtable; + // Set dispatch + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ConstantNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(TupleNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(VarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataflowVarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ShapeExprNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ExternFuncNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(GlobalVarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(FunctionNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(CallNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(SeqExprNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(IfNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OpNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(TupleGetItemNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(PrimValueNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(StringImmNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataTypeImmNode); + return post_order_vtable; + } +}; + +TVM_REGISTER_NODE_TYPE(PyExprMutatorNode); + +/*! + * \brief Managed reference to PyExprMutatorNode. + * \sa PyExprMutatorNode + */ +class PyExprMutator : public ObjectRef { + public: + /*! + * \brief Create a PyExprMutator with customized methods on the python-side. + * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. + * \param f_visit_constant_ The packed function of `VisitExpr_(const ConstantNode* op)`. + * \param f_visit_tuple_ The packed function of `VisitExpr_(const TupleNode* op)`. + * \param f_visit_var_ The packed function of `VisitExpr_(const VarNode* op)`. + * \param f_visit_dataflow_var_ The packed function of `VisitExpr_(const DataflowVarNode* op)`. + * \param f_visit_shape_expr_ The packed function of `VisitExpr_(const ShapeExprNode* op)`. + * \param f_visit_extern_func_ The packed function of `VisitExpr_(const ExternFuncNode* op)`. + * \param f_visit_global_var_ The packed function of `VisitExpr_(const GlobalVarNode* op)`. + * \param f_visit_function_ The packed function of `VisitExpr_(const FunctionNode* op)`. + * \param f_visit_call_ The packed function of `VisitExpr_(const CallNode* op)`. + * \param f_visit_seq_expr_ The packed function of `VisitExpr_(const SeqExprNode* op)`. + * \param f_visit_if_ The packed function of `VisitExpr_(const IfNode* op)`. + * \param f_visit_op_ The packed function of `VisitExpr_(const OpNode* op)`. + * \param f_visit_tuple_getitem_ The packed function of `VisitExpr_(const TupleGetItemNode* op)`. + * \param f_visit_prim_value_ The packed function of `VisitExpr_(const PrimValueNode* op)`. + * \param f_visit_string_imm_ The packed function of `VisitExpr_(const StringImmNode* op)`. + * \param f_visit_data_type_imm_ The packed function of `VisitExpr_(const DataTypeImmNode* op)`. + * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. + * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode* + * binding)`. + * \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode* + * binding)`. + * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock& + * block)`. + * \param f_visit_binding_block_ The packed function of `VisitBindingBlock_(const + * BindingBlockNode* block)`. + * \param f_visit_dataflow_block_ The packed function of `VisitBindingBlock_(const + * DataflowBlockNode* block)`. + * \param f_visit_var_def The packed function of `VisitVarDef(const Var& var)`. + * \param f_visit_var_def_ The packed function of `VisitVarDef_(const VarNode* var)`. + * \param f_visit_dataflow_var_def_ The packed function of `VisitVarDef_(const DataflowVarNode* + * var)`. + * \param f_visit_span The packed function of `VisitSpan(const Span& span)`. + * \return The PyExprMutator created. + */ + TVM_DLL static PyExprMutator MakePyExprMutator( + BlockBuilder builder_, PackedFunc f_visit_expr, PackedFunc f_visit_constant_, + PackedFunc f_visit_tuple_, PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, + PackedFunc f_visit_shape_expr_, PackedFunc f_visit_extern_func_, + PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_, + PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_, + PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_prim_value_, + PackedFunc f_visit_string_imm_, PackedFunc f_visit_data_type_imm_, PackedFunc f_visit_binding, + PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_, + PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, + PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, + PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_span) { + ObjectPtr n = make_object(); + n->builder_ = builder_; + n->f_visit_expr = f_visit_expr; + n->f_visit_constant_ = f_visit_constant_; + n->f_visit_tuple_ = f_visit_tuple_; + n->f_visit_var_ = f_visit_var_; + n->f_visit_dataflow_var_ = f_visit_dataflow_var_; + n->f_visit_shape_expr_ = f_visit_shape_expr_; + n->f_visit_extern_func_ = f_visit_extern_func_; + n->f_visit_global_var_ = f_visit_global_var_; + n->f_visit_function_ = f_visit_function_; + n->f_visit_call_ = f_visit_call_; + n->f_visit_seq_expr_ = f_visit_seq_expr_; + n->f_visit_if_ = f_visit_if_; + n->f_visit_op_ = f_visit_op_; + n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; + n->f_visit_prim_value_ = f_visit_prim_value_; + n->f_visit_string_imm_ = f_visit_string_imm_; + n->f_visit_data_type_imm_ = f_visit_data_type_imm_; + n->f_visit_binding = f_visit_binding; + n->f_visit_var_binding_ = f_visit_var_binding_; + n->f_visit_match_cast_ = f_visit_match_cast_; + n->f_visit_binding_block = f_visit_binding_block; + n->f_visit_binding_block_ = f_visit_binding_block_; + n->f_visit_dataflow_block_ = f_visit_dataflow_block_; + n->f_visit_var_def = f_visit_var_def; + n->f_visit_var_def_ = f_visit_var_def_; + n->f_visit_dataflow_var_def_ = f_visit_dataflow_var_def_; + n->f_visit_span = f_visit_span; + return PyExprMutator(n); + } + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprMutator, ObjectRef, PyExprMutatorNode); +}; + +TVM_REGISTER_GLOBAL("relax.MakePyExprVisitor").set_body_typed(PyExprVisitor::MakePyExprVisitor); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitExpr") + .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { visitor->VisitExpr(expr); }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBinding") + .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { + visitor->VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBindingBlock") + .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { + visitor->VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitVarDef") + .set_body_typed([](PyExprVisitor visitor, const Var& var) { visitor->VisitVarDef(var); }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitExpr") + .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { + visitor->ExprVisitor::VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding") + .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { + visitor->ExprVisitor::VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock") + .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { + visitor->ExprVisitor::VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef") + .set_body_typed([](PyExprVisitor visitor, const Var& var) { + visitor->ExprVisitor::VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitSpan") + .set_body_typed([](PyExprVisitor visitor, const Span& span) { + visitor->ExprVisitor::VisitSpan(span); + }); + +TVM_REGISTER_GLOBAL("relax.MakePyExprMutator").set_body_typed(PyExprMutator::MakePyExprMutator); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExpr") + .set_body_typed([](PyExprMutator mutator, const Expr& expr) { + return mutator->VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBinding") + .set_body_typed([](PyExprMutator mutator, const Binding& binding) { + mutator->VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBindingBlock") + .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) { + return mutator->VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitVarDef") + .set_body_typed([](PyExprMutator mutator, const Var& var) { + return mutator->VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitExpr") + .set_body_typed([](PyExprMutator mutator, const Expr& expr) { + return mutator->ExprMutator::VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding") + .set_body_typed([](PyExprMutator mutator, const Binding& binding) { + return mutator->ExprMutator::VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock") + .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) { + return mutator->ExprMutator::VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef") + .set_body_typed([](PyExprMutator mutator, const Var& var) { + return mutator->ExprMutator::VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExprPostOrder") + .set_body_typed([](PyExprMutator mutator, const Expr& expr) { + return mutator->VisitExprPostOrder(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitWithNewScope") + .set_body_typed([](PyExprMutator mutator, const Expr& expr) { + return mutator->VisitWithNewScope(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorLookupBinding") + .set_body_typed([](PyExprMutator mutator, const Var& var) { + return mutator->LookupBinding(var); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorWithStructInfo") + .set_body_typed([](PyExprMutator mutator, Var var, StructInfo sinfo) { + return mutator->WithStructInfo(var, sinfo); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorSetVarRemap") + .set_body_typed([](PyExprMutator mutator, Id id, Var var) { + return mutator->var_remap_[id] = var; + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorGetVarRemap") + .set_body_typed([](PyExprMutator mutator, Id id) { return mutator->var_remap_[id]; }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc new file mode 100644 index 000000000000..4004ad28d560 --- /dev/null +++ b/src/relax/ir/struct_info.cc @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/struct_info.cc + * \brief Relax struct info. + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +ObjectStructInfo::ObjectStructInfo(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ObjectStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { + return ObjectStructInfo(span); +}); + +// Prim +PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { + ObjectPtr n = make_object(); + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PrimStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Span span) { + return PrimStructInfo(dtype, span); +}); + +// Shape +ShapeStructInfo::ShapeStructInfo(Array values, Span span) { + ObjectPtr n = make_object(); + n->ndim = static_cast(values.size()); + n->values = values.Map([](PrimExpr value) { + if (value->IsInstance()) { + return tvm::cast(DataType::Int(64), value); + } + ICHECK(value.dtype() == DataType::Int(64)) + << "the value in ShapeStructInfo can only have dtype of int64"; + return value; + }); + n->span = span; + data_ = std::move(n); +} + +ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { + ObjectPtr n = make_object(); + CHECK_GE(ndim, -1) << "ndim of ShapeStructInfo must be >= -1, but got " << ndim; + n->ndim = ndim; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ShapeStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.ShapeStructInfo") + .set_body_typed([](Optional> values, int ndim, Span span) { + if (values.defined()) { + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify values and ndim"; + return ShapeStructInfo(values.value(), span); + } else { + return ShapeStructInfo(ndim, span); + } + }); + +// Tensor +TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) { + ObjectPtr n = make_object(); + // assign ndim before move + Optional sinfo = MatchStructInfo(shape); + ICHECK(sinfo) << "We expect shape to contain pre-set shape struct info"; + ICHECK(shape.defined()) << "Must provide a shape in this constructor"; + ICHECK(shape->IsInstance() || shape->IsInstance()) + << "We require shape to be normalized when constructing TensorStructInfo"; + n->ndim = sinfo.get()->ndim; + // assign rest of the fields. + n->shape = std::move(shape); + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Span span) { + ObjectPtr n = make_object(); + CHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " << ndim; + n->ndim = ndim; + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TensorStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.TensorStructInfo") + .set_body_typed([](Optional shape, DataType dtype, int ndim, Span span) { + if (shape.defined()) { + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape and ndim"; + return TensorStructInfo(shape.value(), dtype, span); + } else { + return TensorStructInfo(dtype, ndim, span); + } + }); + +// Tuple +TupleStructInfo::TupleStructInfo(Array fields, Span span) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.TupleStructInfo") + .set_body_typed([](Array fields, Span span) { + return TupleStructInfo(fields, span); + }); + +// Func +FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, Span span) { + ObjectPtr n = make_object(); + n->params = std::move(params); + n->ret = std::move(ret); + n->span = span; + data_ = std::move(n); +} + +FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, Span span) { + ObjectPtr n = make_object(); + n->derive_func = std::move(derive_func); + n->ret = ObjectStructInfo(); + n->span = span; + return FuncStructInfo(n); +} + +FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, Span span) { + ObjectPtr n = make_object(); + n->ret = std::move(ret); + n->span = span; + return FuncStructInfo(n); +} + +TVM_REGISTER_NODE_TYPE(FuncStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.FuncStructInfo") + .set_body_typed([](Array params, StructInfo ret, Span span) { + return FuncStructInfo(params, ret, span); + }); + +TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") + .set_body_typed([](Optional ret, Optional derive_func, + Span span) { + if (derive_func.defined()) { + ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and derive_func"; + return FuncStructInfo::OpaqueFunc(derive_func.value(), span); + } else { + return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), span); + } + }); + +// Helper functions +void UpdateStructInfo(Expr expr, StructInfo struct_info) { + ICHECK(!expr->struct_info_.defined()) + << "the struct_info_ of the Expr to be updated must be nullptr for idempotency"; + expr->struct_info_ = struct_info; + // also set checked type + expr->checked_type_ = GetStaticType(struct_info); +} + +TVM_REGISTER_GLOBAL("relax.UpdateStructInfo").set_body_typed([](Expr expr, StructInfo struct_info) { + UpdateStructInfo(expr, struct_info); +}); + +TVM_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) { + return GetStructInfo(expr); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc new file mode 100644 index 000000000000..199491e3c63f --- /dev/null +++ b/src/relax/ir/struct_info_functor.cc @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file struct_info_functor.cc + * \brief Implementations of struct info functors. + */ +#include + +namespace tvm { +namespace relax { + +void StructInfoVisitor::VisitStructInfo_(const ObjectStructInfoNode* op) {} + +void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) {} + +void StructInfoVisitor::VisitStructInfo_(const ShapeStructInfoNode* op) { + if (op->values.defined()) { + for (PrimExpr value : op->values.value()) { + this->VisitStructInfoExprField(value); + } + } +} + +void StructInfoVisitor::VisitStructInfo_(const TensorStructInfoNode* op) { + if (op->shape.defined()) { + this->VisitStructInfoExprField(op->shape.value()); + } +} + +void StructInfoVisitor::VisitStructInfo_(const TupleStructInfoNode* op) { + for (StructInfo field : op->fields) { + this->VisitStructInfo(field); + } +} + +void StructInfoVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { + if (op->params.defined()) { + for (StructInfo param : op->params.value()) { + this->VisitStructInfo(param); + } + } + this->VisitStructInfo(op->ret); +} + +StructInfo StructInfoMutator::VisitStructInfo_(const ObjectStructInfoNode* op) { + return GetRef(op); +} + +StructInfo StructInfoMutator::VisitStructInfo_(const PrimStructInfoNode* op) { + return GetRef(op); +} + +StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) { + Optional> values; + + if (op->values.defined()) { + // if no changes are made the original array will be returned. + values = op->values.value().Map( + [this](const PrimExpr& expr) { return this->VisitStructInfoExprField(expr); }); + } + + if (values.same_as(op->values)) { + return GetRef(op); + } else { + return ShapeStructInfo(values.value(), op->span); + } +} + +StructInfo StructInfoMutator::VisitStructInfo_(const TensorStructInfoNode* op) { + Optional shape; + + if (op->shape.defined()) { + shape = this->VisitStructInfoExprField(op->shape.value()); + } + + if (shape.same_as(op->shape)) { + return GetRef(op); + } else { + return TensorStructInfo(shape.value(), op->dtype, op->span); + } +} + +StructInfo StructInfoMutator::VisitStructInfo_(const TupleStructInfoNode* op) { + Array fields = + op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + + if (fields.same_as(op->fields)) { + return GetRef(op); + } else { + return TupleStructInfo(fields, op->span); + } +} + +StructInfo StructInfoMutator::VisitStructInfo_(const FuncStructInfoNode* op) { + Optional> params; + + if (op->params.defined()) { + params = op->params.value().Map( + [this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + } + + StructInfo ret = this->VisitStructInfo(op->ret); + + if (params.same_as(op->params) && ret.same_as(op->ret)) { + return GetRef(op); + } else { + ICHECK(ret.defined()) << "FuncStructInfo that contains params must contain ret"; + return FuncStructInfo(params.value(), ret, op->span); + } +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/tir_pattern.cc b/src/relax/ir/tir_pattern.cc new file mode 100644 index 000000000000..cbe4170bb979 --- /dev/null +++ b/src/relax/ir/tir_pattern.cc @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +namespace tvm { +namespace relax { + +MatchResult::MatchResult(TIRPattern pattern, Array symbol_values, + Array matched_buffers) { + auto n = make_object(); + n->pattern = std::move(pattern); + n->symbol_values = std::move(symbol_values); + n->matched_buffers = std::move(matched_buffers); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(MatchResultNode); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc new file mode 100644 index 000000000000..9f418bff5c6d --- /dev/null +++ b/src/relax/ir/transform.cc @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/ir/transform.cc + * \brief Relax specific transformation passes. + */ +#include +#include +#include +#include +#include +#include +#include +#include +namespace tvm { +namespace relax { +namespace transform { + +TVM_REGISTER_PASS_CONFIG_OPTION("relax.fallback_device_type", IntImm); + +// TODO(@yuchen): will need to dedup with FunctionPass in Relay when we upstream +class FunctionPass; + +/*! + * \brief Function-level passes are used to implement various global + * optimizations for a given Relax IRModule. It fetches one function at a time + * from the function list in the IRModule for optimization. + * + * Note that the scope of passes at this level is a Relax function. Therefore, + * we cannot add or delete a function through these passes as they are not aware + * of the global information. + */ +class FunctionPassNode : public tvm::transform::PassNode { + public: + /* \brief The pass meta data.*/ + PassInfo pass_info; + + /*! \brief The packed pass function sketches the real optimization. For + * instance, we can implement a pass that works on a Relax function as a + * `pass_func` and let it run on a given IRModule. The same `pass_func` will + * then be applied on each function in the IRModule. + */ + runtime::TypedPackedFunc pass_func; + + FunctionPassNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + + /*! + * \brief Run a function pass on given pass context. + * + * \param mod The IRModule that an optimization pass is applied on. + * \param pass_ctx The context that an optimization pass executes on. + * + * \return Return the updated IRModule. + */ + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; + + /*! + * \brief Get the pass information/meta data. + */ + PassInfo Info() const override { return pass_info; } + + static constexpr const char* _type_key = "relax.FunctionPass"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); + + private: + /* + * \brief Check if a function should be skipped for optimization. + * + * \param func The target function to be checked. + * + * \return Return true if the function will be skipped, otherwise false. + */ + bool SkipFunction(const Function& func) const; +}; + +class FunctionPass : public Pass { + public: + /*! + * \brief The constructor + * \param pass_func The packed function which implements a pass. + * \param pass_info The pass info. + */ + TVM_DLL FunctionPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info); + + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); +}; + +FunctionPass::FunctionPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { + auto n = make_object(); + n->pass_func = std::move(pass_func); + n->pass_info = std::move(pass_info); + data_ = std::move(n); +} + +// Perform IRModule -> IRModule optimizations at the Function level. +IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { + DiagnosticContext previous = DiagnosticContext::Default(mod); + + if (pass_ctx->diag_ctx) { + DiagnosticContext tmp = pass_ctx->diag_ctx.value(); + pass_ctx->diag_ctx = previous; + previous = tmp; + } else { + pass_ctx->diag_ctx = previous; + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + + const PassInfo& pass_info = Info(); + + ICHECK(mod.defined()); + + VLOG_CONTEXT << pass_info->name; + VLOG(0) << "Executing function pass with opt level: " << pass_info->opt_level; + VLOG(1) << "Input module:" << std::endl << mod; + + IRModule updated_mod = mod->ShallowCopy(); + + std::vector > updates; + for (const auto& it : updated_mod->functions) { + // only picks up relax::Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + auto updated_func = SkipFunction(func) ? func : pass_func(func, updated_mod, pass_ctx); + updates.push_back({it.first, updated_func}); + } + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block, this is a bug."; + + pass_ctx->diag_ctx.value().Render(); + pass_ctx->diag_ctx = previous; + + VLOG(1) << "Output module:" << std::endl << updated_mod; + + return updated_mod; +} + +bool FunctionPassNode::SkipFunction(const Function& func) const { + // TODO(@yuchen): will need to revisit in the future + return (func->GetAttr(relay::attr::kCompiler).defined()) || + func->GetAttr(relay::attr::kSkipOptimization, 0) != 0; +} + +Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); + return FunctionPass(pass_func, pass_info); +} + +TVM_REGISTER_NODE_TYPE(FunctionPassNode); + +TVM_REGISTER_GLOBAL("relax.transform.MakeFunctionPass") + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return FunctionPass(pass_func, pass_info); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Function pass: " << info->name << " at the optimization level " + << info->opt_level; + }); + +class DataflowBlockPass; + +/*! + * \brief DataflowBlock-level passes are used to implement various dataflow block + * optimizations for a given Relax IRModule. It fetches one dataflow block at a time + * from the functions in an IRModule, and yields a rewritten DataflowBlock. + * + * Note that the scope of passes at this level is a Relax DataflowBlock. Therefore, + * we cannot modify the global scope Vars and symbolic shape Vars defined inside the dataflow block. + */ +class DataflowBlockPassNode : public tvm::transform::PassNode { + public: + /* \brief The pass meta data.*/ + PassInfo pass_info; + + /*! \brief The packed pass function sketches the real optimization. For + * instance, we can implement a pass that works on a Relax DataflowBlock as a + * `pass_func` and let it run on a given IRModule. The same `pass_func` will + * then be applied on each DataflowBlock in the IRModule. + */ + runtime::TypedPackedFunc pass_func; + + DataflowBlockPassNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; + + PassInfo Info() const override { return pass_info; } + + static constexpr const char* _type_key = "relax.DataflowBlockPass"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockPassNode, PassNode); +}; + +/*! \brief Helper to apply the passed function to dataflow blocks.*/ +class DataflowBlockMutator : public ExprMutator { + public: + DataflowBlockMutator( + runtime::TypedPackedFunc pass_func, + IRModule mod, PassContext pass_ctx) + : pass_func_(pass_func), mod_(mod), pass_ctx_(pass_ctx) {} + + /*! + * \brief Rewrite the DataflowBlockNode with pass_func_ + * + * This function will check that there are no rewrites of the global scope Vars + * and symbolic shape Vars defined inside the dataflow block. + */ + BindingBlock VisitBindingBlock_(const DataflowBlockNode* n) final { + // collect Global Scope Vars and Symbolic Vars inside the DataflowBlock + Map global_scope_vars; + Map symbolic_vars; + for (const Binding& binding : n->bindings) { + Var var = binding->var; + if (const auto* match_cast = binding.as()) { + auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); + for (const tir::VarNode* var : collected_vars) { + symbolic_vars.Set(var->name_hint, GetRef(var)); + } + } + if (!var.as()) { + global_scope_vars.Set(var->name_hint(), var); + } + } + + // apply pass_func_ to the DataflowBlock + DataflowBlock block = GetRef(n); + DataflowBlock updated_block = pass_func_(block, mod_, pass_ctx_); + + // raise error if there are updates of recorded Global Scope Vars and Symbolic Vars + for (const Binding& binding : updated_block->bindings) { + Var var = binding->var; + if (const auto* match_cast = binding.as()) { + auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); + for (const tir::VarNode* var : collected_vars) { + if (symbolic_vars.count(var->name_hint) > 0) { + tir::Var old_var = symbolic_vars[var->name_hint]; + ICHECK(var == old_var.get()) + << "Error: DataflowBlock Pass should not rewrite any Symbolic Var."; + symbolic_vars.erase(var->name_hint); + } + } + } + if (!var.as() && global_scope_vars.count(var->name_hint()) > 0) { + ICHECK(var.same_as(global_scope_vars[var->name_hint()])) + << "Error: DataflowBlock Pass should not rewrite any GlobalScope Var."; + global_scope_vars.erase(var->name_hint()); + } + } + ICHECK(global_scope_vars.empty() && symbolic_vars.empty()) + << "Error: DataflowBlock Pass should not delete any GlobalScope/Symbolic Var."; + + return std::move(updated_block); + } + + private: + class SymbolicVarCollector : public StructInfoVisitor { + public: + static std::unordered_set Collect(const StructInfo& info) { + SymbolicVarCollector collector; + collector.VisitStructInfo(info); + return std::move(collector.symbolic_vars_); + } + + private: + void VisitStructInfoExprField(const PrimExpr& expr) final { + if (const tir::VarNode* sym_var = expr.as()) { + symbolic_vars_.insert(sym_var); + } + } + + private: + std::unordered_set symbolic_vars_; + }; + + runtime::TypedPackedFunc pass_func_; + IRModule mod_; + PassContext pass_ctx_; +}; + +class DataflowBlockPass : public Pass { + public: + /*! + * \brief The constructor + * \param pass_func The packed function which implements a pass. + * \param pass_info The pass info. + */ + TVM_DLL DataflowBlockPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info); + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockPass, Pass, DataflowBlockPassNode); +}; + +DataflowBlockPass::DataflowBlockPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { + auto n = make_object(); + n->pass_func = std::move(pass_func); + n->pass_info = std::move(pass_info); + data_ = std::move(n); +} + +// Perform IRModule -> IRModule transformations at the DataflowBlock level. +IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { + DiagnosticContext previous = DiagnosticContext::Default(mod); + + if (pass_ctx->diag_ctx) { + DiagnosticContext tmp = pass_ctx->diag_ctx.value(); + pass_ctx->diag_ctx = previous; + previous = tmp; + } else { + pass_ctx->diag_ctx = previous; + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block, this is a bug."; + + const PassInfo& pass_info = Info(); + + ICHECK(mod.defined()); + + VLOG_CONTEXT << pass_info->name; + VLOG(0) << "Executing DataflowBlock pass with opt level: " << pass_info->opt_level; + VLOG(1) << "Input module:" << std::endl << mod; + + IRModule updated_mod = mod->ShallowCopy(); + + DataflowBlockMutator dataflow_block_mutator(pass_func, updated_mod, pass_ctx); + std::vector > updates; + for (const auto& it : updated_mod->functions) { + // only picks up relax::Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + Function updated_func = Downcast(dataflow_block_mutator.VisitExpr(func)); + updates.push_back({it.first, updated_func}); + } + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + + pass_ctx->diag_ctx.value().Render(); + pass_ctx->diag_ctx = previous; + + VLOG(1) << "Output module:" << std::endl << updated_mod; + + return updated_mod; +} + +Pass CreateDataflowBlockPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); + return DataflowBlockPass(pass_func, pass_info); +} + +TVM_REGISTER_NODE_TYPE(DataflowBlockPassNode); + +TVM_REGISTER_GLOBAL("relax.transform.MakeDataflowBlockPass") + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return DataflowBlockPass(pass_func, pass_info); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run DataflowBlock pass: " << info->name << " at the optimization level " + << info->opt_level; + }); +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc new file mode 100644 index 000000000000..49ef1d7163f1 --- /dev/null +++ b/src/relax/ir/type.cc @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/type.cc + * \brief Relax type system. + */ +#include +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(ShapeTypeNode); + +ShapeType::ShapeType(int ndim, Span span) { + ObjectPtr n = make_object(); + n->ndim = ndim; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](int ndim, Span span) { + return ShapeType(ndim, span); +}); + +ObjectType::ObjectType(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ObjectTypeNode); + +TVM_REGISTER_GLOBAL("relax.ObjectType").set_body_typed([](Span span) { return ObjectType(span); }); + +DynTensorType::DynTensorType(int ndim, DataType dtype, Span span) { + ObjectPtr n = make_object(); + n->ndim = std::move(ndim); + n->dtype = std::move(dtype); + n->span = span; + data_ = std::move(n); +} + +DynTensorType DynTensorType::CreateUnknownNDim(DataType dtype, Span span) { + ObjectPtr n = make_object(); + n->ndim = -1; + n->dtype = std::move(dtype); + n->span = std::move(span); + return DynTensorType(std::move(n)); +} + +TVM_REGISTER_NODE_TYPE(DynTensorTypeNode); + +TVM_REGISTER_GLOBAL("relax.DynTensorType").set_body_typed([](int ndim, DataType dtype, Span span) { + return DynTensorType(ndim, dtype, span); +}); + +PackedFuncType::PackedFuncType(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PackedFuncTypeNode); + +TVM_REGISTER_GLOBAL("relax.PackedFuncType").set_body_typed([](Span span) { + return PackedFuncType(span); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc new file mode 100644 index 000000000000..6d49bea6b656 --- /dev/null +++ b/src/relax/op/image/resize.cc @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file resize.cc + * \brief Image resize operators. + */ + +#include "resize.h" + +#include + +namespace tvm { +namespace relax { + +/* relax.resize2d */ +TVM_REGISTER_NODE_TYPE(Resize2DAttrs); + +Expr resize2d(Expr data, Expr size, Array roi, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, double extrapolation_value, DataType out_dtype) { + ObjectPtr attrs = make_object(); + attrs->roi = std::move(roi); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->coordinate_transformation_mode = std::move(coordinate_transformation_mode); + attrs->rounding_method = std::move(rounding_method); + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; + attrs->extrapolation_value = extrapolation_value; + attrs->out_dtype = out_dtype; + + static const Op& op = Op::Get("relax.image.resize2d"); + return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.image.resize2d").set_body_typed(resize2d); + +StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1 && call->args.size() != 2) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Resize2D expects either one or two arguments, while the given number of arguments is " + << call->args.size()); + } + + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* size_sinfo = GetStructInfoAs(call->args[1]); + const auto* size_value = call->args[1].as(); + if (data_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Resize2D expects the input data to be a Tensor, while the given data is " + << call->args[0]->GetTypeKey()); + } + if (size_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Resize2D expects the given output image size to be a Shape, while the given one is " + << call->args[1]->GetTypeKey()); + } + if (size_sinfo->ndim != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "Resize2D expects the given output image size to " + "be a 2-dim shape, while the given one has ndim " + << size_sinfo->ndim); + } + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + + DataType out_dtype = attrs->out_dtype.is_void() ? data_sinfo->dtype : attrs->out_dtype; + + Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, GetRef(data_sinfo), data_layout); + if (!data_shape.defined() || size_value == nullptr) { + return TensorStructInfo(out_dtype, data_layout.ndim()); + } + + Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + Array out_NCHW_shape(data_NCHW_shape); + out_NCHW_shape.Set(2, size_value->values[0]); + out_NCHW_shape.Set(3, size_value->values[1]); + + Array out_shape = data2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), out_dtype); +} + +InferLayoutOutput InferLayoutResize2d(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name(); + return InferLayoutOutput({layout, InitialNLayout(call->args[1])}, {layout}, Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.image.resize2d") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("size", "Shape", "The output image shape.") + .set_attr("FInferStructInfo", InferStructInfoResize2D) + .set_attr("FRelaxInferLayout", InferLayoutResize2d) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h new file mode 100644 index 000000000000..085a1cbc5d5f --- /dev/null +++ b/src/relax/op/image/resize.h @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file resize.h + * \brief The functions to make Relax image resize operator calls. + */ + +#ifndef TVM_RELAX_OP_IMAGE_RESIZE_H_ +#define TVM_RELAX_OP_IMAGE_RESIZE_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief Image resize2d operator. */ +Expr resize2d(Expr data, Expr size, Array roi, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, double extrapolation_value, DataType out_dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_IMAGE_RESIZE_H_ diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc new file mode 100644 index 000000000000..c27e8b68d0bc --- /dev/null +++ b/src/relax/op/nn/attention.cc @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "attention.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.nn.attention */ +TVM_REGISTER_NODE_TYPE(AttentionAttrs); + +Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale) { + ObjectPtr attrs = make_object(); + attrs->scale = scale; + if (bias.defined()) { + return Call(Op::Get("relax.nn.attention_bias"), + {std::move(query), std::move(key), std::move(value), std::move(bias.value())}, + Attrs(attrs), {}); + } + return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value)}, + Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention); + +StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo q_sinfo = input_sinfo[0]; + TensorStructInfo k_sinfo = input_sinfo[1]; + TensorStructInfo v_sinfo = input_sinfo[2]; + auto diag_dim = [&](TensorStructInfo sinfo, String name) { + if (sinfo->ndim != 4) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The " << name << " should have 4 dimension, namely " + << "[batch size, sequence length, number of heads, dimension of heads]."); + } + }; + diag_dim(q_sinfo, "query"); + diag_dim(k_sinfo, "key"); + diag_dim(v_sinfo, "value"); + const ShapeExprNode* q_shape = q_sinfo->shape.as(); + const ShapeExprNode* k_shape = k_sinfo->shape.as(); + const ShapeExprNode* v_shape = v_sinfo->shape.as(); + PrimExpr num_batches = q_shape->values[0]; + PrimExpr num_queries = q_shape->values[1]; + PrimExpr num_heads = q_shape->values[2]; + PrimExpr head_dim = q_shape->values[3]; + PrimExpr num_keys = k_shape->values[1]; + PrimExpr head_dim_value = v_shape->values[3]; + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + auto diag_equal = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) { + if (analyzer->CanProve(v1 != v2)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The " << m1 << " " << dim << " and the " << m2 << " " << dim + << " should be the same. However, the " << dim << " of " << m1 << " is " + << v1 << " while the " << dim << " of " << m2 << " is " << v2); + } + }; + diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size"); + diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size"); + diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads"); + diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads"); + diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length"); + diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads"); + + if (input_sinfo.size() == 4) { + TensorStructInfo bias_sinfo = input_sinfo[3]; + const ShapeExprNode* bias_shape = bias_sinfo->shape.as(); + if (bias_sinfo->ndim == 4) { + diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch size"); + diag_equal(num_heads, bias_shape->values[1], "query", "bias", "number of heads"); + diag_equal(num_queries, bias_shape->values[2], "query", "bias", "sequence length"); + diag_equal(num_keys, bias_shape->values[3], "key", "bias", "sequence length"); + } else if (bias_sinfo->ndim == 3) { + diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch size"); + diag_equal(num_queries, bias_shape->values[1], "query", "bias", "sequence length"); + diag_equal(num_keys, bias_shape->values[2], "key", "bias", "sequence length"); + } else if (bias_sinfo->ndim == 2) { + diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch size"); + diag_equal(num_keys, bias_shape->values[1], "key", "bias", "sequence length"); + } else { + ctx->ReportFatal(Diagnostic::Error(call) + << "The bias should have 2, 3 or 4 dimensions." + << "However, the bias input has " << bias_sinfo->ndim << " dimensions."); + } + } + + Array output_shape = {num_batches, num_queries, num_heads, head_dim_value}; + return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.nn.attention") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("query", "Tensor", "The input queries tensor.") + .add_argument("key", "Tensor", "The input keys tensor.") + .add_argument("value", "Tensor", "The input values tensor.") + .set_attr("FInferStructInfo", InferStructInfoAttention); + +TVM_REGISTER_OP("relax.nn.attention_bias") + .set_attrs_type() + .set_num_inputs(4) + .add_argument("query", "Tensor", "The input queries tensor.") + .add_argument("key", "Tensor", "The input keys tensor.") + .add_argument("value", "Tensor", "The input values tensor.") + .add_argument("bias", "Tensor", "The input bias tensor.") + .set_attr("FInferStructInfo", InferStructInfoAttention); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h new file mode 100644 index 000000000000..7eda30b40813 --- /dev/null +++ b/src/relax/op/nn/attention.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file attention.h + * \brief The functions to make Relax attention operator calls. + */ + +#ifndef TVM_RELAX_OP_NN_ATTENTION_H_ +#define TVM_RELAX_OP_NN_ATTENTION_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief fused multi head attention */ +Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_NN_ATTENTION_H_ diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc new file mode 100644 index 000000000000..ae84409c2a14 --- /dev/null +++ b/src/relax/op/nn/convolution.cc @@ -0,0 +1,498 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/op/nn/convolution.cc + * \brief Convolution operators + */ + +#include "convolution.h" + +#include + +namespace tvm { +namespace relax { + +/* relax.nn.conv1d */ +TVM_REGISTER_NODE_TYPE(Conv1DAttrs); + +Expr conv1d(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, String data_layout, String kernel_layout, + Optional out_layout, DataType out_dtype) { + padding = GetCompletePadding1D(std::move(padding)); + + CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " + "the given number of groups is " + << groups; + CHECK_EQ(strides.size(), 1) + << "The input strides length is expected to be 1. However, the given strides is " << strides; + CHECK_EQ(dilation.size(), 1) + << "The input dilation length is expected to be 1. However, the given dilation is " + << dilation; + return MakeConv(std::move(data), std::move(weight), std::move(strides), + std::move(padding), std::move(dilation), groups, data_layout, + std::move(kernel_layout), out_layout.value_or(data_layout), + out_dtype, /*op_name=*/"relax.nn.conv1d"); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.conv1d").set_body_typed(conv1d); + +StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo weight_sinfo = input_sinfo[1]; + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCW] = CheckTensorLayout(call, ctx, attrs->data_layout, // + /*tgt_layout=*/"NCW", // + /*tensor_name=*/"data"); + auto [weight_layout, weight2OIW] = CheckTensorLayout(call, ctx, attrs->kernel_layout, // + /*tgt_layout=*/"OIW", // + /*tensor_name=*/"kernel"); + auto [out_layout, out2NCW] = CheckTensorLayout(call, ctx, attrs->out_layout, // + /*tgt_layout=*/"NCW", // + /*tensor_name=*/"output"); + + Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + Optional weight_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); + + DataType out_dtype = attrs->out_dtype.is_void() + ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) + : attrs->out_dtype; + if (!data_shape.defined() || !weight_shape.defined()) { + return TensorStructInfo(out_dtype, out_layout.ndim()); + } + + Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); + Array weight_OIW_shape = weight2OIW.ForwardShape(weight_shape.value()->values); + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr input_channel_data = data_NCW_shape[1]; + PrimExpr input_channel_kernel = weight_OIW_shape[1]; + if (analyzer->CanProve(input_channel_data != input_channel_kernel * attrs->groups)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "The channel size of the data should equal to the product of input channel size of the " + "weight and the number of groups. However, the data channel size is " + << input_channel_data << " while the weight input channel size and number of groups are " + << input_channel_kernel << " and " << attrs->groups); + } else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel * attrs->groups)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + if (analyzer->CanProve(floormod(weight_OIW_shape[0], attrs->groups) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Conv1d expects the number of output channels to be divisible by the " + "number of groups. However, the number of output channels is " + << weight_OIW_shape[0] << " while the number of groups is " << attrs->groups); + } else if (!analyzer->CanProveEqual(floormod(weight_OIW_shape[0], attrs->groups), 0)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + + PrimExpr input_w = data_NCW_shape[2]; + PrimExpr kernel_w = weight_OIW_shape[2]; + PrimExpr padding_w = attrs->padding[0] + attrs->padding[1]; + + std::vector out_NCW_shape; + out_NCW_shape.resize(3); + out_NCW_shape[0] = data_NCW_shape[0]; + out_NCW_shape[1] = weight_OIW_shape[0]; + + PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w - 1) - 1; + out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1); + + Array out_shape = out2NCW.BackwardShape(out_NCW_shape); + return TensorStructInfo(ShapeExpr(out_shape), out_dtype); +} + +InferLayoutOutput InferLayoutConv1d(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + const auto& it = desired_layouts.find("relax.nn.conv1d"); + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision data_layout, weight_layout, output_layout; + ObjectPtr new_attrs = make_object(*attrs); + + if (it != desired_layouts.end()) { + // We have a desired layout for conv1d. + Layout desired_data_layout = (*it).second[0]; + Layout desired_weight_layout = (*it).second[1]; + Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; + ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) + << "Axis swap only"; + ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) + << "Axis swap only"; + data_layout = TransposeLike(InitialLayout(3), attrs->data_layout, desired_data_layout); + weight_layout = TransposeLike(InitialLayout(3), attrs->kernel_layout, desired_weight_layout); + output_layout = TransposeLike(InitialLayout(3), attrs->out_layout, desired_output_layout); + new_attrs->data_layout = (*it).second[0]; + new_attrs->kernel_layout = (*it).second[1]; + new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + } else { + // We don't have a desired layout for conv1d. + // We can just propagate the layout from the input. + data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); + output_layout = data_layout; + new_attrs->data_layout = + TransposeLike(attrs->data_layout, InitialLayout(3), data_layout->layout).name(); + new_attrs->kernel_layout = + TransposeLike(attrs->kernel_layout, InitialLayout(3), weight_layout->layout).name(); + new_attrs->out_layout = + TransposeLike(attrs->out_layout, InitialLayout(3), output_layout->layout).name(); + } + return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); +} + +Call InferMixedPrecisionConv1d(const Call& call, const DataType& out_dtype) { + const auto* conv1d_attrs = call->attrs.as(); + return Downcast(conv1d(call->args[0], call->args[1], conv1d_attrs->strides, + conv1d_attrs->padding, conv1d_attrs->dilation, conv1d_attrs->groups, + conv1d_attrs->data_layout, conv1d_attrs->kernel_layout, + conv1d_attrs->out_layout, out_dtype)); +} + +TVM_REGISTER_OP("relax.nn.conv1d") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoConv1d) + .set_attr("FRelaxInferLayout", InferLayoutConv1d) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) + .set_attr("FInferMixedPrecision", InferMixedPrecisionConv1d); + +/* relax.nn.conv2d */ +TVM_REGISTER_NODE_TYPE(Conv2DAttrs); + +Expr conv2d(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, String data_layout, String kernel_layout, + Optional out_layout, DataType out_dtype) { + padding = GetCompletePadding2D(std::move(padding)); + if (strides.size() == 1) { + strides.push_back(strides[0]); + } + if (dilation.size() == 1) { + dilation.push_back(dilation[0]); + } + + CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " + "the given number of groups is " + << groups; + CHECK_EQ(strides.size(), 2) + << "The input strides length is expected to be 2. However, the given strides is " << strides; + CHECK_EQ(dilation.size(), 2) + << "The input dilation length is expected to be 2. However, the given dilation is " + << dilation; + return MakeConv(std::move(data), std::move(weight), std::move(strides), + std::move(padding), std::move(dilation), groups, data_layout, + std::move(kernel_layout), out_layout.value_or(data_layout), + out_dtype, /*op_name=*/"relax.nn.conv2d"); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.conv2d").set_body_typed(conv2d); + +StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo weight_sinfo = input_sinfo[1]; + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->data_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + auto [weight_layout, weight2OIHW] = CheckTensorLayout(call, ctx, attrs->kernel_layout, // + /*tgt_layout=*/"OIHW", // + /*tensor_name=*/"kernel"); + auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, attrs->out_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"output"); + + Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + Optional weight_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); + + DataType out_dtype = attrs->out_dtype.is_void() + ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) + : attrs->out_dtype; + if (!data_shape.defined() || !weight_shape.defined()) { + return TensorStructInfo(out_dtype, out_layout.ndim()); + } + + Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + Array weight_OIHW_shape = weight2OIHW.ForwardShape(weight_shape.value()->values); + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr input_channel_data = data_NCHW_shape[1]; + PrimExpr input_channel_kernel = weight_OIHW_shape[1]; + if (analyzer->CanProve(input_channel_data != input_channel_kernel * attrs->groups)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "The channel size of the data should equal to the product of input channel size of the " + "weight and the number of groups. However, the data channel size is " + << input_channel_data << " while the weight input channel size and number of groups are " + << input_channel_kernel << " and " << attrs->groups); + } else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel * attrs->groups)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + if (analyzer->CanProve(floormod(weight_OIHW_shape[0], attrs->groups) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Conv2d expects the number of output channels to be divisible by the " + "number of groups. However, the number of output channels is " + << weight_OIHW_shape[0] << " while the number of groups is " << attrs->groups); + } else if (!analyzer->CanProveEqual(floormod(weight_OIHW_shape[0], attrs->groups), 0)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + + PrimExpr input_h = data_NCHW_shape[2]; + PrimExpr input_w = data_NCHW_shape[3]; + PrimExpr kernel_h = weight_OIHW_shape[2]; + PrimExpr kernel_w = weight_OIHW_shape[3]; + PrimExpr padding_h = attrs->padding[0] + attrs->padding[2]; + PrimExpr padding_w = attrs->padding[1] + attrs->padding[3]; + + std::vector out_NCHW_shape; + out_NCHW_shape.resize(4); + out_NCHW_shape[0] = data_NCHW_shape[0]; + out_NCHW_shape[1] = weight_OIHW_shape[0]; + + PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w - 1) - 1; + out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[0]) + 1); + out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1); + + Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), out_dtype); +} + +InferLayoutOutput InferLayoutConv2d(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + const auto& it = desired_layouts.find("relax.nn.conv2d"); + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision data_layout, weight_layout, output_layout; + ObjectPtr new_attrs = make_object(*attrs); + + if (it != desired_layouts.end()) { + // We have a desired layout for conv2d. + Layout desired_data_layout = (*it).second[0]; + Layout desired_weight_layout = (*it).second[1]; + Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; + ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) + << "Axis swap only"; + ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) + << "Axis swap only"; + data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, desired_data_layout); + weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, desired_weight_layout); + output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, desired_output_layout); + new_attrs->data_layout = (*it).second[0]; + new_attrs->kernel_layout = (*it).second[1]; + new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + } else { + // We don't have a desired layout for conv2d. + // We can just propagate the layout from the input. + data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); + output_layout = data_layout; + new_attrs->data_layout = + TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name(); + new_attrs->kernel_layout = + TransposeLike(attrs->kernel_layout, InitialLayout(4), weight_layout->layout).name(); + new_attrs->out_layout = + TransposeLike(attrs->out_layout, InitialLayout(4), output_layout->layout).name(); + } + return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); +} + +Call InferMixedPrecisionConv2d(const Call& call, const DataType& out_dtype) { + const auto* conv2d_attrs = call->attrs.as(); + return Downcast(conv2d(call->args[0], call->args[1], conv2d_attrs->strides, + conv2d_attrs->padding, conv2d_attrs->dilation, conv2d_attrs->groups, + conv2d_attrs->data_layout, conv2d_attrs->kernel_layout, + conv2d_attrs->out_layout, out_dtype)); +} + +TVM_REGISTER_OP("relax.nn.conv2d") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoConv2d) + .set_attr("FRelaxInferLayout", InferLayoutConv2d) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) + .set_attr("FInferMixedPrecision", InferMixedPrecisionConv2d); + +/* relax.nn.conv2d_transpose */ +TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); + +Expr conv2d_transpose(Expr data, Expr weight, Array strides, Array padding, + Array output_padding, Array dilation, int groups, + String data_layout, String kernel_layout, Optional out_layout, + DataType out_dtype) { + padding = GetCompletePadding2D(std::move(padding)); + if (output_padding.size() == 1) { + output_padding.push_back(output_padding[0]); + } + if (strides.size() == 1) { + strides.push_back(strides[0]); + } + if (dilation.size() == 1) { + dilation.push_back(dilation[0]); + } + + CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " + "the given number of groups is " + << groups; + CHECK_EQ(output_padding.size(), 2) << "The input output_padding length is expected to be 4. " + "However, the given output_padding is " + << output_padding; + CHECK_EQ(strides.size(), 2) + << "The input strides length is expected to be 2. However, the given strides is " << strides; + CHECK_EQ(dilation.size(), 2) + << "The input dilation length is expected to be 2. However, the given dilation is " + << dilation; + + auto attrs = make_object(); + attrs->strides = ConvertIntImmToInt64(strides); + attrs->padding = ConvertIntImmToInt64(padding); + attrs->output_padding = ConvertIntImmToInt64(output_padding); + attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->groups = groups; + attrs->data_layout = data_layout; + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout.value_or(data_layout)); + attrs->out_dtype = std::move(out_dtype); + const Op& op = Op::Get("relax.nn.conv2d_transpose"); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.conv2d_transpose").set_body_typed(conv2d_transpose); + +StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo weight_sinfo = input_sinfo[1]; + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->data_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + auto [weight_layout, weight2IOHW] = CheckTensorLayout(call, ctx, attrs->kernel_layout, // + /*tgt_layout=*/"IOHW", // + /*tensor_name=*/"kernel"); + auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, attrs->out_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"output"); + + Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + Optional weight_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); + + DataType out_dtype = attrs->out_dtype.is_void() + ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) + : attrs->out_dtype; + if (!data_shape.defined() || !weight_shape.defined()) { + return TensorStructInfo(out_dtype, out_layout.ndim()); + } + + Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + Array weight_IOHW_shape = weight2IOHW.ForwardShape(weight_shape.value()->values); + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr input_channel_data = data_NCHW_shape[1]; + PrimExpr input_channel_kernel = weight_IOHW_shape[0]; + if (analyzer->CanProve(input_channel_data != input_channel_kernel)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Conv2dTranspose expects the channel size of the data should equal to the input channel " + "size of the weight. However, the data channel size is " + << input_channel_data << " while the weight input channel size is " + << input_channel_kernel); + } else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + if (analyzer->CanProve(floormod(input_channel_kernel, attrs->groups) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Conv2dTranspose expects the number of input channels to be divisible by " + "the number of groups. However, the number of input channels is " + << input_channel_kernel << " while the number of groups is " << attrs->groups); + } else if (!analyzer->CanProveEqual(floormod(input_channel_kernel, attrs->groups), 0)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + if (analyzer->CanProve(attrs->output_padding[0]->value >= attrs->strides[0]->value || + attrs->output_padding[1]->value >= attrs->strides[1]->value)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Conv2dTranspose expects the output padding less than the strides, but the " + "output padding is" + << attrs->output_padding << " while the strides are" << attrs->strides); + } else if (!analyzer->CanProve(attrs->output_padding[0]->value < attrs->strides[0]->value && + attrs->output_padding[1]->value < attrs->strides[1]->value)) { + // Todo(relax-team): Trust the input padding at this moment, and revisit + // this condition with runtime shape check + } + + PrimExpr input_h = data_NCHW_shape[2]; + PrimExpr input_w = data_NCHW_shape[3]; + PrimExpr kernel_h = weight_IOHW_shape[2]; + PrimExpr kernel_w = weight_IOHW_shape[3]; + PrimExpr padding_h = attrs->padding[0] + attrs->padding[2]; + PrimExpr padding_w = attrs->padding[1] + attrs->padding[3]; + + std::vector out_NCHW_shape; + out_NCHW_shape.resize(4); + out_NCHW_shape[0] = data_NCHW_shape[0]; + out_NCHW_shape[1] = weight_IOHW_shape[1] * attrs->groups; + + PrimExpr out_h = (input_h - 1) * attrs->strides[0] - padding_h + + attrs->dilation[0] * (kernel_h - 1) + attrs->output_padding[0] + 1; + PrimExpr out_w = (input_w - 1) * attrs->strides[1] - padding_w + + attrs->dilation[1] * (kernel_w - 1) + attrs->output_padding[1] + 1; + out_NCHW_shape[2] = analyzer->Simplify(out_h); + out_NCHW_shape[3] = analyzer->Simplify(out_w); + + Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), out_dtype); +} + +// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for conv2d_transpose +// and unit test for mixed_precision +TVM_REGISTER_OP("relax.nn.conv2d_transpose") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoConv2dTranspose); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h new file mode 100644 index 000000000000..833e730ee949 --- /dev/null +++ b/src/relax/op/nn/convolution.h @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file convolution.h + * \brief The functions to make Relax neural network convolution operator calls. + */ + +#ifndef TVM_RELAX_OP_NN_CONVOLUTION_H_ +#define TVM_RELAX_OP_NN_CONVOLUTION_H_ + +#include + +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +template +inline Expr MakeConv(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, String data_layout, String kernel_layout, + String out_layout, DataType out_dtype, std::string op_name) { + auto attrs = make_object(); + attrs->strides = ConvertIntImmToInt64(strides); + attrs->padding = ConvertIntImmToInt64(padding); + attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->groups = groups; + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + const Op& op = Op::Get(op_name); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + +/*! \brief 1D convolution */ +Expr conv1d(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, String data_layout, String kernel_layout, + Optional out_layout, DataType out_dtype); + +/*! \brief 2D convolution */ +Expr conv2d(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, String data_layout, String kernel_layout, + Optional out_layout, DataType out_dtype); + +/*! + * \brief Two dimensional transposed convolution operator. + * + * This operator is intended to be the backward operator of conv2d. It can be used to calculate the + * gradient of the result of conv2d w.r.t. the input of conv2d. + */ +Expr conv2d_transpose(Expr data, Expr weight, Array strides, Array padding, + Array output_padding, Array dilation, int groups, + String data_layout, String kernel_layout, Optional out_layout, + DataType out_dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_NN_CONVOLUTION_H_ diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc new file mode 100644 index 000000000000..c3e18f8e3b19 --- /dev/null +++ b/src/relax/op/nn/nn.cc @@ -0,0 +1,496 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "nn.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.nn.relu */ +RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(relu, "nn.relu", /*require_float_dtype=*/false); + +/* relax.nn.gelu */ +RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu, "nn.gelu", /*require_float_dtype=*/true); + +/* relax.nn.silu */ +RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(silu, "nn.silu", /*require_float_dtype=*/true); + +/* relax.nn.softmax */ +TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); + +Expr softmax(Expr data, int axis) { + auto attrs = make_object(); + attrs->axis = axis; + static const Op& op = Op::Get("relax.nn.softmax"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.softmax").set_body_typed(softmax); + +StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (data_sinfo->IsUnknownNdim()) { + return data_sinfo; + } + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input tensor to have float " + "dtype. However, the given input dtype is " + << data_sinfo->dtype); + } + const auto* attrs = call->attrs.as(); + NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); + + return data_sinfo; +} + +InferLayoutOutput InferLayoutSoftmax(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axis = FindAxis(layout->layout, attrs->axis); + return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.nn.softmax") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoSoftmax) + .set_attr("FRelaxInferLayout", InferLayoutSoftmax); + +/* relax.nn.log_softmax */ +Expr log_softmax(Expr data, int axis) { + auto attrs = make_object(); + attrs->axis = axis; + static const Op& op = Op::Get("relax.nn.log_softmax"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.log_softmax").set_body_typed(log_softmax); + +TVM_REGISTER_OP("relax.nn.log_softmax") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoSoftmax); + +bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, + const Array& input_sinfo, Array axes) { + Op op = Downcast(call->op); + int n_input = op->arguments.size(); + + TensorStructInfo data_sinfo = input_sinfo[0]; + + std::vector axes_non_neg; + if (!data_sinfo->IsUnknownNdim()) { + axes_non_neg = NormalizeAxes(call, ctx, data_sinfo->ndim, axes); + } + int n_axis = axes.size(); + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + ctx->ReportFatal( + Diagnostic::Error(call) + << op << " requires the input data to have float dtype. However, the given data dtype is " + << data_sinfo->dtype); + } + for (int i = 1; i < n_input; ++i) { + if (input_sinfo[i]->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << op + << " requires all the input tensors to have the same dtype. However, the " + << op->arguments[i]->name << " has dtype " << input_sinfo[i]->dtype + << " which is other than the input data's dtype " << data_sinfo->dtype); + } else if (input_sinfo[i]->ndim != n_axis) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " requires the input " << op->arguments[i]->name + << " to have as many dimensions as the length of input axes. However, the " + "given one has ndim " + << input_sinfo[i]->ndim << ", which is other than the length of axes " + << n_axis); + } + } + + std::vector> axis_lengths; + axis_lengths.reserve(n_input); + if (const auto* data_shape = data_sinfo->shape.as()) { + std::vector lengths; + lengths.reserve(n_axis); + for (int d = 0; d < n_axis; ++d) { + lengths.push_back(data_shape->values[axes_non_neg[d]]); + } + axis_lengths.push_back(lengths); + } + for (int i = 1; i < n_input; ++i) { + if (const auto* shape = input_sinfo[i]->shape.as()) { + axis_lengths.push_back(shape->values); + } + } + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + for (int i = 1; i < static_cast(axis_lengths.size()); ++i) { + for (int d = 0; d < n_axis; ++d) { + if (analyzer->CanProve(axis_lengths[0][d] != axis_lengths[i][d])) { + ctx->ReportFatal(Diagnostic::Error(call) + << op + << " requires the input gamma, beta, etc., to have size same as the " + "lengths of the data on the given axes. However, there exists " + << axis_lengths[0] << " and " << axis_lengths[i] << " that are unequal."); + } else if (!analyzer->CanProveEqual(axis_lengths[0][d], axis_lengths[i][d])) { + return true; + } + } + } + return false; +} + +/* relax.nn.batch_norm */ +TVM_REGISTER_NODE_TYPE(BatchNormAttrs); + +Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // + int axis, double epsilon, bool center, bool scale) { + ObjectPtr attrs = make_object(); + attrs->axis = axis; + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + + static const Op& op = Op::Get("relax.nn.batch_norm"); + return Call(op, + {std::move(data), std::move(gamma), std::move(beta), std::move(moving_mean), + std::move(moving_var)}, + Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm); + +StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, {attrs->axis}); + + DataType dtype = input_sinfo[0]->dtype; + if (unknown_shape) { + return TupleStructInfo({TensorStructInfo(dtype, input_sinfo[0]->ndim), + TensorStructInfo(dtype, /*ndim=*/1), + TensorStructInfo(dtype, /*ndim=*/1)}); + } else { + return TupleStructInfo({input_sinfo[0], input_sinfo[3], input_sinfo[4]}); + } +} + +InferLayoutOutput InferLayoutBatchNorm(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + std::vector initial_layouts; + for (size_t i = 0; i < 5; ++i) { + const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); + } + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axis = FindAxis(layout->layout, attrs->axis); + return InferLayoutOutput( + {layout, initial_layouts[1], initial_layouts[2], initial_layouts[3], initial_layouts[4]}, + {{layout, initial_layouts[3], initial_layouts[4]}}, Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.nn.batch_norm") + .set_attrs_type() + .set_num_inputs(5) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .add_argument("moving_mean", "Tensor", "Running mean of input.") + .add_argument("moving_var", "Tensor", "Running variance of input.") + .set_attr("FInferStructInfo", InferStructInfoBatchNorm) + .set_attr("FRelaxInferLayout", InferLayoutBatchNorm); + +/* relax.nn.layer_norm */ +TVM_REGISTER_NODE_TYPE(LayerNormAttrs); + +Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, + bool scale) { + ObjectPtr attrs = make_object(); + attrs->axes = std::move(axes); + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + + static const Op& op = Op::Get("relax.nn.layer_norm"); + return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.layer_norm").set_body_typed(layer_norm); + +StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes); + + return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim) + : input_sinfo[0]; +} + +InferLayoutOutput InferLayoutLayerNorm(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + std::vector initial_layouts; + for (size_t i = 0; i < 3; ++i) { + const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); + } + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + ObjectPtr new_attrs = make_object(*attrs); + std::vector new_axis; + for (const auto& axis : attrs->axes) { + new_axis.push_back(FindAxis(layout->layout, axis->value)); + } + new_attrs->axes = std::move(new_axis); + return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout}, + Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.nn.layer_norm") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_attr("FInferStructInfo", InferStructInfoLayerNorm) + .set_attr("FRelaxInferLayout", InferLayoutLayerNorm) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.nn.group_norm */ +TVM_REGISTER_NODE_TYPE(GroupNormAttrs); + +Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, + Array axes, double epsilon, bool center, bool scale) { + ObjectPtr attrs = make_object(); + attrs->num_groups = num_groups; + attrs->channel_axis = channel_axis; + attrs->axes = std::move(axes); + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + + static const Op& op = Op::Get("relax.nn.group_norm"); + return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.group_norm").set_body_typed(group_norm); + +StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { + Op op = Downcast(call->op); + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + + TensorStructInfo data_sinfo = input_sinfo[0]; + int channel_axis = -1; + if (!data_sinfo->IsUnknownNdim()) { + channel_axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->channel_axis); + std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes); + // channel_axis must be in axes. + if (std::find(axes.begin(), axes.end(), channel_axis) != axes.end()) { + ctx->ReportFatal(Diagnostic::Error(call) + << op + << " expects that channel_axis must not be in axes, but got channel_axis: " + << channel_axis << ", axes: " << attrs->axes); + } + } + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that data must be float, but got " << data_sinfo->dtype); + } + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape != nullptr && channel_axis != -1 && + analyzer->CanProve(floormod(data_shape->values[channel_axis], attrs->num_groups) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that the size of channel_axis must be divisible by " + << attrs->num_groups << ", but got " << data_shape->values[channel_axis]); + } + for (int i = 1; i < static_cast(op->arguments.size()); ++i) { + if (input_sinfo[i]->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have the same dtype, but got " + << input_sinfo[i]->dtype << " and " << data_sinfo->dtype); + } else if (input_sinfo[i]->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have ndim=1, but got " + << input_sinfo[i]->ndim); + } else if (channel_axis != -1) { + const auto* shape = input_sinfo[i]->shape.as(); + if (shape != nullptr && data_shape != nullptr) { + PrimExpr channel_size = data_shape->values[channel_axis]; + PrimExpr input_size = shape->values[0]; + if (analyzer->CanProve(channel_size != input_size)) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that the size of input " << i + << " must be equal to the size of channel_axis, but got " << input_size + << " and " << channel_size); + } + } + } + } + return data_sinfo; +} + +InferLayoutOutput InferLayoutGroupNorm(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + std::vector initial_layouts; + for (size_t i = 0; i < 3; ++i) { + const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); + } + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + ObjectPtr new_attrs = make_object(*attrs); + std::vector new_axes; + for (const auto& axis : attrs->axes) { + new_axes.push_back(FindAxis(layout->layout, axis->value)); + } + new_attrs->axes = std::move(new_axes); + new_attrs->channel_axis = FindAxis(layout->layout, attrs->channel_axis); + return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout}, + Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.nn.group_norm") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_attr("FInferStructInfo", InferStructInfoGroupNorm) + .set_attr("FRelaxInferLayout", InferLayoutGroupNorm) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.nn.dropout */ +TVM_REGISTER_NODE_TYPE(DropoutAttrs); + +Expr dropout(Expr data, double rate) { + ObjectPtr attrs = make_object(); + attrs->rate = rate; + + static const Op& op = Op::Get("relax.nn.dropout"); + return Call(op, {std::move(data)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.dropout").set_body_typed(dropout); + +StructInfo InferStructInfoDropout(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + return TupleStructInfo({data_sinfo, data_sinfo}); +} + +TVM_REGISTER_OP("relax.nn.dropout") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input to which dropout will be applied.") + .set_attr("FInferStructInfo", InferStructInfoDropout) + .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.nn.cross_entropy_with_logits */ +StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo pred_sinfo = input_sinfo[0]; + TensorStructInfo label_sinfo = input_sinfo[1]; + + // infer dtype + DataType dtype = InferBinaryArithOpOutDtype(call, ctx, pred_sinfo, label_sinfo); + + // infer ndim + if (!pred_sinfo->IsUnknownNdim() && !label_sinfo->IsUnknownNdim() && + pred_sinfo->ndim != label_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "CrossEntropy requires predictions and labels to have the same ndim. " + "However, the ndim of predictions is " + << pred_sinfo->ndim << " while the ndim of labels is " << label_sinfo->ndim); + } + + Optional> pred_shape_value; + if (pred_sinfo->shape.defined()) { + pred_shape_value = GetStructInfoAs(pred_sinfo->shape.value())->values; + } + + Optional> label_shape_value; + if (label_sinfo->shape.defined()) { + label_shape_value = GetStructInfoAs(label_sinfo->shape.value())->values; + } + + if (pred_shape_value.defined() && label_shape_value.defined()) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + for (size_t i = 0; i < pred_shape_value.value().size(); ++i) { + if (analyzer->CanProve(pred_shape_value.value()[i] != label_shape_value.value()[i])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "CrossEntropy requires the predictions and labels to have " + "the same shape. However, the shape of predictions at dim " + << i << " is" << pred_shape_value.value()[i] + << " while the shape of labels at this dim is " + << label_shape_value.value()[i]); + } + } + } + return TensorStructInfo(ShapeExpr(Array()), dtype); +} + +Expr cross_entropy_with_logits(Expr predictions, Expr labels) { + static const Op& op = Op::Get("relax.nn.cross_entropy_with_logits"); + return Call(op, {std::move(predictions), std::move(labels)}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.cross_entropy_with_logits") + .set_body_typed(cross_entropy_with_logits); + +TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") + .set_num_inputs(2) + .add_argument("predictions", "Tensor", "The predictions.") + .add_argument("labels", "Tensor", "The labels.") + .set_attr("FInferStructInfo", InferStructInfoCrossEntropy); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h new file mode 100644 index 000000000000..f578f89346f7 --- /dev/null +++ b/src/relax/op/nn/nn.h @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file nn.h + * \brief The functions to make Relax neural network operator calls. + */ + +#ifndef TVM_RELAX_OP_NN_NN_H_ +#define TVM_RELAX_OP_NN_NN_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Quick helper macro to + * - expose a make-function interface which construct the call node. + * - register op to the registry. + * \param OpName The name of operator to register. + * \param OpRegName The identifier of the operator in the registry. + * \param RequireFloatDtype A boolean indicating if the input is required to have float dtype. + */ +#define RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(OpName, OpRegName, RequireFloatDtype) \ + RELAX_REGISTER_UNARY_OP(OpRegName).set_attr( \ + "FInferStructInfo", InferStructInfoUnaryArith); \ + RELAX_UNARY_OP_INTERFACE(OpName, OpRegName); + +/*! \brief Rectified linear unit. */ +Expr relu(Expr data); + +/*! \brief Gaussian Error Linear Units function. */ +Expr gelu(Expr data); + +/*! \brief Sigmoid Linear Unit function. */ +Expr silu(Expr data); + +/*! \brief Softmax function. */ +Expr softmax(Expr data, int axis); + +/*! \brief LogSoftmax function. */ +Expr log_softmax(Expr data, int axis); + +/*! \brief Compute batch normalization. */ +Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // + int axis, double epsilon, bool center, bool scale); + +/*! \brief Compute layer normalization. */ +Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, + bool scale); + +/*! \brief Compute group normalization. */ +Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, + Array axes, double epsilon, bool center, bool scale); + +/*! + * \brief Applies the dropout operation to the input tensor. + * \param data The input data to the operator. + * \param rate The probability for an element to be reset to 0. + * \return A Tuple of two tensors. + * The first one is the original tensor and the second one is a + * mask tensor (1.0 where element not dropped, 0.0 where dropped) + */ +Expr dropout(Expr data, double rate); + +/*! \brief CrossEntropy with logits. */ +Expr cross_entropy_with_logits(Expr predictions, Expr labels); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_NN_NN_H_ diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc new file mode 100644 index 000000000000..c31ce3dd0ba6 --- /dev/null +++ b/src/relax/op/nn/pooling.cc @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "pooling.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.nn.max_pool2d and relax.nn.avg_pool2d */ +TVM_REGISTER_NODE_TYPE(Pool2DAttrs); + +Expr MakePool2d(String op_name, Expr data, Array pool_size, Array strides, + Array padding, Array dilation, bool ceil_mode, String layout, + Optional out_layout) { + padding = GetCompletePadding2D(std::move(padding)); + if (pool_size.size() == 1) { + pool_size.push_back(pool_size[0]); + } + if (strides.size() == 1) { + strides.push_back(strides[0]); + } + if (dilation.size() == 1) { + dilation.push_back(dilation[0]); + } + + CHECK_EQ(pool_size.size(), 2) + << "The input pool_size length is expected to be 2. However, the given pool_size is " + << pool_size; + CHECK_EQ(strides.size(), 2) + << "The input strides length is expected to be 2. However, the given strides is " << strides; + CHECK_EQ(dilation.size(), 2) + << "The input dilation length is expected to be 2. However, the given dilation is " + << dilation; + + auto attrs = make_object(); + attrs->pool_size = ConvertIntImmToInt64(pool_size); + attrs->strides = ConvertIntImmToInt64(strides); + attrs->padding = ConvertIntImmToInt64(padding); + attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->ceil_mode = ceil_mode; + attrs->layout = layout; + attrs->out_layout = out_layout.value_or(layout); + const Op& op = Op::Get(op_name); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +Expr max_pool2d(Expr data, Array pool_size, Array strides, Array padding, + Array dilation, bool ceil_mode, String layout, + Optional out_layout) { + return MakePool2d("relax.nn.max_pool2d", data, pool_size, strides, padding, dilation, ceil_mode, + layout, out_layout); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d); + +StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, attrs->out_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"output"); + + Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + if (!data_shape.defined()) { + return TensorStructInfo(data_sinfo->dtype, out_layout.ndim()); + } + + Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + + PrimExpr input_h = data_NCHW_shape[2]; + PrimExpr input_w = data_NCHW_shape[3]; + PrimExpr kernel_h = attrs->pool_size[0]; + PrimExpr kernel_w = attrs->pool_size[1]; + PrimExpr padding_h = attrs->padding[0] + attrs->padding[2]; + PrimExpr padding_w = attrs->padding[1] + attrs->padding[3]; + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + std::vector out_NCHW_shape; + out_NCHW_shape.resize(4); + out_NCHW_shape[0] = data_NCHW_shape[0]; + out_NCHW_shape[1] = data_NCHW_shape[1]; + + PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w - 1) - 1; + if (attrs->ceil_mode) { + numerator_h += attrs->strides[0] - 1; + numerator_w += attrs->strides[1] - 1; + } + out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[0]) + 1); + out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1); + + Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); +} + +InferLayoutOutput InferLayoutPool2d(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + const auto* tensor_sinfo = GetStructInfoAs(call); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK_EQ(tensor_sinfo->ndim, 4) << "Unsupported initial layout"; + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name(); + new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4), layout->layout).name(); + return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.nn.max_pool2d") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoPool2D) + .set_attr("FRelaxInferLayout", InferLayoutPool2d) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +Expr avg_pool2d(Expr data, Array pool_size, Array strides, Array padding, + Array dilation, bool ceil_mode, String layout, + Optional out_layout) { + return MakePool2d("relax.nn.avg_pool2d", data, pool_size, strides, padding, dilation, ceil_mode, + layout, out_layout); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool2d").set_body_typed(avg_pool2d); + +TVM_REGISTER_OP("relax.nn.avg_pool2d") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoPool2D) + .set_attr("FRelaxInferLayout", InferLayoutPool2d) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.nn.adaptive_avg_pool2d */ +TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs); + +Expr adaptive_avg_pool2d(Expr data, Optional> output_size, String layout, + Optional out_layout) { + ObjectPtr attrs = make_object(); + attrs->layout = layout; + attrs->out_layout = out_layout.value_or(layout); + if (output_size.defined()) { + Array _output_size = output_size.value(); + if (_output_size.size() == 1) { + _output_size.push_back(_output_size[0]); + } + CHECK_EQ(_output_size.size(), 2) + << "The output_size length is expected to be 2. However, the given output_size is " + << _output_size; + attrs->output_size = std::move(_output_size); + } + + static const Op& op = Op::Get("relax.nn.adaptive_avg_pool2d"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool2d").set_body_typed(adaptive_avg_pool2d); + +StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, attrs->out_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"output"); + + Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + if (!data_shape.defined()) { + if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout && + !attrs->output_size.defined()) { + return data_sinfo; + } else { + return TensorStructInfo(data_sinfo->dtype, out_layout.ndim()); + } + } + + Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + Array out_NCHW_shape(data_NCHW_shape); + if (attrs->output_size.defined()) { + out_NCHW_shape.Set(2, attrs->output_size.value()[0]); + out_NCHW_shape.Set(3, attrs->output_size.value()[1]); + } + + Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); +} + +InferLayoutOutput InferLayoutAdaptiveAvgPool2D(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + const auto* tensor_sinfo = GetStructInfoAs(call); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK_EQ(tensor_sinfo->ndim, 4) << "Unsupported initial layout"; + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name(); + new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4), layout->layout).name(); + return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoAdaptiveAvgPool2D) + .set_attr("FRelaxInferLayout", InferLayoutAdaptiveAvgPool2D) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h new file mode 100644 index 000000000000..63d2e76772e2 --- /dev/null +++ b/src/relax/op/nn/pooling.h @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file pooling.h + * \brief The functions to make Relax neural network pooling operator calls. + */ + +#ifndef TVM_RELAX_OP_NN_POOLING_H_ +#define TVM_RELAX_OP_NN_POOLING_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief 2D maximum pooling operator. */ +Expr max_pool2d(Expr data, Array pool_size, Array strides, Array padding, + Array dilation, bool ceil_mode, String layout, Optional out_layout); + +/*! \brief 2D average pooling operator. */ +Expr avg_pool2d(Expr data, Array pool_size, Array strides, Array padding, + Array dilation, bool ceil_mode, String layout, Optional out_layout); + +/*! \brief 2D adaptive average pooling operator. */ +Expr adaptive_avg_pool2d(Expr data, Optional> output_size, String layout, + Optional out_layout); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_NN_POOLING_H_ diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc new file mode 100644 index 000000000000..c641c45922d8 --- /dev/null +++ b/src/relax/op/op.cc @@ -0,0 +1,557 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +#include "op_common.h" + +namespace tvm { +namespace relax { + +bool EqualConstInt(const PrimExpr& lhs, int64_t value) { + if (const int64_t* pvalue = tir::as_const_int(lhs)) { + return pvalue[0] == value; + } + return false; +} + +bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { + PrimExpr diff = lhs - rhs; + if (const int64_t* pdiff = tir::as_const_int(diff)) { + return pdiff[0] == 0; + } + tvm::arith::Analyzer ana; + diff = ana.Simplify(diff); + if (const int64_t* pdiff = tir::as_const_int(diff)) { + return pdiff[0] == 0; + } + return false; +} + +StructInfo ReturnVoidStructInfo(const Call& call, const BlockBuilder& ctx) { + return TupleStructInfo(Array()); +} + +StructInfo ReturnObjectStructInfo(const Call& call, const BlockBuilder& ctx) { + return ObjectStructInfo(); +} + +StructInfo InferStructInfoShapeOf(const Call& call, const BlockBuilder& ctx) { + // use the StructInfo of the argument + auto arg_sinfo = GetStructInfo(call->args[0]); + auto* tensor_sinfo = GetStructInfo(call->args[0]).as(); + CHECK(tensor_sinfo) << "shape_of expects a tensor input, but received " << arg_sinfo + << "; use MatchCast if necessary"; + if (tensor_sinfo->ndim == kUnknownNDim) { + return ShapeStructInfo(kUnknownNDim); + } + // if the tensor shape is a Relax var or omitted, do not try to construct a shape expr from it + if (!tensor_sinfo->shape.defined() || tensor_sinfo->shape.as()) { + return ShapeStructInfo(tensor_sinfo->ndim); + } + // otherwise, copy over the values from the tensor shape + auto* tensor_shape = tensor_sinfo->shape.as(); + CHECK(tensor_shape); + return ShapeStructInfo(tensor_shape->values); +} + +// call_tir + +StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "sinfo_args should have exact 1 output struct info."); + } + CHECK(call->args[0]->IsInstance()) + << "call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " + << "However, gets " << call->args[0]; + return call->sinfo_args[0]; +} + +RELAY_REGISTER_OP("relax.call_tir") + .set_num_inputs(3) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", "The input arguments.") + .add_argument("packed_ints", "Expr", + "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " + "args if unused") + .set_attr("FInferStructInfo", InferStructInfoCallTIR); + +Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, + Optional packed_ints) { + for (const TensorStructInfo& sinfo : out_sinfo_list) { + const auto* shape = sinfo->shape.as(); + CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; + } + + StructInfo out_sinfo{nullptr}; + if (out_sinfo_list.size() == 1) { + out_sinfo = out_sinfo_list[0]; + } else { + out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + } + + static const Op& op = Op::Get("relax.call_tir"); + Call call; + if (!packed_ints) { + // don't use additional optional argument + call = Call(op, {func, args}, {}, {out_sinfo}); + } else { + call = Call(op, {func, args, packed_ints.value()}, {}, {out_sinfo}); + } + return call; +} + +TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); + +// call_dps_packed + +StructInfo InferStructInfoCallDPSPacked(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "sinfo_args should have exact 1 output struct info."); + } + return call->sinfo_args[0]; +} + +RELAY_REGISTER_OP("relax.call_dps_packed") + .set_num_inputs(2) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", "The input arguments.") + .set_attr("FInferStructInfo", InferStructInfoCallDPSPacked); + +Expr MakeCallDPSPacked(Expr func, Tuple args, Array out_sinfo_list) { + for (const TensorStructInfo& sinfo : out_sinfo_list) { + const auto* shape = sinfo->shape.as(); + CHECK(shape != nullptr) + << "out_sinfo of call_dps_packed should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; + } + + StructInfo out_sinfo{nullptr}; + if (out_sinfo_list.size() == 1) { + out_sinfo = out_sinfo_list[0]; + } else { + out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + } + + static const Op& op = Op::Get("relax.call_dps_packed"); + return Call(op, {func, args}, {}, {out_sinfo}); +} + +TVM_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked); + +// call builtin +StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.size() == 0) { + // by default return void. + return TupleStructInfo(Array()); + } else { + ICHECK_EQ(call->sinfo_args.size(), 1); + return call->sinfo_args[0]; + } +} + +TVM_REGISTER_OP("relax.call_builtin_with_ctx") + .set_num_inputs(4) + .add_argument("func", "Expr", "The builtin packed func.") + .add_argument("args", "Tuple", "The input arguments.") + .set_attr("FInferStructInfo", InferStructInfoCallBuiltinWithCtx); + +Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, Array sinfo_args) { + static const Op& op = Op::Get("relax.call_builtin_with_ctx"); + return Call(op, {func, args}, Attrs(), sinfo_args); +} + +TVM_REGISTER_GLOBAL("relax.op.call_builtin_with_ctx").set_body_typed(MakeCallBuiltinWithCtx); + +TVM_REGISTER_OP("relax.null_value") + .set_num_inputs(0) + .set_attr("FInferStructInfo", ReturnObjectStructInfo); + +Expr MakeCallNullValue() { + static const Op& op = Op::Get("relax.null_value"); + return Call(op, {}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue); + +// print + +RELAY_REGISTER_OP("relax.print") + .set_num_inputs(-1) + .add_argument("vals", "Array", + "The first value is Python-style format string to use to print. The others " + "are values to print") + .set_attr("FInferStructInfo", ReturnVoidStructInfo) + .set_attr("FCallPacked", "relax.run.print"); + +Expr MakePrint(Array vals, StringImm format) { + Array params; + params.push_back(format); + for (const auto val : vals) { + params.push_back(val); + } + static const Op& op = Op::Get("relax.print"); + return Call(op, params); +} + +TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint); + +// assert_op + +// can't actually name it assert or else Python will consider it a syntax error + +StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) { + // Ensure that the condition argument is a boolean scalar. + // Also permitted is a tensor with unknown shape and unknown dtype + // (checked dynamically in that case). Returns void. + if (call->args.size() < 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Assert must have at least one argument (the condition)."); + } + StructInfo arg_struct_info = GetStructInfo(call->args[0]); + if (!IsBoolStructInfo(arg_struct_info)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The argument to assert must be a boolean scalar, but received " + << arg_struct_info); + } + return ReturnVoidStructInfo(call, ctx); +} + +RELAY_REGISTER_OP("relax.assert_op") + .set_num_inputs(-1) + .add_argument("vals", "Array", + "The first value is used as the assertion condition. The second value is " + "Python-style format string to use for displaying an error message, if the " + "assert fails. The others are used as format arguments if there is an error.") + .set_attr("FInferStructInfo", InferAssertStructInfo) + .set_attr("FCallPacked", "relax.run.assert_op"); + +Expr MakeAssertOp(Expr condition, Array vals, StringImm format) { + static const Op& op = Op::Get("relax.assert_op"); + Array args = {condition}; + args.push_back(format); + for (auto val : vals) { + args.push_back(val); + } + return Call(op, args); +} + +TVM_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp); + +// make_closure + +RELAY_REGISTER_OP("relax.make_closure") + .set_num_inputs(2) + .add_argument("func", "Expr", "The closure.") + .add_argument("args", "Tuple", "The captured variables.") + .set_attr("FInferStructInfo", ReturnObjectStructInfo); + +Expr MakeClosure(Expr func, Tuple args) { + static const Op& op = Op::Get("relax.make_closure"); + return Call(op, {func, args}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure); + +// invoke_closure + +StructInfo InferStructInfoInvokeClosure(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.empty()) { + return ObjectStructInfo(); + } else if (call->sinfo_args.size() == 1) { + return call->sinfo_args[0]; + } else { + return TupleStructInfo(call->sinfo_args); + } +} + +RELAY_REGISTER_OP("relax.invoke_closure") + .set_num_inputs(2) + .add_argument("closure", "Expr", "The VMClosure.") + .add_argument("args", "Tuple", "The captured variables.") + .set_attr("FInferStructInfo", InferStructInfoInvokeClosure); + +Expr InvokeClosure(Expr closure, Tuple args, Array sinfo_args) { + static const Op& op = Op::Get("relax.invoke_closure"); + return Call(op, {closure, args}, {}, sinfo_args); +} + +TVM_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure); + +// shape_of + +RELAY_REGISTER_OP("relax.shape_of") + .set_num_inputs(1) + .add_argument("input", "Expr", "The input expression") + .set_attr("FInferStructInfo", InferStructInfoShapeOf); + +Expr MakeShapeOf(Expr expr) { + static const Op& op = Op::Get("relax.shape_of"); + return Call(op, {expr}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf); + +// tensor_to_shape + +StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& ctx) { + ICHECK(call->args.size() == 1); + ICHECK(call->args[0]->struct_info_.defined()); + const auto* tsinfo = GetStructInfoAs(call->args[0]); + ICHECK(tsinfo && tsinfo->shape.defined()); + ShapeExpr shape_expr = Downcast(tsinfo->shape.value()); + ICHECK(shape_expr->values.size() == 1); + const IntImmNode* ndim = shape_expr->values[0].as(); + ICHECK(ndim); + return ShapeStructInfo(ndim->value); +} + +RELAY_REGISTER_OP("relax.tensor_to_shape") + .set_num_inputs(1) + .add_argument("input", "Expr", "The input expression") + .set_attr("FInferStructInfo", ReturnTensorToShapeStructInfo); + +Expr MakeTensorToShape(Expr expr) { + static const Op& op = Op::Get("relax.tensor_to_shape"); + return Call(op, {expr}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.tensor_to_shape").set_body_typed(MakeTensorToShape); + +// shape_to_tensor +StructInfo ReturnShapeToTensorStructInfo(const Call& call, const BlockBuilder& ctx) { + ICHECK(call->args.size() == 1); + ICHECK(call->args[0]->struct_info_.defined()); + const auto* sinfo = GetStructInfoAs(call->args[0]); + ICHECK(sinfo); + int32_t ndim = sinfo->ndim; + return TensorStructInfo(ShapeExpr({PrimExpr(ndim)}), DataType::Int(64)); +} + +RELAY_REGISTER_OP("relax.shape_to_tensor") + .set_num_inputs(1) + .add_argument("input", "Expr", "The input expression") + .set_attr("FInferStructInfo", ReturnShapeToTensorStructInfo) + .set_attr("FCallPacked", "relax.run.shape_to_tensor"); + +Expr MakeShapeToTensor(Expr expr) { + static const Op& op = Op::Get("relax.shape_to_tensor"); + return Call(op, {expr}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.shape_to_tensor").set_body_typed(MakeShapeToTensor); + +// alloc_tensor + +StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& ctx) { + ICHECK(call->args[0].as()) + << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey(); + ICHECK(call->args[1].as()) + << "must be DataTypeImm, but got " << call->args[1]->GetTypeKey(); + DataType out_dtype; + if (const auto* dtype_node = call->args[1].as()) { + const DataTypeImm dtype_imm = GetRef(dtype_node); + out_dtype = dtype_imm->value; + } + return TensorStructInfo(call->args[0], out_dtype); +} + +RELAY_REGISTER_OP("relax.builtin.alloc_tensor") + .set_num_inputs(3) + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") + .add_argument("runtime_device_index", "PrimValue", + "The device index indicating on which device the tensor is to be " + "allocated at runtime. Index -1 is reserved for the host device.") + .set_attr("FInferStructInfo", InferStructInfoAllocateTensor); + +Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_index) { + static const Op& op = Op::Get("relax.builtin.alloc_tensor"); + return Call(op, {shape, dtype, runtime_device_index}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); + +// memory planning alloc_storage + +RELAY_REGISTER_OP("relax.memory.alloc_storage") + .set_num_inputs(4) + .add_argument("total_space", "Expr", "The total space of the storage to allocate.") + .add_argument( + "virtual_device_index", "PrimValue", + "The virtual device index indicating on which device the storage is to be allocated, " + "Index -1 is reserved for the host device.") + .add_argument("storage_scope", "StringImm", + "The storage scope of the storage to allocate. Default is global.") + .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") + .set_attr("FInferStructInfo", ReturnObjectStructInfo); + +Expr MakeAllocStorage(Expr size, PrimValue virtual_device_index, StringImm storage_scope, + DataTypeImm dtype) { + static const Op& op = Op::Get("relax.memory.alloc_storage"); + return Call(op, {size, virtual_device_index, storage_scope, dtype}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocStorage); + +// memory planning alloc_tensor + +StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& ctx) { + ICHECK(GetStructInfoAs(call->args[2])) + << "must be a Expr of ShapeStructInfo, but got " << call->args[1]->GetTypeKey(); + DataType out_dtype; + if (const auto* dtype_node = call->args[3].as()) { + const DataTypeImm dtype_imm = GetRef(dtype_node); + out_dtype = dtype_imm->value; + } + return TensorStructInfo(call->args[2], out_dtype); +} + +RELAY_REGISTER_OP("relax.memory.alloc_tensor") + .set_num_inputs(4) + .add_argument("storage", "Expr", "The storage to allocate the tensor to.") + .add_argument("offset", "PrimValue", "Storage offset to allocate the tensor.") + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") + .set_attr("FInferStructInfo", InferStructInfoMemAllocTensor); + +Expr MakeMemAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm dtype) { + static const Op& op = Op::Get("relax.memory.alloc_tensor"); + return Call(op, {storage, offset, shape, dtype}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocTensor); + +// memory planning kill_storage + +RELAY_REGISTER_OP("relax.memory.kill_storage") + .set_num_inputs(1) + .add_argument("storage", "Expr", "The storage to be killed.") + .set_attr("FInferStructInfo", ReturnVoidStructInfo); + +Expr MakeMemKillStorage(Expr storage) { + static const Op& op = Op::Get("relax.memory.kill_storage"); + return Call(op, {storage}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.kill_storage").set_body_typed(MakeMemKillStorage); + +// memory planning kill_tensor + +RELAY_REGISTER_OP("relax.memory.kill_tensor") + .set_num_inputs(1) + .add_argument("tensor", "Expr", "The tensor to be killed.") + .set_attr("FInferStructInfo", ReturnVoidStructInfo); + +Expr MakeMemKillTensor(Expr tensor) { + static const Op& op = Op::Get("relax.memory.kill_tensor"); + return Call(op, {tensor}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.kill_tensor").set_body_typed(MakeMemKillTensor); + +// vm alloc_storage + +RELAY_REGISTER_OP("relax.vm.alloc_storage") + .set_num_inputs(3) + .add_argument("size", "Expr", "The size of the storage to allocate.") + .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") + .add_argument("runtime_device_index", "PrimValue", + "The device index indicating on which device the tensor is " + "to be allocated at runtime.") + .set_attr("FInferStructInfo", ReturnObjectStructInfo); + +Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm dtype) { + static const Op& op = Op::Get("relax.vm.alloc_storage"); + return Call(op, {size, runtime_device_index, dtype}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStorage); + +// vm alloc_tensor + +StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ctx) { + DataType out_dtype; + if (const auto* dtype_node = call->args[3].as()) { + const DataTypeImm dtype_imm = GetRef(dtype_node); + out_dtype = dtype_imm->value; + } + if (const auto* output_shape = call->args[2].as()) { + return TensorStructInfo(GetRef(output_shape), out_dtype); + } + return TensorStructInfo(out_dtype, kUnknownNDim); +} + +RELAY_REGISTER_OP("relax.vm.alloc_tensor") + .set_num_inputs(4) + .add_argument("storage", "Expr", "The storage to allocate the tensor to.") + .add_argument("offset", "PrimValue", "Storage offset to allocate the tensor.") + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") + .set_attr("FInferStructInfo", InferStructInfoVMAllocTensor); + +Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm dtype) { + static const Op& op = Op::Get("relax.vm.alloc_tensor"); + return Call(op, {storage, offset, shape, dtype}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor); + +// vm call_tir_dyn + +RELAY_REGISTER_OP("relax.vm.call_tir_dyn") + .set_num_inputs(2) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", + "The input arguments (list of tensors and last argument is ShapeExpr)") + .set_attr("FInferStructInfo", ReturnVoidStructInfo); + +Expr MakeCallTIRDyn(Expr func, Tuple args) { + static const Op& op = Op::Get("relax.vm.call_tir_dyn"); + return Call(op, {func, args}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.call_tir_dyn").set_body_typed(MakeCallTIRDyn); + +// builtin stop_lift_params +StructInfo InferStructInfoStopLiftParams(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoUnaryArith(call, ctx); +} + +RELAY_REGISTER_OP("relax.builtin.stop_lift_params") + .set_num_inputs(1) + .add_argument("x", "Expr", "The input data") + .set_attr("FInferStructInfo", InferStructInfoStopLiftParams); + +Expr MakeStopLiftParams(Expr x) { + static const Op& op = Op::Get("relax.builtin.stop_lift_params"); + return Call(op, {x}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.builtin.stop_lift_params").set_body_typed(MakeStopLiftParams); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc new file mode 100644 index 000000000000..0997a3623e8a --- /dev/null +++ b/src/relax/op/op_common.cc @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "op_common.h" + +#include + +namespace tvm { +namespace relax { + +Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { + Op op = Downcast(call->op); + int n_input = op->arguments.size(); + if (static_cast(call->args.size()) != n_input) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " op should have " << n_input << " arguments"); + } + Array input_tensor_sinfo; + input_tensor_sinfo.reserve(n_input); + for (int i = 0; i < n_input; ++i) { + const auto* sinfo = GetStructInfoAs(call->args[i]); + if (sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " requires the input " << op->arguments[i]->name + << " to be Tensor. However, the given one has a " + << call->args[i]->struct_info_->GetTypeKey()); + } + input_tensor_sinfo.push_back(GetRef(sinfo)); + } + return input_tensor_sinfo; +} + +Optional> InferBinaryBroadcastShape(const Call& call, const BlockBuilder& ctx, + const Array& x1_shape, + const Array& x2_shape) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + int x1_ndim = x1_shape.size(); + int x2_ndim = x2_shape.size(); + int max_ndim = std::max(x1_ndim, x2_ndim); + + std::vector output_shape; + output_shape.reserve(max_ndim); + + int i = 1; + for (; i <= std::min(x1_ndim, x2_ndim); ++i) { + const PrimExpr& dim0 = x1_shape[x1_ndim - i]; + const PrimExpr& dim1 = x2_shape[x2_ndim - i]; + const auto* int_dim0 = dim0.as(); + const auto* int_dim1 = dim1.as(); + if (int_dim0 != nullptr && int_dim0->value == 1) { + output_shape.push_back(dim1); + } else if (int_dim1 != nullptr && int_dim1->value == 1) { + output_shape.push_back(dim0); + } else if (analyzer->CanProveEqual(dim0, dim1)) { + output_shape.push_back(dim0); + } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) { + ctx->ReportFatal(Diagnostic::Error(call) + << "In " << call->op << ", the first input shape at dim " << x1_ndim - i + << " is " << dim0 << " and the second input shape at dim " << x2_ndim - i + << " is " << dim1 << ", which are not broadcastable."); + } else { + // Use simple fallback when shape mismatch. + return NullOpt; + } + } + auto& longer_shape = (x1_ndim > x2_ndim) ? x1_shape : x2_shape; + for (; i <= max_ndim; ++i) { + output_shape.push_back(longer_shape[max_ndim - i]); + } + return Array(output_shape.rbegin(), output_shape.rend()); +} + +std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, + const Array& axes) { + ICHECK_NE(ndim, kUnknownNDim) << "The ndim is required to be known for this function."; + std::vector appeared_dims_set; + std::vector axes_non_neg; + appeared_dims_set.resize(ndim, /*value=*/false); + axes_non_neg.reserve(axes.size()); + for (const Integer& axis : axes) { + int _axis = axis->value; + if (_axis < -ndim || _axis >= ndim) { + ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", the input axis " << _axis + << " is out of range. The input tensor has " << ndim + << " dimensions, so axis should be in range [" + << -ndim << ", " << ndim << ")."); + } else if (_axis < 0) { + _axis = ndim + _axis; + } + + if (appeared_dims_set[_axis]) { + ctx->ReportFatal(Diagnostic::Error(call) + << "In " << call->op + << ", the input axes is required to be non-repetitive. However, there are " + "multiple given axes referring to axis " + << _axis); + } + appeared_dims_set[_axis] = true; + axes_non_neg.push_back(_axis); + } + return axes_non_neg; +} + +InferLayoutOutput InferLayoutUnaryEwise(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + return InferLayoutOutput({layout}, {layout}, Attrs(call->attrs)); +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h new file mode 100644 index 000000000000..616dded39e52 --- /dev/null +++ b/src/relax/op/op_common.h @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file op_common.h + * \brief A set of utilities and common functionality + * for Relax ops. + */ +#ifndef TVM_RELAX_OP_OP_COMMON_H_ +#define TVM_RELAX_OP_OP_COMMON_H_ + +#include +#include +#include +#include +#include + +#include +#include + +#include "../transform/infer_amp_utils.h" +#include "../transform/infer_layout_utils.h" + +namespace tvm { +namespace relax { + +/************ Op input struct info getter ************/ + +/*! + * \brief Get the tensor struct info of the operator input. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \return The tensor struct info of each input. + * \note This function require every input to be Tensor. The number of call arguments is required + * to match the number of inputs of the op being called. + */ +Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx); + +/*! + * \brief Get the tensor struct info of the unary operator input. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \return The tensor struct info of the unary operator input. + * \throw Throw exception if the number of input is not one, or the struct info of the input is not + * a tensor struct info. + */ +inline TensorStructInfo GetUnaryInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { + return GetInputTensorStructInfo(call, ctx)[0]; +} + +/************ Op registration macro ************/ + +/*! + * \brief Quick helper macro to register the operator to registry + * \param OpRegName The name of operator to register. The name passed in will + * be prepended with a prefix "relax." as the identifier string in the operator registry. + */ +#define RELAX_REGISTER_UNARY_OP(OpRegName) \ + TVM_REGISTER_OP("relax." OpRegName) \ + .set_num_inputs(1) \ + .add_argument("x", "Tensor", "The input tensor.") \ + .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) \ + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + +/*! + * \brief Quick helper macro to expose a make-function to construct the operator. + * \param OpName The name of the operator as well as the make-function name, which will + * be prepended with a prefix "relax.op." as the FFI identifier string for the make function, + * \param OpRegName The identifier of the operator in the registry. + */ +#define RELAX_UNARY_OP_INTERFACE(OpName, OpRegName) \ + Expr OpName(Expr x) { \ + static const Op& op = Op::Get("relax." OpRegName); \ + return Call(op, {std::move(x)}, Attrs(), {}); \ + } \ + TVM_REGISTER_GLOBAL("relax.op." OpRegName).set_body_typed(OpName) + +/************ Utilities ************/ + +/*! + * \brief Infer the struct info for unary elementwise ops. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param f_compute_out_dtype The function to compute the output dtype, with + * signature DataType f_compute_out_dtype(const TensorStructInfo& input_sinfo). + * \tparam require_float_dtype whether this op requires the input dtype to be float + * \tparam Ftype the type of f_compute_out_dtype + * \return The inferred struct info. + */ +template +inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx, + FType f_compute_out_dtype) { + TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (require_float_dtype && !input_sinfo->IsUnknownDtype() && !input_sinfo->dtype.is_float()) { + ctx->ReportFatal( + Diagnostic::Error(call) + << call->op + << " requires the input tensor to have float dtype. However, the given input dtype is " + << input_sinfo->dtype); + } + auto output_sinfo = make_object(*input_sinfo.get()); + output_sinfo->dtype = f_compute_out_dtype(input_sinfo); + return TensorStructInfo(output_sinfo); +} + +/*! + * \brief Infer the struct info by returning the struct info of the input argument. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \tparam arg_index The index of the argument to infer the output dtype from. + * \return The inferred struct info. + */ +template +StructInfo ReturnStructInfoFromArg(const Call& call, const BlockBuilder& ctx) { + Op op = Downcast(call->op); + int n_input = op->arguments.size(); + if (static_cast(call->args.size()) != n_input) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " op should have " << n_input << " arguments"); + } + if (arg_index >= n_input) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " op has only " << n_input + << "arguments, but try to get the arg with index " << arg_index); + } + return GetStructInfo(call->args[arg_index]); +} + +/*! + * \brief Infer the struct info for unary arithmetic elementwise ops. It's also + * used in some NN operators. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \tparam require_float_dtype whether this op requires the input dtype to be float + * \return The inferred struct info. + */ +template +StructInfo InferStructInfoUnaryArith(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoUnary( + call, ctx, [](const TensorStructInfo& input_sinfo) { return input_sinfo->dtype; }); +} + +/*! + * \brief Layout infer util for unary elementwise ops. It will simply take the layout of the input. + * \param call The context Call to the operator. + * \param desired_layouts The desired layouts of certain ops. + * \param var_layout_map The layout of vars. + * \return The inferred layout result. + */ +InferLayoutOutput InferLayoutUnaryEwise(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map); + +/*! + * \brief Infer the output datatype for binary arithmetic operators. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param x1_sinfo The struct info of the first operand + * \param x2_sinfo The struct info of the second operand + * \return The inferred output dtype. + * \throw Throw exception if the dtype of two input TensorStructInfo don’t match + */ +inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx, + const TensorStructInfo& x1_sinfo, + const TensorStructInfo& x2_sinfo) { + if (x1_sinfo->IsUnknownDtype() || x2_sinfo->IsUnknownDtype()) { + return DataType::Void(); + } else if (x1_sinfo->dtype != x2_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Data types " << x1_sinfo->dtype << " and " << x2_sinfo->dtype + << " must be equal for binary operators"); + } + return x1_sinfo->dtype; +} + +/*! + * \brief Infer the output shape for binary broadcast operators. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param x1_shape The shape of the first operand. + * \param x2_shape The shape of the second operand. + * \return The inferred output shape after broadcasting. Or `NullOpt` if the output shape cannot be + * determined due to symbolic broadcast. + */ +Optional> InferBinaryBroadcastShape(const Call& call, const BlockBuilder& ctx, + const Array& x1_shape, + const Array& x2_shape); + +/*! + * \brief Convert all axes to non-negative indices, and meanwhile check if the given array of axes + * are all in range and non-repetitive with regards to the given ndim. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param ndim The ndim constraint, which is required to be known already. + * \param axes The axis indices to be checked + * \return The input axes in non-negative indexing. + * \throw Throw exception if there exists out-of-range axis index or repetitive indices. + */ +std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, + const Array& axes); + +/*! + * \brief Convert the given axis to non-negative index. Meanwhile check if the axis is in range + * with regards to the given ndim. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param ndim The ndim constraint. + * \param axis The axis index to be checked + * \return The input axis in non-negative indexing. + * \throw Throw exception the given axis is out-of-range. + */ +inline int NormalizeAxis(const Call& call, const BlockBuilder& ctx, int ndim, int axis) { + return NormalizeAxes(call, ctx, ndim, {axis})[0]; +} + +/*! + * \brief Convert an array of integers to int64 dtype. + * \param int_imms The input IntImms to be converted. + * \return The conversion result, where every IntImm has dtype int64 + */ +inline Array ConvertIntImmToInt64(const Array& int_imms) { + return int_imms.Map([](const IntImm& i) { return Downcast(cast(DataType::Int(64), i)); }); +} + +/************ Utilities for NN operators ************/ + +/*! + * \brief Complete the padding to a 2-length array. + * - If the padding length is 1, the same padding is used on all left/right sides + * - If the padding length is 2, padding is in the order of (left, right) + * \param padding The given padding to be completed + * \return The completed padding. + * \throws Throws error if the input padding length is neither 1 or 2. + */ +inline Array GetCompletePadding1D(Array padding) { + if (padding.size() == 1) { + return {padding[0], padding[0]}; + } else if (padding.size() == 2) { + return padding; + } + LOG(FATAL) << "The input padding length is expected to be either 1 or 2. However, the given " + "padding is " + << padding; + throw; +} + +/*! + * \brief Complete the padding to a 4-length array. + * - If the padding length is 1, the same padding is used on all top/left/bottom/right sides + * - If the padding length is 2, top/bottom sides use padding[0] and left/right use padding[1] + * - If the padding length is 4, padding is in the order of (top, left, bottom, right) + * \param padding The given padding to be completed + * \return The completed padding. + * \throws Throws error if the input padding length is neither 1, 2 or 4. + */ +inline Array GetCompletePadding2D(Array padding) { + if (padding.size() == 1) { + return {padding[0], padding[0], padding[0], padding[0]}; + } else if (padding.size() == 2) { + return {padding[0], padding[1], padding[0], padding[1]}; + } else if (padding.size() == 4) { + return padding; + } + LOG(FATAL) << "The input padding length is expected to be either 1, 2 or 4. However, the given " + "padding is " + << padding; + throw; +} + +/*! + * \brief Check if the given tensor layout can be converted to the given target layout. + * If convertible, return the tensor layout and the bijective conversion in tir::Layout and + * tir::BijectiveLayout accordingly. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param tensor_layout The tensor layout to be checked + * \param tgt_layout The target layout to be matched + * \param tensor_name The name of the input tensor + * \return The tensor layout and the bijective conversion in tir::Layout and tir::BijectiveLayout + * accordingly. + */ +inline std::pair CheckTensorLayout(const Call& call, + const BlockBuilder& ctx, + const String& tensor_layout, + const String& tgt_layout, + const String& tensor_name) { + tir::Layout _tensor_layout(tensor_layout, DataType::Int(64)); + tir::BijectiveLayout tensor2tgt(_tensor_layout, tir::Layout(tgt_layout, DataType::Int(64))); + if (!tensor2tgt.defined()) { + ctx->ReportFatal(Diagnostic::Error(call) << call->op << " requires the given " << tensor_name + << " layout to be convertible from " << tgt_layout + << " layout. However, the given layout " + << tensor_layout << " is not convertible."); + } + return {_tensor_layout, tensor2tgt}; +} + +/*! + * \brief Check if the given tensor struct info has expected ndim per the given layout (or the ndim + * is unknown), and try to cast the shape to ShapeExpr. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param sinfo The input tensor struct info to be checked. + * \param layout The layout that the given tensor is expected to have. + * \return The shape of the input tensor in ShapeExpr, or `NullOpt` if the shape is unknown. + */ +inline Optional CheckNdimPerLayoutAndGetShape(const Call& call, const BlockBuilder& ctx, + const TensorStructInfo& sinfo, + const tir::Layout& layout) { + if (!sinfo->IsUnknownNdim() && sinfo->ndim != static_cast(layout.ndim())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "In " << call->op << ", layout " << layout << " requires the input to be " + << layout.ndim() << "-dim tensor. However, the given input has ndim " + << sinfo->ndim); + } + if (const auto* shape_expr = sinfo->shape.as()) { + return GetRef(shape_expr); + } + return NullOpt; +} + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_OP_COMMON_H_ diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc new file mode 100644 index 000000000000..96d1f01e8a6b --- /dev/null +++ b/src/relax/op/tensor/binary.cc @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file binary.cc + * \brief binary broadcast operators. + */ + +#include "binary.h" + +#include + +namespace tvm { +namespace relax { + +template +StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, + FType f_compute_out_dtype) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo x1_sinfo = input_sinfo[0]; + TensorStructInfo x2_sinfo = input_sinfo[1]; + + // DateType + DataType output_dtype = f_compute_out_dtype(call, ctx, x1_sinfo, x2_sinfo); + + // ndims + int output_ndim; + if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { + output_ndim = kUnknownNDim; + } else { + output_ndim = std::max(x1_sinfo->ndim, x2_sinfo->ndim); + } + + const auto* x1_shape = x1_sinfo->shape.as(); + const auto* x2_shape = x2_sinfo->shape.as(); + // Shapes and ndims + if (x1_shape && x2_shape) { + // If all inputs have shapes, directly infer shapes + Optional> output_shape = + InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); + if (!output_shape.defined()) { + return TensorStructInfo(output_dtype, /*ndim=*/output_ndim); + } else { + ICHECK_EQ(static_cast(output_shape.value().size()), output_ndim); + return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); + } + } else if (x1_sinfo->shape.defined() && x1_sinfo->shape.same_as(x2_sinfo->shape)) { + return TensorStructInfo(x1_sinfo->shape.value(), output_dtype); + } else { + return TensorStructInfo(output_dtype, /*ndim=*/output_ndim); + } +} + +StructInfo InferStructInfoBroadcastArith(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoBroadcast(call, ctx, InferBinaryArithOpOutDtype); +} + +StructInfo InferStructInfoBroadcastCMP(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoBroadcast( + call, ctx, + [](const Call& call, const BlockBuilder& ctx, const TensorStructInfo& x1_sinfo, + const TensorStructInfo& x2_sinfo) { return DataType::Bool(); }); +} + +InferLayoutOutput InferLayoutBinaryEwise(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + LayoutDecision layout1 = GetLayoutDecision(var_layout_map, call->args[0]); + LayoutDecision layout2 = GetLayoutDecision(var_layout_map, call->args[1]); + + auto* x1_sinfo = GetStructInfoAs(call->args[0]); + auto* x2_sinfo = GetStructInfoAs(call->args[1]); + + ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim()) + << "Unknown dim tensors should not be handled by this function"; + + if (x1_sinfo->ndim <= x2_sinfo->ndim) { + LayoutDecision out_layout = FollowDecision(layout1, x2_sinfo->ndim); + return InferLayoutOutput({layout1, out_layout}, {out_layout}, Attrs(call->attrs)); + } else { + LayoutDecision out_layout = FollowDecision(layout2, x1_sinfo->ndim); + return InferLayoutOutput({out_layout, layout2}, {out_layout}, Attrs(call->attrs)); + } +} + +/***************** Arithmetic operators *****************/ + +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(divide); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract); + +/***************** Comparison operators *****************/ + +RELAX_REGISTER_CMP_OP_AND_IMPL(equal); +RELAX_REGISTER_CMP_OP_AND_IMPL(greater); +RELAX_REGISTER_CMP_OP_AND_IMPL(greater_equal); +RELAX_REGISTER_CMP_OP_AND_IMPL(less); +RELAX_REGISTER_CMP_OP_AND_IMPL(less_equal); +RELAX_REGISTER_CMP_OP_AND_IMPL(not_equal); + +/***************** Min/Max operators *****************/ + +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(minimum); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(maximum); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h new file mode 100644 index 000000000000..e386f9019fd4 --- /dev/null +++ b/src/relax/op/tensor/binary.h @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file binary.h + * \brief The functions to make Relax binary arithmetic and comparison operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_BINARY_H_ +#define TVM_RELAX_OP_TENSOR_BINARY_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Quick helper macro + * - Expose a make function to construct the node. + * - Register op to the registry. + * \param OpName The name of operator to register. The name passed in will + * 1. be prepended with a prefix "relax.op." as the FFI identifier string for the make function, + * 2. be prepended with a prefix "relax." as the identifier string in the operator registry. + */ +#define RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName) \ + Expr OpName(Expr x1, Expr x2) { \ + static const Op& op = Op::Get("relax." #OpName); \ + return Call(op, {x1, x2}, Attrs(), {}); \ + } \ + TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_REGISTER_OP("relax." #OpName) \ + .set_num_inputs(2) \ + .add_argument("x1", "Tensor", "The first input tensor.") \ + .add_argument("x2", "Tensor", "The second input tensor.") \ + .set_attr("FRelaxInferLayout", InferLayoutBinaryEwise) \ + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + +#define RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(OpName) \ + RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferStructInfo", InferStructInfoBroadcastArith) + +#define RELAX_REGISTER_CMP_OP_AND_IMPL(OpName) \ + RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferStructInfo", InferStructInfoBroadcastCMP) + +/***************** Arithmetic operators *****************/ + +/*! \brief Addition with numpy-style broadcasting. */ +Expr add(Expr x1, Expr x2); + +/*! \brief Division with numpy-style broadcasting. */ +Expr divide(Expr x1, Expr x2); + +/*! \brief Floor division with numpy-style broadcasting. */ +Expr floor_divide(Expr x1, Expr x2); + +/*! \brief Multiplication with numpy-style broadcasting. */ +Expr multiply(Expr x1, Expr x2); + +/*! \brief Power with numpy-style broadcasting. */ +Expr power(Expr x1, Expr x2); + +/*! \brief Subtraction with numpy-style broadcasting. */ +Expr subtract(Expr x1, Expr x2); + +/***************** Comparison operators *****************/ + +/*! \brief Broadcasted element-wise test for (lhs == rhs). */ +Expr equal(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs > rhs). */ +Expr greater(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs >= rhs). */ +Expr greter_equal(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs < rhs). */ +Expr less(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs <= rhs). */ +Expr less_equal(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs != rhs). */ +Expr not_equal(Expr x1, Expr x2); + +/***************** Min/Max *****************/ + +/*! \brief Element-wise minimum */ +Expr minimum(Expr x1, Expr x2); + +/*! \brief Element-wise maximum */ +Expr maximum(Expr x1, Expr x2); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_BINARY_H_ diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc new file mode 100644 index 000000000000..d4e5e166b72c --- /dev/null +++ b/src/relax/op/tensor/create.cc @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file create.cc + * \brief Creation operators. + */ + +#include "create.h" + +#include + +namespace tvm { +namespace relax { + +/* Initialization operators */ +TVM_REGISTER_NODE_TYPE(InitAttrs); + +/* relax.full */ +Expr full(ObjectRef shape, Expr fill_value, DataType dtype) { + Expr shape_in_expr{nullptr}; + if (const auto* expr = shape.as()) { + shape_in_expr = GetRef(expr); + } else if (const auto* _array = shape.as()) { + shape_in_expr = ShapeExpr(GetRef>(_array)); + } else { + LOG(FATAL) << "Full only expects the input shape to be either an Expr or an Array of PrimExpr. " + "However, the given one is " + << shape->GetTypeKey(); + } + + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.full"); + return Call(op, {std::move(shape_in_expr), std::move(fill_value)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.full").set_body_typed(full); + +StructInfo InferStructInfoFull(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "Full op should have 2 arguments"); + } + const auto* shape_sinfo = GetStructInfoAs(call->args[0]); + const auto* fill_value_sinfo = GetStructInfoAs(call->args[1]); + if (shape_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Full requires the input shape to be a Shape. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (fill_value_sinfo == nullptr || fill_value_sinfo->ndim != 0) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Full requires the input fill value to be zero rank Tensor. However, the given one is " + << call->args[1]->struct_info_); + } + + const auto* attrs = call->attrs.as(); + DataType out_dtype = attrs->dtype.is_void() ? fill_value_sinfo->dtype : attrs->dtype; + return TensorStructInfo(/*shape=*/call->args[0], out_dtype); +} + +TVM_REGISTER_OP("relax.full") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("shape", "Shape", "The shape of the created tensor.") + .add_argument("fill_value", "Tensor", "The scalar tensor, denoting the value to fill.") + .set_attr("FInferStructInfo", InferStructInfoFull) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.full_like */ +Expr full_like(Expr x, Expr fill_value, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.full_like"); + return Call(op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.full_like").set_body_typed(full_like); + +StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo fill_value_sinfo = input_sinfo[1]; + if (fill_value_sinfo->ndim != 0) { + ctx->ReportFatal(Diagnostic::Error(call) << "FullLike requires the input fill value to be zero " + "rank Tensor. However, the given one has ndim" + << fill_value_sinfo->ndim); + } + + const auto* attrs = call->attrs.as(); + if (attrs->dtype.is_void()) { + return data_sinfo; + } else { + auto output_sinfo = make_object(*data_sinfo.get()); + output_sinfo->dtype = attrs->dtype; + return TensorStructInfo(output_sinfo); + } +} + +TVM_REGISTER_OP("relax.full_like") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("fill_value", "Tensor", "The scalar value to fill.") + .set_attr("FInferStructInfo", InferStructInfoFullLike) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +// Structure info inference for ones and zeros +StructInfo InferStructInfoOnesZeros(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) << "Ones/Zeros should have 1 argument"); + } + + const auto* shape_sinfo = GetStructInfoAs(call->args[0]); + if (shape_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Ones/Zeros requires the input shape to be a Shape. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + const auto* attrs = call->attrs.as(); + return TensorStructInfo(/*shape=*/call->args[0], attrs->dtype); +} + +// Structure info inference for ones_like and zeros_like +StructInfo InferStructInfoOnesLikeZerosLike(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (attrs->dtype.is_void()) { + return data_sinfo; + } else { + auto output_sinfo = make_object(*data_sinfo.get()); + output_sinfo->dtype = attrs->dtype; + return TensorStructInfo(output_sinfo); + } +} + +/* relax.ones & relax.ones_like */ +Expr ones(Expr shape, DataType dtype) { + CHECK(!dtype.is_void()) << "Ones op expects the input dtype not to be void"; + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.ones"); + return Call(op, {std::move(shape)}, Attrs(attrs), {}); +} + +Expr ones_like(Expr x, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.ones_like"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.ones").set_body_typed(ones); +TVM_REGISTER_GLOBAL("relax.op.ones_like").set_body_typed(ones_like); + +TVM_REGISTER_OP("relax.ones") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("shape", "Shape", "The shape of the created tensor.") + .set_attr("FInferStructInfo", InferStructInfoOnesZeros) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +TVM_REGISTER_OP("relax.ones_like") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike); + +/* relax.zeros & relax.zeros_like */ +Expr zeros(Expr shape, DataType dtype) { + CHECK(!dtype.is_void()) << "Zeros op expects the input dtype not to be void"; + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.zeros"); + return Call(op, {std::move(shape)}, Attrs(attrs), {}); +} + +Expr zeros_like(Expr x, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.zeros_like"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.zeros").set_body_typed(zeros); +TVM_REGISTER_GLOBAL("relax.op.zeros_like").set_body_typed(zeros_like); + +TVM_REGISTER_OP("relax.zeros") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("shape", "Shape", "The shape of the created tensor.") + .set_attr("FInferStructInfo", InferStructInfoOnesZeros) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +TVM_REGISTER_OP("relax.zeros_like") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike); + +/* relax.tril & relax.triu */ +TVM_REGISTER_NODE_TYPE(TriluAttrs); + +Expr tril(Expr x, int k) { + ObjectPtr attrs = make_object(); + attrs->k = k; + + static const Op& op = Op::Get("relax.tril"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +Expr triu(Expr x, int k) { + ObjectPtr attrs = make_object(); + attrs->k = k; + + static const Op& op = Op::Get("relax.triu"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.tril").set_body_typed(tril); +TVM_REGISTER_GLOBAL("relax.op.triu").set_body_typed(triu); + +StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim < 2) { + ctx->ReportFatal(Diagnostic::Error(call) << call->op + << " requires the input tensor to have at least two " + "dimensions. However, the given input has " + << data_sinfo->ndim << " dimension(s)."); + } + return data_sinfo; +} + +TVM_REGISTER_OP("relax.tril") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoTrilTriu); + +TVM_REGISTER_OP("relax.triu") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoTrilTriu); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h new file mode 100644 index 000000000000..c1ade470b4e8 --- /dev/null +++ b/src/relax/op/tensor/create.h @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file create.h + * \brief The functions to make Relax tensor-creation operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_CREATE_H_ +#define TVM_RELAX_OP_TENSOR_CREATE_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Fill array with scalar value. + * \param shape The shape of the created tensor. + * \param fill_value The value to fill. Must be a scalar tensor. + * \param dtype The data type of the created tensor. + * If dtype is not given, it will by default use the dtype of fill_value. + * \return The result tensor. + */ +Expr full(ObjectRef shape, Expr fill_value, DataType dtype); + +/*! + * \brief Construct a tensor such that + * - its shape is the same as the input data tensor's shape, + * - its value is filled with the input scalar fill value. + * \param x The input tensor, which provides the shape, and dtype + * when the input dtype is void. + * \param fill_value The value to fill. Must be a scalar tensor. + * \param dtype The data type of the created tensor. If it is + * void, the input tensor's dtype will be used. + * \return The result tensor. + */ +Expr full_like(Expr x, Expr fill_value, DataType dtype); + +/*! + * \brief Construct a tensor of all ones, with the input shape and dtype. + * \param shape The shape of the created tensor. + * \param dtype The data type of the created tensor. + * \return The result tensor. + */ +Expr ones(Expr shape, DataType dtype); + +/*! + * \brief Construct a tensor with all ones, with shape of the input tensor shape. + * \param x The input tensor, which provides the shape, and dtype + * when the input dtype is void. + * \param dtype The data type of the created tensor. If it is + * void, the input tensor's dtype will be used. + * \return The result tensor. + */ +Expr ones_like(Expr x, DataType dtype); + +/*! \brief Construct a tensor of all zeros, with the input shape and dtype. */ +Expr zeros(Expr shape, DataType dtype); + +/*! \brief Construct a tensor with all zeros, with shape of the input tensor shape. */ +Expr zeros_like(Expr x, DataType dtype); + +/*! \brief Return the lower triangular part of a matrix or a batch of matrices. */ +Expr tril(Expr x, int k); + +/*! \brief Return the upper triangular part of a matrix or a batch of matrices. */ +Expr triu(Expr x, int k); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_CREATE_H_ diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc new file mode 100644 index 000000000000..18747fedcda0 --- /dev/null +++ b/src/relax/op/tensor/datatype.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file datatype.cc + * \brief Datatype operators. + */ + +#include "datatype.h" + +#include + +namespace tvm { +namespace relax { + +/* relax.astype */ +TVM_REGISTER_NODE_TYPE(AstypeAttrs); + +Expr astype(Expr x, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.astype"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.astype").set_body_typed(astype); + +StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + ObjectPtr new_sinfo = make_object(*sinfo.get()); + new_sinfo->dtype = attrs->dtype; + return TensorStructInfo(new_sinfo); +} + +TVM_REGISTER_OP("relax.astype") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoAstype) + .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.wrap_param */ +TVM_REGISTER_NODE_TYPE(WrapParamAttrs); + +Expr MakeWrapParam(Expr data, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.wrap_param"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.wrap_param").set_body_typed(MakeWrapParam); + +StructInfo InferStructInfoWrapParam(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + ObjectPtr new_sinfo = make_object(*sinfo.get()); + new_sinfo->dtype = attrs->dtype; + return TensorStructInfo(new_sinfo); +} + +TVM_REGISTER_OP("relax.wrap_param") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoWrapParam); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/datatype.h b/src/relax/op/tensor/datatype.h new file mode 100644 index 000000000000..b612c45fc941 --- /dev/null +++ b/src/relax/op/tensor/datatype.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file datatype.h + * \brief The functions to make Relax datatype operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_DATATYPE_H_ +#define TVM_RELAX_OP_TENSOR_DATATYPE_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Cast input tensor to the given data type. + * \param x The input data to the operator. + * \param dtype The target data type + * \return The casted result. + */ +Expr astype(Expr x, DataType dtype); + +/*! + * \brief A wrapper to wrap the input const tensor to the given data type. + * \param x The input const tensor to the operator. + * \param dtype The target data type + * \return The wrapped result. + */ +Expr wrap_param(Expr x, DataType dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_DATATYPE_H_ diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc new file mode 100644 index 000000000000..ac3ce084c435 --- /dev/null +++ b/src/relax/op/tensor/index.cc @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file index.cc + * \brief indexing operators. + */ + +#include "index.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.take */ +TVM_REGISTER_NODE_TYPE(TakeAttrs); + +Expr take(Expr x, Expr indices, Optional axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.take"); + return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.take").set_body_typed(take); + +StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo indices_sinfo = input_sinfo[1]; + if (indices_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Take op requires the input indices to be 1-dimensional tensor. However, " + "the given indices ndim is " + << indices_sinfo->ndim); + } else if (!indices_sinfo->IsUnknownDtype() && + !(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Take op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_sinfo->dtype); + } + + const auto* attrs = call->attrs.as(); + if (!attrs->axis.defined() && data_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Take op expects the input data to be 1-dimensional tensor when the axis " + "is not specified. However, the given data tensor has ndim " + << data_sinfo->ndim); + } + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + int axis = attrs->axis.defined() + ? NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value) + : 0; + const auto* data_shape = data_sinfo->shape.as(); + const auto* indices_shape = indices_sinfo->shape.as(); + if (data_shape == nullptr || indices_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + + Array output_shape = data_shape->values; + output_shape.Set(axis, indices_shape->values[0]); + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.take") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("x", "Tensor", "The source tensor.") + .add_argument("indices", "Tensor", "The indices of the values to extract.") + .set_attr("FInferStructInfo", InferStructInfoTake); + +/* relax.strided_slice */ +TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); + +Expr strided_slice(Expr x, // + Array axes, // + Array begin, // + Array end, // + Optional> strides) { + int n_axis = axes.size(); + CHECK_EQ(static_cast(begin.size()), n_axis) + << "StridedSlice requires the number of begin indices to equal the number of axes."; + CHECK_EQ(static_cast(end.size()), n_axis) + << "StridedSlice requires the number of end indices to equal the number of axes."; + if (strides.defined()) { + CHECK_EQ(static_cast(strides.value().size()), n_axis) + << "StridedSlice requires the number of strides to equal the number of axes."; + } + + // Todo(relax-team): We are going to support dynamic strided slice, where + // begin/end/stride can be not static at compile time. Therefore, begin/end/stride + // should not be part of StridedSliceAttrs, as we only allow static values to + // reside in attributes. However, using ShapeExpr to represent these + // arrays is not conceptually right, because they are not describing a + // concrete shape. The proper way to support dynamic strided slice is to use + // Tuple of PrimValue to represent begin/end/stride. Since at this moment + // we have no support for PrimValue, we store begin/end/stride as attribute + // fields as a workaround. + // Will switch to Tuple of PrimValue after introducing PrimValue. + auto f_convert_to_int64 = [](const PrimExpr& value) { + if (value->IsInstance()) { + return cast(DataType::Int(64), value); + } + CHECK(value.dtype() == DataType::Int(64)) << "strided_slice expects the input begin/end/stride " + "values to be all int64. However, the given " + << value << " has dtype " << value->dtype; + return value; + }; + + ObjectPtr attrs = make_object(); + attrs->axes = std::move(axes); + attrs->begin = begin.Map(f_convert_to_int64); + attrs->end = end.Map(f_convert_to_int64); + attrs->strides = strides.defined() ? strides.value().Map(f_convert_to_int64) : strides; + + static const Op& op = Op::Get("relax.strided_slice"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice); + +inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t stride) { + // Same as topi strided slice CanonicalizeIndex function in + // include/tvm/topi/detail/strided_slice.h + PrimExpr begin_range = stride < 0 ? -1 : 0; + PrimExpr end_range = stride < 0 ? extent - 1 : extent; + index = if_then_else(index < 0, index + extent, index); + return min(max(index, begin_range), end_range); // NOLINT +} + +PrimExpr GetLength(PrimExpr begin, PrimExpr end, const int64_t stride, const PrimExpr& length) { + begin = CanonicalizeIndex(begin, length, stride); + end = CanonicalizeIndex(end, length, stride); + + if (stride < 0) { + return ceildiv(begin - end, IntImm(DataType::Int(64), -stride)); + } else { + return ceildiv(end - begin, IntImm(DataType::Int(64), stride)); + } +} + +StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (attrs->axes.empty()) { + return data_sinfo; + } + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes); + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + + int n_axis = axes.size(); + Array strides = attrs->strides.defined() + ? attrs->strides.value() + : Array(n_axis, IntImm(DataType::Int(64), 1)); + std::vector int_strides; + int_strides.reserve(n_axis); + // Only do output shape inference when all the begin/end/stride values are integers. + for (int i = 0; i < n_axis; ++i) { + const auto* int_begin = attrs->begin[i].as(); + const auto* int_end = attrs->end[i].as(); + const auto* int_stride = strides[i].as(); + if (!int_begin || !int_end || !int_stride) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + int_strides.push_back(int_stride->value); + } + + Array output_shape = data_shape->values; + for (int i = 0; i < n_axis; ++i) { + ICHECK_NE(int_strides[i], 0) + << "Strided slice requires stride to be non-zero but got 0 for axis " << axes[i] << "."; + output_shape.Set(axes[i], GetLength(attrs->begin[i], attrs->end[i], int_strides[i], + data_shape->values[axes[i]])); + } + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +InferLayoutOutput InferLayoutStridedSlice(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Invalid Call"; + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + std::vector new_axes; + for (const auto& axis : attrs->axes) { + new_axes.push_back(FindAxis(existing_layout->layout, axis->value)); + } + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axes = std::move(new_axes); + return InferLayoutOutput({existing_layout}, {existing_layout}, Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.strided_slice") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The source tensor to be sliced.") + .set_attr("FInferStructInfo", InferStructInfoStridedSlice) + .set_attr("FRelaxInferLayout", InferLayoutStridedSlice) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/index.h b/src/relax/op/tensor/index.h new file mode 100644 index 000000000000..6944493a0fd6 --- /dev/null +++ b/src/relax/op/tensor/index.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file index.h + * \brief The functions to make Relax tensor indexing operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_INDEX_H_ +#define TVM_RELAX_OP_TENSOR_INDEX_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Take elements from a tensor along an axis. + * \param x The source tensor. + * \param indices The indices of the values to extract. + * It is required to be a one-dimensional tensor which has integer dtype. + * \param axis The axis over which to select values. + * If it is `NullOpt`, the input tensor is required to be one-dimensional. + * \return The taken result. + */ +Expr take(Expr x, Expr indices, Optional axis); + +/*! + * \brief Strided slice of a tensor. + * \param x The source tensor to be sliced. + * \param axes Axes along which slicing is applied. + * \param begin The indices to begin with in the slicing, inclusive. + * \param end The indices indicating end of the slice, exclusive. + * \param strides Specifies the stride values, it can be negative in that case, + * the input tensor will be reversed in that particular axis. + * If it is `NullOpt`, it by default is an list of ones of the same length as `axes`. + * \return The sliced result + */ +Expr strided_slice(Expr x, // + Array axes, // + Array begin, // + Array end, // + Optional> strides); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_INDEX_H_ diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc new file mode 100644 index 000000000000..afcc7fefe70e --- /dev/null +++ b/src/relax/op/tensor/linear_algebra.cc @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file linear_algebra.cc + * \brief Linear algebra operators. + */ + +#include "linear_algebra.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.matmul */ +TVM_REGISTER_NODE_TYPE(MatmulAttrs); + +Expr matmul(Expr x1, Expr x2, DataType out_dtype) { + ObjectPtr attrs = make_object(); + attrs->out_dtype = out_dtype; + + static const Op& op = Op::Get("relax.matmul"); + return Call(op, {std::move(x1), std::move(x2)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.matmul").set_body_typed(matmul); + +StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo x1_sinfo = input_sinfo[0]; + TensorStructInfo x2_sinfo = input_sinfo[1]; + + const auto* attrs = call->attrs.as(); + DataType out_dtype = attrs->out_dtype.is_void() + ? InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo) + : attrs->out_dtype; + + if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { + return TensorStructInfo(out_dtype, kUnknownNDim); + } + int x1_ndim = x1_sinfo->ndim; + int x2_ndim = x2_sinfo->ndim; + if (x1_ndim == 0 || x2_ndim == 0) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Matmul requires both inputs to have at least 1 dimension. However, " + << (x1_ndim == 0 ? "x1" : "x2") << " is a 0-rank tensor."); + } + + int x1_prepended = 0; + int x2_appended = 0; + if (x1_ndim == 1) { + x1_ndim = 2; + x1_prepended = 1; + } + if (x2_ndim == 1) { + x2_ndim = 2; + x2_appended = 1; + } + int output_ndim = std::max(x1_ndim, x2_ndim) - x1_prepended - x2_appended; + + const auto* x1_shape = x1_sinfo->shape.as(); + const auto* x2_shape = x2_sinfo->shape.as(); + if (x1_shape == nullptr || x2_shape == nullptr) { + return TensorStructInfo(out_dtype, output_ndim); + } + + Array x1_shape_prefix{x1_shape->values.begin(), + x1_shape->values.end() - 2 + x1_prepended}; + Array x2_shape_prefix{x2_shape->values.begin(), + x2_shape->values.end() - 2 + x2_appended}; + Optional> output_shape_prefix = + InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix); + if (!output_shape_prefix.defined()) { + return TensorStructInfo(out_dtype, output_ndim); + } + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr x1_reduction_length = x1_shape->values[x1_sinfo->ndim - 1]; + PrimExpr x2_reduction_length = x2_shape->values[x2_ndim - 2]; + if (analyzer->CanProve(x1_reduction_length != x2_reduction_length)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Matmul requires the reduction length of x1 and x2 to be equal. However, " + "the reduction lengths of x1 and x2 are " + << x1_reduction_length << " and " << x2_reduction_length << " respectively."); + } + + Array output_shape = output_shape_prefix.value(); + if (!x1_prepended) { + output_shape.push_back(x1_shape->values[x1_ndim - 2]); + } + if (!x2_appended) { + output_shape.push_back(x2_shape->values[x2_ndim - 1]); + } + ICHECK_EQ(static_cast(output_shape.size()), output_ndim); + return TensorStructInfo(ShapeExpr(output_shape), out_dtype); +} + +Call InferMixedPrecisionMatmul(const Call& call, const DataType& out_dtype) { + return Downcast(matmul(call->args[0], call->args[1], out_dtype)); +} + +TVM_REGISTER_OP("relax.matmul") + .set_num_inputs(2) + .add_argument("x1", "Tensor", "The first input tensor.") + .add_argument("x2", "Tensor", "The second input tensor.") + .set_attr("FInferStructInfo", InferStructInfoMatmul) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) + .set_attr("FInferMixedPrecision", InferMixedPrecisionMatmul); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/linear_algebra.h b/src/relax/op/tensor/linear_algebra.h new file mode 100644 index 000000000000..af614c1f30d5 --- /dev/null +++ b/src/relax/op/tensor/linear_algebra.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file linear_algebra.h + * \brief The functions to make Relax linear algebra operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_ +#define TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief General matrix multiplication of two tensors. + * The semantics and output shape deduction rule is specified as + * https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html. + * \param x1 The first input tensor. + * \param x2 The second input tensor. + * \param out_dtype The data type of the matmul result. + * When it is not specified, the output dtype will be the the same as input dtype. + * \return The computed result. + */ +Expr matmul(Expr x1, Expr x2, DataType out_dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_ diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc new file mode 100644 index 000000000000..faa5ee3bc099 --- /dev/null +++ b/src/relax/op/tensor/manipulate.cc @@ -0,0 +1,1363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file manipulate.cc + * \brief Manipulation operators. + */ + +#include "manipulate.h" + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/* relax.broadcast_to */ +Expr broadcast_to(Expr x, Expr shape) { + static const Op& op = Op::Get("relax.broadcast_to"); + return Call(op, {std::move(x), std::move(shape)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.broadcast_to").set_body_typed(broadcast_to); + +StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "broadcast_to should take 2 arguments."); + } + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* tgt_shape_sinfo = GetStructInfoAs(call->args[1]); + if (data_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "broadcast_to requires the input data to be Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (tgt_shape_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "broadcast_to requires the input new shape to be Shape. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + + if (!data_sinfo->IsUnknownNdim() && !tgt_shape_sinfo->IsUnknownNdim() && + tgt_shape_sinfo->ndim < data_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "broadcast_to expects the input shape to have the number of ndim at least " + "as the input tensor's. However, the given tensor has ndim " + << data_sinfo->ndim << " while the target shape has ndim " + << tgt_shape_sinfo->ndim); + } + + // Trust the input target shape when there is no possibility to do any compile-time check. + if (!data_sinfo->shape.defined()) { + return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype); + } + ShapeStructInfo shape_sinfo = Downcast(data_sinfo->shape.value()->struct_info_); + if (!shape_sinfo->values.defined() || !tgt_shape_sinfo->values.defined()) { + return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype); + } + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + Array old_shape_value = shape_sinfo->values.value(); + Array tgt_shape_value = tgt_shape_sinfo->values.value(); + int old_ndim = old_shape_value.size(); + int tgt_ndim = tgt_shape_value.size(); + for (int i = 0; i < old_ndim; ++i) { + PrimExpr old_len = old_shape_value[old_ndim - i - 1]; + PrimExpr tgt_len = tgt_shape_value[tgt_ndim - i - 1]; + const auto* old_len_int = old_len.as(); + if (old_len_int != nullptr && old_len_int->value == 1) { + continue; + } else if (analyzer->CanProve(old_len != tgt_len)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "broadcast_to expects the input tensor shape is broadcastable to the target shape. " + "The target shape at dim " + << tgt_ndim - i - 1 << " is " << tgt_len << " while the input tensor shape at dim " + << old_ndim - i - 1 << " is " << old_len << ", which are not equal."); + } + // Todo(relax-team): revisit here for better check on if the tensor length + // is consistent with the length in the given shape. + } + return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.broadcast_to") + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The target shape.") + .set_attr("FInferStructInfo", InferStructInfoBroadcastTo) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.concat */ +TVM_REGISTER_NODE_TYPE(ConcatAttrs); + +Expr concat(Expr tensors, Optional axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.concat"); + return Call(op, {std::move(tensors)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.concat").set_body_typed(concat); + +Array GetTensorSInfoFromTuple(const Call& call, const BlockBuilder& ctx, + const Expr& expr) { + const auto* tuple_sinfo = GetStructInfoAs(expr); + if (tuple_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << call->op + << " expects the input to be a Tuple of Tensors. However, the given input is " + << expr->struct_info_->GetTypeKey()); + } + + Array tensor_sinfo; + tensor_sinfo.reserve(tuple_sinfo->fields.size()); + for (StructInfo field_sinfo : tuple_sinfo->fields) { + const auto* field_tensor_sinfo = field_sinfo.as(); + if (field_tensor_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << call->op << " expects the input to be a Tuple of Tensors. However, the given input is " + << expr->struct_info_); + } + tensor_sinfo.push_back(GetRef(field_tensor_sinfo)); + } + return tensor_sinfo; +} + +Optional> CheckConcatOutputShape(const Call& call, const BlockBuilder& ctx, + const std::vector>& shape_values, + int axis) { + bool shape_unknown = false; + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr concat_sum = IntImm(DataType::Int(64), 0); + for (int d = 0; d < static_cast(shape_values[0].size()); ++d) { + // For the specified axis, we compute the sum of shape value over each tensor. + if (d == axis) { + for (Array shape_value : shape_values) { + concat_sum += shape_value[d]; + } + continue; + } + + // For other axes, we check the equality of all tensors' shape values, to ensure safety. + for (int i = 1; i < static_cast(shape_values.size()); ++i) { + if (analyzer->CanProve(shape_values[i][d] != shape_values[0][d])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects the input tensors to have the same shape on every " + "dimension except the one indicated by the input axis. However, the " + "input contains tensors whose shapes on dimension " + << d << " is " << shape_values[0][d] << " and " << shape_values[i][d]); + } else if (!analyzer->CanProveEqual(shape_values[i][d], shape_values[0][d])) { + shape_unknown = true; + } + } + } + + if (shape_unknown) { + return NullOpt; + } + Array output_shape = shape_values[0]; + output_shape.Set(axis, concat_sum); + return output_shape; +} + +StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument"); + } + Array tensor_sinfo = GetTensorSInfoFromTuple(call, ctx, call->args[0]); + if (tensor_sinfo.empty()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat op expects at least one tensor in the input Tuple. However, the " + "given input Tuple is empty."); + } + + const auto* attrs = call->attrs.as(); + int output_ndim = attrs->axis.defined() ? kUnknownNDim : 1; + DataType output_dtype = DataType::Void(); + bool shape_unknown = false; + bool is_void_dtype = false; + std::vector> shape_values; + shape_values.reserve(tensor_sinfo.size()); + + for (TensorStructInfo sinfo : tensor_sinfo) { + // Update the output dtype. + if (sinfo->dtype.is_void()) { + is_void_dtype = true; + } else if (output_dtype.is_void()) { + output_dtype = sinfo->dtype; + } else if (sinfo->dtype != output_dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects all input tensors to have the same dtype. However, the " + "input contains tensors with dtype " + << output_dtype << " and " << sinfo->dtype); + } + + // Update the output ndim. + // Todo(relax-team): revisit here for better check on if the input tensor has + // ndim 1 when the input axis is undefined. + if (output_ndim == kUnknownNDim) { + output_ndim = sinfo->ndim; + } else if (sinfo->ndim != kUnknownNDim && sinfo->ndim != output_ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects all input tensors to have same ndim. However, the " + "input contains tensors with ndim " + << output_ndim << " and " << sinfo->ndim); + } + + // Update the shape values for best effort check. + const auto* shape_expr = sinfo->shape.as(); + if (shape_expr != nullptr) { + shape_values.push_back(shape_expr->values); + continue; + } + shape_unknown = true; + + if (!sinfo->shape.defined()) { + continue; + } + // Keep the shape value for equality check. + ShapeStructInfo shape_sinfo = Downcast(sinfo->shape.value()->struct_info_); + if (shape_sinfo->values.defined()) { + shape_values.push_back(shape_sinfo->values.value()); + } + } + + if (is_void_dtype) { + output_dtype = DataType::Void(); + } + if (output_ndim == kUnknownNDim) { + return tensor_sinfo.size() == 1 ? tensor_sinfo[0] : TensorStructInfo(output_dtype, output_ndim); + } + + int axis = + attrs->axis.defined() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()->value) : 0; + // If there is only one input tensor, no action is needed. + if (tensor_sinfo.size() == 1) { + return tensor_sinfo[0]; + } + if (shape_values.empty()) { + return TensorStructInfo(output_dtype, output_ndim); + } + + // As long as the there is known shape value, we will do the best effort check to ensure safety. + Optional> output_shape = CheckConcatOutputShape(call, ctx, shape_values, axis); + + if (shape_unknown || !output_shape.defined()) { + return TensorStructInfo(output_dtype, output_ndim); + } else { + return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); + } +} + +InferLayoutOutput InferLayoutConcat(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Invalid Call"; + NLayout nlayout = GetNLayout(var_layout_map, call->args[0]); + ICHECK(nlayout.IsNested()); + ICHECK(nlayout.NestedArray()[0].IsLeaf()); + + int n_tensor = nlayout.NestedArray().size(); + LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); + Array input_layouts, output_layouts; + for (int i = 0; i < n_tensor; ++i) { + input_layouts.push_back(layout); + } + output_layouts.push_back(layout); + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axis = Integer(FindAxis(layout->layout, attrs->axis.value_or(0)->value)); + return InferLayoutOutput({NLayout(input_layouts)}, output_layouts, Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.concat") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") + .set_attr("FInferStructInfo", InferStructInfoConcat) + .set_attr("FRelaxInferLayout", InferLayoutConcat) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.expand_dims */ +TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); + +Expr expand_dims(Expr x, Array axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.expand_dims"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.expand_dims").set_body_typed(expand_dims); + +StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (attrs->axis.empty()) { + return data_sinfo; + } + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + int n_new_dim = attrs->axis.size(); + int output_ndim = data_sinfo->ndim + n_new_dim; + std::vector axes = NormalizeAxes(call, ctx, output_ndim, attrs->axis); + + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, output_ndim); + } + + std::vector output_shape; + output_shape.resize(output_ndim, PrimExpr()); + for (int i = 0; i < n_new_dim; ++i) { + output_shape[axes[i]] = IntImm(DataType::Int(64), 1); + } + + int i_data_shape = 0; + for (int i = 0; i < output_ndim; ++i) { + if (output_shape[i].defined()) { + continue; + } + ICHECK_LT(i_data_shape, data_sinfo->ndim); + output_shape[i] = data_shape->values[i_data_shape]; + ++i_data_shape; + } + ICHECK_EQ(i_data_shape, data_sinfo->ndim); + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +InferLayoutOutput InferLayoutExpandDims(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Invalid Call"; + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + + LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + int ndim = tensor_sinfo->ndim; + int n_new_dim = attrs->axis.size(); + int output_ndim = ndim + n_new_dim; + std::vector is_new_dim(output_ndim, false); + for (const auto& axis : attrs->axis) { + is_new_dim[(axis->value + output_ndim) % output_ndim] = true; + } + std::string new_layout; + for (int i = 0; i < output_ndim; ++i) { + if (!is_new_dim[i]) { + new_layout.push_back('A' + i); + } + } + new_layout = TransposeStrLike(new_layout, InitialLayout(ndim), existing_layout->layout); + std::string output_layout; + for (int i = 0, j = 0; i < output_ndim; ++i) { + if (is_new_dim[i]) { + output_layout.push_back('A' + i); + } else { + output_layout.push_back(new_layout.at(j++)); + } + } + return InferLayoutOutput({existing_layout}, {LayoutDecision(Layout(output_layout))}, + Attrs(call->attrs)); +} + +TVM_REGISTER_OP("relax.expand_dims") + .set_num_inputs(1) + .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoExpandDims) + .set_attr("FRelaxInferLayout", InferLayoutExpandDims) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +// Helper function for flatten and reshape. +PrimExpr ComputeShapeProduct(const Array& shape_values) { + PrimExpr shape_prod = IntImm(DataType::Int(64), 1); + for (PrimExpr value : shape_values) { + shape_prod *= value; + } + return shape_prod; +} + +/* relax.flatten */ +Expr flatten(Expr x) { + static const Op& op = Op::Get("relax.flatten"); + return Call(op, {std::move(x)}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.flatten").set_body_typed(flatten); + +StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1); + } else if (data_sinfo->ndim == 0) { + return TensorStructInfo(ShapeExpr({1}), data_sinfo->dtype); + } else if (data_sinfo->ndim == 1) { + return data_sinfo; + } + + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1); + } + PrimExpr shape_prod = ComputeShapeProduct(data_shape->values); + return TensorStructInfo(ShapeExpr({std::move(shape_prod)}), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.flatten") + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoFlatten) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.layout_transform */ +TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); + +Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value) { + ObjectPtr attrs = make_object(); + attrs->index_map = std::move(index_map); + attrs->pad_value = std::move(pad_value); + + static const Op& op = Op::Get("relax.layout_transform"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform); + +StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + tir::IndexMap index_map = attrs->index_map; + Optional optional_pad_value = attrs->pad_value; + + // Check pad_value has same dtype as input. + if (optional_pad_value.defined()) { + PrimExpr padded_value = optional_pad_value.value()->value; + if (padded_value->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "layout_transform pad_value dtype (" << padded_value->dtype + << ") and input dtype (" << data_sinfo->dtype << ") must be the same"); + } + } + + if (data_sinfo->IsUnknownNdim()) { + // Todo(relax-team): revisit here for better check on if the input tensor has desired ndim. + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + } + + // If rank is known, check that it is compatible with the index_map, i.e., #dims match. + if (index_map->initial_indices.size() != static_cast(data_sinfo->ndim)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "number of dimensions in input must match the number of source dimensions " + "in index map, but got " + << data_sinfo->ndim << " != " << index_map->initial_indices.size()); + } + + if (!data_sinfo->shape.defined()) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + } + + ShapeStructInfo shape_sinfo = Downcast(data_sinfo->shape.value()->struct_info_); + if (!shape_sinfo->values.defined()) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + } + + Array output_shape = index_map->MapShape(shape_sinfo->values.value()); + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.layout_transform") + .set_num_inputs(1) + .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoLayoutTransform) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.permute_dims */ +TVM_REGISTER_NODE_TYPE(PermuteDimsAttrs); + +Expr permute_dims(Expr x, Optional> axes) { + ObjectPtr attrs = make_object(); + attrs->axes = std::move(axes); + + static const Op& op = Op::Get("relax.permute_dims"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.permute_dims").set_body_typed(permute_dims); + +bool IsIdentityPermutation(const std::vector& permutation) { + for (int i = 0; i < static_cast(permutation.size()); ++i) { + if (permutation[i] != i) { + return false; + } + } + return true; +} + +StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + + // Todo(relax-team): revisit here for better check on if the input tensor has + // ndim same as the number of input axes. + if (!attrs->axes.defined() && data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + if (attrs->axes.defined()) { + int n_axis = attrs->axes.value().size(); + if (!data_sinfo->IsUnknownNdim() && n_axis != data_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "PermuteDims expects the number of input axes to equal the ndim of the " + "input tensor. However, the tensor ndim is " + << data_sinfo->ndim << " while the given number of axes is " << n_axis); + } + } + + std::vector axes; + if (attrs->axes.defined()) { + axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes.value()); + } else { + // Construct the reverse permutation via std::iota + axes.resize(data_sinfo->ndim); + std::iota(axes.rbegin(), axes.rend(), 0); + } + if (IsIdentityPermutation(axes)) { + return data_sinfo; + } + + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + std::vector new_shape; + new_shape.reserve(data_sinfo->ndim); + for (int i = 0; i < data_sinfo->ndim; ++i) { + new_shape.push_back(data_shape->values[axes[i]]); + } + return TensorStructInfo(ShapeExpr(new_shape), data_sinfo->dtype); +} + +InferLayoutOutput InferLayoutPermuteDims(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Invalid Call"; + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + int ndim = tensor_sinfo->ndim; + + LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + Array order; + if (attrs->axes.defined()) { + order = attrs->axes.value(); + } else { + order.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + order.push_back(Integer(ndim - i - 1)); + } + } + std::string order_str; + for (const auto& axis : order) { + order_str.push_back(axis->value + 'A'); + } + String new_axes = + TransposeStrLike(InitialLayout(ndim).name(), existing_layout->layout, order_str); + Array new_order; + for (size_t i = 0; i < new_axes.size(); ++i) { + new_order.push_back(Integer(new_axes.at(i) - 'A')); + } + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axes = new_order; + return InferLayoutOutput({existing_layout}, {InitialLayoutDecision(ndim)}, Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.permute_dims") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoPermuteDims) + .set_attr("FRelaxInferLayout", InferLayoutPermuteDims) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.reshape */ +Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { + if (const auto* e = shape.as()) { + return GetRef(e); + } + + const auto* array = shape.as(); + CHECK(array != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " + "Array of PrimExprs. However, the given new shape is " + << shape; + int dim_to_infer = -1; + // Keep track of which dimensions should be copied from input. + std::vector zero_dims; + for (int i = 0; i < static_cast(array->size()); ++i) { + const auto* _len = array->at(i).as(); + CHECK(_len != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " + "Array of PrimExprs. However, the given new shape is " + << shape; + PrimExpr len = GetRef(_len); + CHECK(len->dtype.is_int()) << "Reshape requires the new shape values to be all " + "integers. However, the give new shape is " + << shape; + const auto* int_len = len.as(); + if (int_len != nullptr && int_len->value == 0) { + // Note that this dimension should be copied from the original shape. + zero_dims.push_back(i); + } else if (int_len != nullptr && int_len->value == -1) { + CHECK_EQ(dim_to_infer, -1) << "Reshape accepts at most one \"-1\" in the new shape. However, " + "there are multiple \"-1\" in the given new shape " + << shape; + dim_to_infer = i; + } else { + CHECK(int_len == nullptr || int_len->value > 0) + << "Reshape requires all values in the new shape to be positive except a single \"-1\". " + "However, the given new shape is " + << shape; + } + } + + Array array_ref = GetRef>(array); + // When there is no dimension to infer, just return the input array as ShapeExpr. + if (dim_to_infer == -1 && zero_dims.empty()) { + return ShapeExpr(array_ref); + } + + // Otherwise, we require the input tensor to have known shape value for inference. + const auto* data_sinfo = GetStructInfoAs(data); + CHECK(data_sinfo != nullptr) + << "Reshape expects the input data to be a Tensor. However, the given input is " + << data->struct_info_->GetTypeKey(); + CHECK(data_sinfo->shape.defined()) + << "Reshape expects the input tensor to have known shape when there is some dimension length " + "to infer. However, the given input has no shape."; + const auto* shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); + CHECK(shape_sinfo != nullptr && shape_sinfo->values.defined()) + << "Reshape expects the input tensor to have known shape when there is some dimension length " + "to infer. However, the given input shape is " + << data_sinfo->shape << " whose shape value is unknown."; + + // Set any 0 valued dimensions to match the corresponding input shape. + if (!zero_dims.empty()) { + for (int i : zero_dims) { + array_ref.Set(i, shape_sinfo->values.value()[i]); + } + } + + // Set any -1 dimensions to complete the number of appropriate elements. + // Start by computing the shape product of all positive indices. + PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1); + for (int i = 0; i < static_cast(array_ref.size()); ++i) { + PrimExpr new_dim = array_ref[i]; + const auto* int_dim = new_dim.as(); + // We expect any symbolic not to signal the intent of -1, and therefore do no check for + // symbolic value here. + if (int_dim == nullptr || int_dim->value > 0) { + new_shape_prod = new_shape_prod * new_dim; + } + } + + // Assign appropriate value to -1 dimension. + if (dim_to_infer != -1) { + arith::Analyzer analyzer; + PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value()); + array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod))); + } + return ShapeExpr(array_ref); +} + +Expr reshape(Expr x, ObjectRef shape) { + Expr shape_in_expr = ConvertNewShapeToExpr(x, shape); + static const Op& op = Op::Get("relax.reshape"); + return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.reshape").set_body_typed(reshape); + +StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "Reshape op should take 2 arguments"); + } + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* new_shape_sinfo = GetStructInfoAs(call->args[1]); + if (data_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Reshape requires the input data to be Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (new_shape_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Reshape requires the input new shape to be Shape. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + + Optional> old_shape_values; + if (data_sinfo->shape.defined()) { + const auto* old_shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); + ICHECK_NOTNULL(old_shape_sinfo); + old_shape_values = old_shape_sinfo->values; + } + + if (new_shape_sinfo->values.defined() && old_shape_values.defined()) { + PrimExpr new_shape_prod = ComputeShapeProduct(new_shape_sinfo->values.value()); + PrimExpr old_shape_prod = ComputeShapeProduct(old_shape_values.value()); + if (ctx->GetAnalyzer()->CanProve(old_shape_prod != new_shape_prod)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Reshape expects the new shape to be convertible from the old shape. " + "However, the old shape is " + << data_sinfo->shape << ", with product " << old_shape_prod + << ", while the new shape is " << call->args[1] << ", with product " + << new_shape_prod); + } + } + Expr target_shape = call->args[1]; + // If shape values are defined, use them + if (target_shape->IsInstance() && new_shape_sinfo->values.defined()) { + return TensorStructInfo(ShapeExpr(new_shape_sinfo->values.value()), data_sinfo->dtype); + } + return TensorStructInfo(target_shape, data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.reshape") + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The input new shape.") + .set_attr("FInferStructInfo", InferStructInfoReshape) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.split */ +TVM_REGISTER_NODE_TYPE(SplitAttrs); + +Expr split(Expr x, ObjectRef indices_or_sections, int axis) { + ObjectPtr attrs = make_object(); + if (const auto* indices = indices_or_sections.as()) { + for (int i = 0; i < static_cast(indices->size()); ++i) { + const auto* idx = indices->at(i).as(); + CHECK(idx != nullptr) << "Split op only accepts an array of integers as the indices. " + "However, the given indices " + << indices_or_sections << " contains some non-integer."; + } + indices_or_sections = ConvertIntImmToInt64(GetRef>(indices)); + } else if (const auto* n_section = indices_or_sections.as()) { + CHECK_GT(n_section->value, 0) << "Split op expects the input number of sections to be a " + "positive integer. However, the given number of sections is " + << n_section->value; + indices_or_sections = IntImm(DataType::Int(64), n_section->value); + } else { + LOG(FATAL) << "Split op expects the input indices_or_sections to be either an Array of " + "PrimExpr or an integer. However, the given one is " + << indices_or_sections->GetTypeKey(); + } + attrs->indices_or_sections = indices_or_sections; + attrs->axis = axis; + + static const Op& op = Op::Get("relax.split"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.split").set_body_typed(split); + +StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + const auto* data_shape = data_sinfo->shape.as(); + int axis = + data_sinfo->IsUnknownNdim() ? -1 : NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); + + if (const auto* p_indices = attrs->indices_or_sections.as()) { + // When there is not index, return the input tensor's struct info. + if (p_indices->size() == 0) { + return TupleStructInfo({data_sinfo}); + } + // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. + if (data_shape == nullptr) { + return TupleStructInfo(Array( + p_indices->size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim))); + } + + ICHECK_NE(axis, -1); + const auto* axis_length = data_shape->values[axis].as(); + // Fall back to unknown shape when the input tensor shape at the given axis is symbolic. + if (axis_length == nullptr) { + return TupleStructInfo(Array( + p_indices->size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim))); + } + + // Only do output shape inference when all the indices and the total length are integers. + Array indices = GetRef>(p_indices); + IntImm zero(DataType::Int(64), /*value=*/0); + indices.insert(indices.begin(), zero); + indices.insert(indices.end(), Downcast(data_shape->values[axis])); + + std::vector output_sinfo; + output_sinfo.reserve(indices.size() - 1); + for (int i = 0; i + 1 < static_cast(indices.size()); ++i) { + PrimExpr l = tvm::max(zero, indices[i]); + PrimExpr r = tvm::min(data_shape->values[axis], indices[i + 1]); + + Array shape = data_shape->values; + shape.Set(axis, tvm::max(zero, r - l)); + output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype)); + } + return TupleStructInfo(output_sinfo); + } else if (const auto* p_n_section = attrs->indices_or_sections.as()) { + ICHECK_GT(p_n_section->value, 0); + int n_section = p_n_section->value; + // When the number of section is one, return the input tensor's struct info. + if (n_section == 1) { + return TupleStructInfo({data_sinfo}); + } + // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. + if (data_shape == nullptr) { + return TupleStructInfo( + Array(n_section, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim))); + } + ICHECK_NE(axis, -1); + PrimExpr split_len = ceildiv(data_shape->values[axis], n_section); + + // Construct struct info for tensors except the last one. + Array shape = data_shape->values; + shape.Set(axis, split_len); + std::vector output_sinfo(n_section - 1, + TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype)); + + // Construct struct info for the last tensor. + shape.Set(axis, data_shape->values[axis] - split_len * (n_section - 1)); + output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype)); + return TupleStructInfo(output_sinfo); + } + ICHECK(false) << "Cannot reach here."; + throw; +} + +InferLayoutOutput InferLayoutSplit(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Invalid Call"; + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + + LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis); + StructInfo out_sinfo = InferStructInfoSplit(call, BlockBuilder::Create(IRModule())); + const auto* out_tuple = out_sinfo.as(); + ICHECK(out_tuple != nullptr) << "Invalid Call"; + NLayout tuple_layouts(Array(out_tuple->fields.size(), existing_layout)); + return InferLayoutOutput({existing_layout}, {tuple_layouts}, Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.split") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoSplit) + .set_attr("FRelaxInferLayout", InferLayoutSplit) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +/* relax.squeeze */ +TVM_REGISTER_NODE_TYPE(SqueezeAttrs); + +Expr squeeze(Expr x, Optional> axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.squeeze"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.squeeze").set_body_typed(squeeze); + +StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (attrs->axis.defined() && attrs->axis.value().empty()) { + return data_sinfo; + } + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + Optional> shape_value; + if (data_sinfo->shape.defined()) { + shape_value = Downcast(data_sinfo->shape.value()->struct_info_)->values; + } + + std::vector axis_removal_mask; + axis_removal_mask.resize(data_sinfo->ndim, /*value=*/false); + + if (attrs->axis.defined()) { + std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value()); + + if (!shape_value.defined()) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim - axes.size()); + } + for (int i = 0; i < static_cast(axes.size()); ++i) { + // Todo(relax-team): revisit here for better check on if the axis being squeezed has length 1. + // When `axis` is given, the dim lengths at the axes must be integer 1 when it is not symbolic + const auto* int_len = shape_value.value()[axes[i]].as(); + if (int_len != nullptr && int_len->value != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Squeeze expects the input tensor shape values at the given axis " + "positions to be all 1. However, the tensor shape at axis " + << axes[i] << " is " << shape_value.value()[axes[i]] + << " which is not 1. If it is symbolic, please use MatchCast to cast it " + "to 1 before doing Squeeze."); + } + axis_removal_mask[axes[i]] = true; + } + } else { + // When `axis` is not defined, squeeze all unit-length dimensions. + // Note: This is a less well-defined path in Array API standard's squeeze + // (https://data-apis.org/array-api/latest/API_specification/generated/array_api.squeeze.html). + // Consider discourage usage later. + if (!shape_value.defined()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + for (int i = 0; i < data_sinfo->ndim; ++i) { + // Whenever a dimension length is symbolic, fall back to unknown ndim. + const auto* int_len = shape_value.value()[i].as(); + if (int_len == nullptr) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + if (int_len->value == 1) { + axis_removal_mask[i] = true; + } + } + } + + std::vector output_shape; + output_shape.reserve(data_sinfo->ndim - axis_removal_mask.size()); + for (int i = 0; i < data_sinfo->ndim; ++i) { + if (!axis_removal_mask[i]) { + output_shape.push_back(shape_value.value()[i]); + } + } + + if (data_sinfo->shape.value()->IsInstance()) { + if (static_cast(output_shape.size()) == data_sinfo->ndim) { + return data_sinfo; + } else if (attrs->axis.defined()) { + return TensorStructInfo(data_sinfo->dtype, output_shape.size()); + } else { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + } else { + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); + } +} + +InferLayoutOutput InferLayoutSqueeze(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Invalid Call"; + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + ICHECK(tensor_sinfo->shape.defined()) << "Only support static shape for now"; + int ndim = tensor_sinfo->ndim; + const auto* shape = tensor_sinfo->shape.as(); + ICHECK(shape != nullptr) << "Only support static shape for now"; + + Array axis; + if (attrs->axis.defined()) { + axis = attrs->axis.value(); + } else { + axis.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + if (tir::is_one(shape->values[i])) { + axis.push_back(Integer(i)); + } + } + } + + std::string axis_str(ndim, '0'); + for (const auto& iter : axis) { + axis_str[iter->value] = '1'; + } + for (int i = 0, j = 0; i < ndim; ++i) { + if (axis_str[i] != '1') { + axis_str[i] = 'A' + j++; + } + } + + LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout); + Array new_axis; + for (size_t i = 0; i < new_axis_str.size(); ++i) { + if (new_axis_str.at(i) == '1') { + new_axis.push_back(Integer(i)); + } + } + std::string output_layout = new_axis_str; + output_layout.erase(std::remove(output_layout.begin(), output_layout.end(), '1'), + output_layout.end()); + + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axis = new_axis; + return InferLayoutOutput({existing_layout}, {LayoutDecision(Layout(output_layout))}, + Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.squeeze") + .set_num_inputs(1) + .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoSqueeze) + .set_attr("FRelaxInferLayout", InferLayoutSqueeze) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +void CheckCollapseShape(const Call& call, const BlockBuilder& ctx, + const Array& data_shape, const Array& target_shape) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + + int data_ndim = data_shape.size(); + int target_ndim = target_shape.size(); + + int data_ax = data_ndim - 1; + int target_ax = target_ndim - 1; + for (; data_ax >= 0; --data_ax) { + if (target_ax < 0) { + continue; + } + const PrimExpr& dim0 = data_shape[data_ax]; + const PrimExpr& dim1 = target_shape[target_ax]; + const auto* int_dim0 = dim0.as(); + const auto* int_dim1 = dim1.as(); + + if (analyzer->CanProveEqual(dim0, dim1) || (int_dim1 != nullptr && int_dim1->value == 1)) { + --target_ax; + } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) { + ctx->ReportFatal(Diagnostic::Error(call) + << "In " << call->op << ", the data shape at dim " << data_ax << " is " + << dim0 << " and the target shape at dim " << target_ax << " is " << dim1 + << ", which do not match the rule of collapse sum."); + } else { + // Todo(relax-team): At this moment, enforcing MatchCast is fine. But we may need to revisit + // this requirement to reduce the workload of importers and better support dynamic shapes. + ctx->ReportFatal(Diagnostic::Error(call) + << call->op + << " fails to match the axes because of unknown dim or symbolic" + " shape. In this position the dim of data shape is " + << dim0 << " while the dim of target shape is " << dim1 + << ". If it is symbolic, consider use MatchCast first."); + } + } +} + +/* relax.collapse_sum_like */ +Expr collapse_sum_like(Expr data, Expr collapse_target) { + static const Op& op = Op::Get("relax.collapse_sum_like"); + return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.collapse_sum_like").set_body_typed(collapse_sum_like); + +StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo collapse_target_sinfo = input_sinfo[1]; + + DataType output_dtype = data_sinfo->dtype; + + Optional> data_shape_value; + if (data_sinfo->shape.defined()) { + data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; + } + Optional> collapse_target_shape_value; + if (collapse_target_sinfo->shape.defined()) { + collapse_target_shape_value = + GetStructInfoAs(collapse_target_sinfo->shape.value())->values; + } + + if (data_shape_value.defined() && collapse_target_shape_value.defined()) { + CheckCollapseShape(call, ctx, data_shape_value.value(), collapse_target_shape_value.value()); + } + + if (collapse_target_sinfo->shape.defined()) { + return TensorStructInfo(collapse_target_sinfo->shape.value(), output_dtype); + } else { + return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim); + } +} + +TVM_REGISTER_OP("relax.collapse_sum_like") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("collapse_target", "Tensor", + "The tensor whose shape is the shape to collapse to.") + .set_attr("FInferStructInfo", InferStructInfoCollapseSumLike); + +/* relax.collapse_sum_to */ +Expr collapse_sum_to(Expr data, Expr shape) { + static const Op& op = Op::Get("relax.collapse_sum_to"); + return Call(op, {std::move(data), std::move(shape)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.collapse_sum_to").set_body_typed(collapse_sum_to); + +StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "CollapseSumTo should have 2 arguments"); + } + + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* shape_sinfo = GetStructInfoAs(call->args[1]); + + if (data_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "CollapseSumTo requires the input data to be a Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (shape_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "CollapseSumTo requires the input shape to be a Shape. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + + DataType output_dtype = data_sinfo->dtype; + + Optional> data_shape_value; + if (data_sinfo->shape.defined()) { + data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; + } + + if (data_shape_value.defined() && shape_sinfo->values.defined()) { + CheckCollapseShape(call, ctx, data_shape_value.value(), shape_sinfo->values.value()); + } + + return TensorStructInfo(/*shape=*/call->args[1], output_dtype); +} + +TVM_REGISTER_OP("relax.collapse_sum_to") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The shape to collapse to.") + .set_attr("FInferStructInfo", InferStructInfoCollapseSumTo); + +/* relax.repeat */ +TVM_REGISTER_NODE_TYPE(RepeatAttrs); + +Expr repeat(Expr data, int repeats, Optional axis) { + auto attrs = make_object(); + attrs->repeats = std::move(repeats); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.repeat"); + return Call(op, {std::move(data)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.repeat").set_body_typed(repeat); + +StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + const auto* data_shape = data_sinfo->shape.as(); + + if (attrs->axis.defined() && !data_sinfo->IsUnknownNdim()) { + int axis = attrs->axis.value()->value; + int ndim = data_sinfo->ndim; + if (axis < -ndim || axis >= ndim) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Repeat requires the input axis belongs range " + "[-data.struct_info.ndim, data.struct_info.ndim - 1]. However, the input axis is " + << axis << ", while ndim is " << ndim); + } + } + + if (data_shape == nullptr) { + if (attrs->axis.defined()) { + if (analyzer->CanProveEqual(attrs->repeats, 1)) { + // the shape does not changes + return data_sinfo; + } else { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + } else { + return TensorStructInfo(data_sinfo->dtype, 1); + } + } + + if (!attrs->axis.defined()) { + PrimExpr new_shape = + analyzer->Simplify(ComputeShapeProduct(data_shape->values) * attrs->repeats); + return TensorStructInfo(ShapeExpr(Array({new_shape})), data_sinfo->dtype); + } + + int axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value); + auto shape_array = data_shape->values; + shape_array.Set(axis, analyzer->Simplify(shape_array[axis] * attrs->repeats)); + return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype); +} + +// TODO(relax-team): implement FRelaxInferLayout for repeat +TVM_REGISTER_OP("relax.repeat") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoRepeat); + +/* relax.tile */ +TVM_REGISTER_NODE_TYPE(TileAttrs); + +Expr tile(Expr data, Array repeats) { + auto attrs = make_object(); + attrs->repeats = std::move(repeats); + + static const Op& op = Op::Get("relax.tile"); + return Call(op, {std::move(data)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.tile").set_body_typed(tile); + +StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + const auto* data_shape = data_sinfo->shape.as(); + int l = attrs->repeats.size(); + int ndim = data_sinfo->ndim; + + if (data_shape == nullptr) { + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + if (l > ndim) { + return TensorStructInfo(data_sinfo->dtype, l); + } else { + for (auto i : attrs->repeats) { + if (!analyzer->CanProveEqual(i, 1)) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + } + // if control reaches here, the shape should not be changed + return data_sinfo; + } + } + + int out_ndim = std::max(l, ndim); + int l_delta = out_ndim - l; + int ndim_delta = out_ndim - ndim; + Array out_shape; + for (int i = 0; i < out_ndim; ++i) { + if (i < l_delta) { + out_shape.push_back(data_shape->values[i - ndim_delta]); + } else if (i < ndim_delta) { + out_shape.push_back(attrs->repeats[i - l_delta]); + } else { + out_shape.push_back( + analyzer->Simplify(data_shape->values[i - ndim_delta] * attrs->repeats[i - l_delta])); + } + } + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); +} + +// TODO(relax-team): implement FRelaxInferLayout for tile +TVM_REGISTER_OP("relax.tile") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoTile); + +/* relax.cumsum */ +TVM_REGISTER_NODE_TYPE(CumsumAttrs); + +Expr cumsum(Expr data, Optional axis, DataType dtype) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + attrs->dtype = std::move(dtype); + + static const Op& op = Op::Get("relax.cumsum"); + return Call(op, {std::move(data)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.cumsum").set_body_typed(cumsum); + +StructInfo InferStructInfoCumsum(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + + DataType out_type = attrs->dtype.is_void() ? data_sinfo->dtype : attrs->dtype; + + if (!attrs->axis.defined()) { + // flattened + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(out_type, data_sinfo->ndim); + } else { + PrimExpr flattened_d = 1; + for (const auto v : data_shape->values) { + flattened_d *= v; + } + return TensorStructInfo(ShapeExpr(Array({flattened_d})), out_type); + } + } + + if (data_sinfo->shape.defined()) { + return TensorStructInfo(data_sinfo->shape.value(), out_type); + } else { + return TensorStructInfo(out_type, data_sinfo->ndim); + } +} + +TVM_REGISTER_OP("relax.cumsum") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoCumsum); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h new file mode 100644 index 000000000000..592d8b13479e --- /dev/null +++ b/src/relax/op/tensor/manipulate.h @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file manipulate.h + * \brief The functions to make Relax tensor manipulation operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_MANIPULATE_H_ +#define TVM_RELAX_OP_TENSOR_MANIPULATE_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief Broadcasts a tensor to a specified shape. */ +Expr broadcast_to(Expr x, Expr shape); + +/*! + * \brief Concatenate the input tensors along the given axis. + * \param tensors An Expr in Tuple type, containing the tensors to be concatenated, + * or a list of tensors + * \param axis The axis along which the tensors are concatenated. + * If it is `NullOpt`, the input tensor is required to be flattened before concatenation. + * \return The concatenated tensor. + */ +Expr concat(Expr tensors, Optional axis); + +/*! + * \brief Insert new axes at the positions given by `axis`. + * \param x The input data to the operator. + * \param axis The axes at which the input array are expanded. + * \return The transformed result. + */ +Expr expand_dims(Expr x, Array axis); + +/*! + * \brief Flatten all the tensor dimensions into one. + * \param x The input data to the operator. + * \return The flattened result. + */ +Expr flatten(Expr x); + +/*! + * \brief Transform layout of a tensor. + * \param x The input data to the operator. + * \param index_map The transformation to apply. + * \param pad_value The value used for padding if the transformation results in implicit padding. If + * not specified, any value can be used. + * \return The transformed result. + */ +Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value); + +/*! + * \brief Permutes the dimensions of an array. + * \param x The input data to the operator. + * \param axes The target axes order, reverse order if not specified. + * \return The transposed result. + */ +Expr permute_dims(Expr x, Optional> axes); + +/*! + * \brief Reshape the input array, supporting `-1` inference in the new + * shape when the new shape is given as an Array of PrimExpr. + * \param x The input data to the operator. + * \param shape The new shape. Should be compatible with the original shape. + * It is required to be either an Array of PrimExpr, or a Shape in Relax + * \return The reshaped result. + */ +Expr reshape(Expr x, ObjectRef shape); + +/*! + * \brief Split input tensor along axis by sections or indices. + * - If indices_or_sections is an integer, the input will be divided equally + * along given axis (if possible). Last section will be smaller if the tensor + * size along the given dimension is not divisible by the integer. + * - If indices_or_sections is a tuple of mixture of int or PrimExpr, + * the entries indicate the indices where along axis the array is split. + * \param x The tensor to be split. + * \param indices_or_sections Indices or sections to split into. + * It is required to be an Array of PrimExpr or an integer. + * \param axis The axis over which to split. + * \return The computed result. + */ +Expr split(Expr x, ObjectRef indices_or_sections, int axis); + +/*! + * \brief Squeeze axes in the array. + * \param x The input data to the operator. + * \param axis The set of axes to remove. + * If it is `NullOpt`, remove all axis of dimensions 1. + * If any specified axis has dimension that does not equal 1, it is an error. + * \return The squeezed result. + */ +Expr squeeze(Expr x, Optional> axis); + +/*! + * \brief Return a summation of data to the shape of collapse_target. + * For details, please see the operator `relax.collapse_sum_to`. + * \param data The input tensor. + * \param collapse_target The tensor whose shape is the shape to collapse to. + * \return The result tensor after summation. + */ +Expr collapse_sum_like(Expr data, Expr collapse_target); + +/*! + * \brief Return a summation of data to the given shape. + * collapse_sum_to is intended as the backward operator of broadcast_to and + * other broadcast operators in the automatic differentiation process. + * We expect that data is the result of broadcasting some tensor of the given shape in some + * broadcast operation. Thus the given shape and data.shape must follow broadcast rules. + * \param data The input tensor. + * \param shape The shape to collapse to. + * \return The result tensor of the given shape after summation. + */ +Expr collapse_sum_to(Expr data, Expr shape); + +/*! + * \brief Repeats elements of an array. + * \param data The input tensor. + * \param repeats The number of repetitions. + * \param axis The axis along which to repeat values. The negative numbers are interpreted counting + * from the backward. By default, use the flattened input array, and return a flat output array. + * \return The computed result. + */ +Expr repeat(Expr data, int repeats, Optional axis = NullOpt); + +/*! + * \brief Construct an array by repeating data the number of times given by reps. + * + * If reps has length l, and data has dimension d, the result will have dimension of max(l, d). + * + * If d < l, data is promoted to be l-dimensional by prepending new axes. So a shape (3,) Tensor is + * promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D replication. If this is not + * the desired behavior, promote data to d-dimensions manually before calling this function. + * + * If d > l, reps is promoted to length d by pre-pending 1's to it. Thus for a data of shape + * (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2). + * \param data The input tensor. + * \param repeats The number of repetitions of data along each axis. + * \return The computed result. + */ +Expr tile(Expr data, Array repeats); + +/*! + * \brief Numpy style cumsum op. Return the cumulative inclusive sum of the elements along + * a given axis. + * \param data The input tensor. + * \param axis Axis along which the cumulative sum is computed. The default (None) is to compute + * the cumsum over the flattened array. + * \param dtype Type of the returned array and of the accumulator in which the elements are summed. + * If dtype is not specified, it defaults to the dtype of data. + * \return The computed result. + */ +Expr cumsum(Expr data, Optional axis = NullOpt, DataType dtype = DataType::Void()); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_MANIPULATE_H_ diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc new file mode 100644 index 000000000000..71f37c743ff2 --- /dev/null +++ b/src/relax/op/tensor/search.cc @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file search.cc + * \brief Searching operators. + */ + +#include "search.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.where */ +Expr where(Expr condition, Expr x1, Expr x2) { + static const Op& op = Op::Get("relax.where"); + return Call(op, {std::move(condition), std::move(x1), std::move(x2)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.where").set_body_typed(where); + +StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo cond_sinfo = input_sinfo[0]; + TensorStructInfo x1_sinfo = input_sinfo[1]; + TensorStructInfo x2_sinfo = input_sinfo[2]; + + if (!cond_sinfo->dtype.is_bool()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Where requires the input condition tensor to have boolean dtype. However, " + "the given condition dtype is " + << cond_sinfo->dtype); + } + DataType output_dtype = InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo); + + int output_ndim; + if (cond_sinfo->IsUnknownNdim() || x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { + output_ndim = kUnknownNDim; + } else { + output_ndim = std::max(cond_sinfo->ndim, std::max(x1_sinfo->ndim, x2_sinfo->ndim)); + } + + const auto* cond_shape = cond_sinfo->shape.as(); + const auto* x1_shape = x1_sinfo->shape.as(); + const auto* x2_shape = x2_sinfo->shape.as(); + if (cond_shape && x1_shape && x2_shape) { + // Step 1. Compute the broadcasted shape of x1's and x2's + Optional> broadcasted_shape = + InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); + if (!broadcasted_shape.defined()) { + return TensorStructInfo(output_dtype, output_ndim); + } + // Step 2. Compute the broadcasted shape of cond's and the previous broadcasted shape. + broadcasted_shape = + InferBinaryBroadcastShape(call, ctx, cond_shape->values, broadcasted_shape.value()); + if (!broadcasted_shape.defined()) { + return TensorStructInfo(output_dtype, output_ndim); + } + ICHECK_EQ(static_cast(broadcasted_shape.value().size()), output_ndim); + return TensorStructInfo(ShapeExpr(broadcasted_shape.value()), output_dtype); + } else if (cond_sinfo->shape.defined() && // + x1_sinfo->shape.defined() && // + x2_sinfo->shape.defined() && // + cond_sinfo->shape.same_as(x1_sinfo->shape) && // + cond_sinfo->shape.same_as(x2_sinfo->shape)) { + return TensorStructInfo(cond_sinfo->shape.value(), output_dtype); + } else { + return TensorStructInfo(output_dtype, output_ndim); + } +} + +TVM_REGISTER_OP("relax.where") + .set_num_inputs(3) + .add_argument("condition", "Tensor", "When True, yield `x1`; otherwise, yield `x2`.") + .add_argument("x1", "Tensor", "The first input tensor.") + .add_argument("x2", "Tensor", "The second input tensor.") + .set_attr("FInferStructInfo", InferStructInfoWhere); + +/* relax.argmax & relax.argmin */ +TVM_REGISTER_NODE_TYPE(ArgmaxArgminAttrs); + +StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + + int axis = -1; + if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) { + axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value); + } + + int out_ndim; + if (attrs->keepdims) { + out_ndim = data_sinfo->ndim; + } else if (!attrs->axis.defined()) { + out_ndim = 0; + } else if (data_sinfo->IsUnknownNdim()) { + out_ndim = kUnknownNDim; + } else { + out_ndim = data_sinfo->ndim - 1; + ICHECK_GE(out_ndim, 0); + } + + DataType out_dtype = DataType::Int(64); + // The inference rule for reduction operator output shapes: + // - axes is None, keepdims is false -> return the zero-rank shape; + // - axes is None, keepdims is true -> return the shape whose ndim is the same as input and every + // value is 1. + // - axes is not None, keepdims is false -> the returned shape does not contain the input axes. + // - axes is not None, keepdims is true -> the returned shape has value 1 at the positions of the + // input axes + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) { + return TensorStructInfo(ShapeExpr(Array(out_ndim, IntImm(out_dtype, /*value=*/1))), + out_dtype); + } else { + return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array()), out_dtype) + : TensorStructInfo(out_dtype, out_ndim); + } + } + + if (data_sinfo->ndim > 0) { + out_dtype = data_shape->values[0]->dtype; + } + + Array out_shape; + out_shape.reserve(out_ndim); + for (int i = 0; i < data_sinfo->ndim; ++i) { + if (attrs->axis.defined() && i != axis) { + out_shape.push_back(data_shape->values[i]); + } else if (attrs->keepdims) { + out_shape.push_back(IntImm(out_dtype, /*value=*/1)); + } + } + ICHECK_EQ(static_cast(out_shape.size()), out_ndim); + return TensorStructInfo(ShapeExpr(out_shape), out_dtype); +} + +#define RELAX_REGISTER_ARGMAX_ARGMIN_OP(OpName) \ + Expr OpName(Expr x, Optional axis, bool keepdims) { \ + ObjectPtr attrs = make_object(); \ + attrs->axis = std::move(axis); \ + attrs->keepdims = std::move(keepdims); \ + static const Op& op = Op::Get("relax." #OpName); \ + return Call(op, {std::move(x)}, Attrs(attrs)); \ + } \ + TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_REGISTER_OP("relax." #OpName) \ + .set_num_inputs(1) \ + .add_argument("x", "Tensor", "The input data tensor") \ + .set_attr("FInferStructInfo", InferStructInfoArgmaxArgmin); + +RELAX_REGISTER_ARGMAX_ARGMIN_OP(argmax); +RELAX_REGISTER_ARGMAX_ARGMIN_OP(argmin); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/search.h b/src/relax/op/tensor/search.h new file mode 100644 index 000000000000..ad9f8b09ecca --- /dev/null +++ b/src/relax/op/tensor/search.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file search.h + * \brief The functions to make Relax searching operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_SEARCH_H_ +#define TVM_RELAX_OP_TENSOR_SEARCH_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Selecting elements from either the input tensors depending on the value of the + * condition. + */ +Expr where(Expr condition, Expr x1, Expr x2); + +/*! \brief Computes the argmax of tensor elements over given axis. */ +Expr argmax(Expr x, Optional axis, bool keepdims); + +/*! \brief Computes the argmin of tensor elements over given axis. */ +Expr argmin(Expr x, Optional axis, bool keepdims); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_SEARCH_H_ diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc new file mode 100644 index 000000000000..8df0813ed2b5 --- /dev/null +++ b/src/relax/op/tensor/set.cc @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file set.cc + * \brief Relax set operators. + */ + +#include "set.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.unique */ + +Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse, + PrimValue return_counts, Optional axis) { + static const Op& op = Op::Get("relax.unique"); + Call call; + if (!axis) { + call = Call(op, {std::move(x), sorted, return_index, return_inverse, return_counts}); + } else { + PrimValue pv_axis = axis.value(); + call = Call(op, {std::move(x), sorted, return_index, return_inverse, return_counts, pv_axis}); + } + return call; +} + +TVM_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique); + +StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = Downcast(call->args[0]->struct_info_); + PrimValue axis, return_index, return_inverse, return_counts; + if (call->args.size() == 6) { + if (auto* prim_value_node = call->args[5].as()) { + axis = GetRef(prim_value_node); + } + } + if (!data_sinfo->IsUnknownNdim() && axis.defined()) { + // Normalize the axis for sanity check purpose. + if (const auto* axis_int = axis->value.as()) { + NormalizeAxis(call, ctx, data_sinfo->ndim, axis_int->value); + } + } + ICHECK(call->args[2]->IsInstance()); + ICHECK(call->args[3]->IsInstance()); + ICHECK(call->args[4]->IsInstance()); + + return_index = Downcast(call->args[2]); + return_inverse = Downcast(call->args[3]); + return_counts = Downcast(call->args[4]); + + auto f_convert_to_int64 = [](const PrimExpr& value) { + CHECK(value->IsInstance()) + << value << " expects to be IntImm, but gets " << value->GetTypeKey(); + const auto* val_node = value.as(); + auto val_imm = GetRef(val_node); + return val_imm->value; + }; + + int64_t n_int_return = f_convert_to_int64(return_index->value) + + f_convert_to_int64(return_inverse->value) + + f_convert_to_int64(return_counts->value); + + std::vector output_sinfo; + output_sinfo.reserve(1 + n_int_return); + + // unique values + if (data_sinfo->ndim == 0) { + output_sinfo.push_back( + TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), data_sinfo->dtype)); + } else if (axis.defined()) { + output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim)); + } else { + output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1)); + } + + // index, reverse and counts + TensorStructInfo int_return{nullptr}; + if (data_sinfo->ndim == 0) { + int_return = + TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), DataType::Int(64)); + } else { + int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1); + } + for (int i = 0; i < n_int_return; ++i) { + output_sinfo.push_back(int_return); + } + + if (output_sinfo.size() == 1) { + return output_sinfo[0]; + } else { + return TupleStructInfo(output_sinfo); + } +} + +TVM_REGISTER_OP("relax.unique") + .set_num_inputs(6) + .add_argument("x", "Tensor", "The input tensor") + .add_argument( + "sorted", "Tensor", + "Whether to sort the unique elements in ascending order before returning as output.") + .add_argument( + "return_index", "Tensor", + "Whether to return an additional tensor with indices for where elements in the unique " + "tensor come from the original input.") + .add_argument("return_inverse", "Tensor", + "Whether to return an additional tensor with indices for where elements in the " + "original input ended up in the returned unique list.") + .add_argument("return_counts", "Tensor", + "Whether to return an additional tensor with counts of each unique elements") + .add_argument( + "axis", "Tensor", + "The dimension to apply unique. If it is NullOpt, the unique values of the flattened input " + "are returned.") + .set_attr("FInferStructInfo", InferStructInfoUnique) + .set_attr("FCallPacked", "relax.run.unique"); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h new file mode 100644 index 000000000000..a5c7ee85bfb2 --- /dev/null +++ b/src/relax/op/tensor/set.h @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. Sex The NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. Sex The License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file set.h + * \brief The functions to make Relax set operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_SET_H_ +#define TVM_RELAX_OP_TENSOR_SET_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse, + PrimValue return_counts, Optional axis); +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_SET_H_ diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc new file mode 100644 index 000000000000..57a31dbb448a --- /dev/null +++ b/src/relax/op/tensor/statistical.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file statistical.cc + * \brief Statistical operators. + */ + +#include "statistical.h" + +#include +#include + +namespace tvm { +namespace relax { + +StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + + std::vector axes; + if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) { + axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value()); + } + + int out_ndim; + if (attrs->keepdims) { + out_ndim = data_sinfo->ndim; + } else if (!attrs->axis.defined()) { + out_ndim = 0; + } else if (data_sinfo->IsUnknownNdim()) { + out_ndim = kUnknownNDim; + } else { + out_ndim = data_sinfo->ndim - axes.size(); + ICHECK_GE(out_ndim, 0); + } + + // The inference rule for reduction operator output shapes: + // - axes is None, keepdims is false -> return the zero-rank shape; + // - axes is None, keepdims is true -> return the shape whose ndim is the same as input and every + // value is 1. + // - axes is not None, keepdims is false -> the returned shape does not contain the input axes. + // - axes is not None, keepdims is true -> the returned shape has value 1 at the positions of the + // input axes + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) { + return TensorStructInfo( + ShapeExpr(Array(out_ndim, IntImm(DataType::Int(64), /*value=*/1))), + data_sinfo->dtype); + } else { + return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array()), data_sinfo->dtype) + : TensorStructInfo(data_sinfo->dtype, out_ndim); + } + } + + Array out_shape; + out_shape.reserve(out_ndim); + for (int i = 0; i < data_sinfo->ndim; ++i) { + if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { + out_shape.push_back(data_shape->values[i]); + } else if (attrs->keepdims) { + out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1)); + } + } + ICHECK_EQ(static_cast(out_shape.size()), out_ndim); + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); +} + +InferLayoutOutput InferLayoutStatistical(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Invalid Call"; + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + int ndim = tensor_sinfo->ndim; + + Array axis; + if (attrs->axis.defined()) { + axis = attrs->axis.value(); + } else { + axis.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + axis.push_back(Integer(i)); + } + } + + std::string axis_str(ndim, '0'); + for (const auto& iter : axis) { + axis_str[(iter->value + ndim) % ndim] = '1'; + } + for (int i = 0, j = 0; i < ndim; ++i) { + if (axis_str[i] != '1') { + axis_str[i] = 'A' + j++; + } + } + + LayoutDecision exisiting_layout = GetLayoutDecision(var_layout_map, call->args[0]); + String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), exisiting_layout->layout); + Array new_axis; + for (size_t i = 0; i < new_axis_str.size(); ++i) { + if (new_axis_str.at(i) == '1') { + new_axis.push_back(Integer(i)); + } + } + std::string output_layout = new_axis_str; + output_layout.erase(std::remove(output_layout.begin(), output_layout.end(), '1'), + output_layout.end()); + + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axis = new_axis; + return InferLayoutOutput({exisiting_layout}, + {attrs->keepdims ? exisiting_layout : Layout(output_layout)}, + Attrs(new_attrs)); +} + +TVM_REGISTER_NODE_TYPE(StatisticalAttrs); + +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(max); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(mean); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(min); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(prod); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(std); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(sum); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(variance); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h new file mode 100644 index 000000000000..0adeb822591d --- /dev/null +++ b/src/relax/op/tensor/statistical.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file statistical.h + * \brief The functions to make Relax statistical operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_STATISTICAL_H_ +#define TVM_RELAX_OP_TENSOR_STATISTICAL_H_ + +#include + +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Quick helper macro + * - Expose a make function to construct the node. + * - Register op to the registry. + * \param OpName The name of operator to register. The name passed in will + * 1. be prepended with a prefix "relax.op." as the FFI identifier string for the make function, + * 2. be prepended with a prefix "relax." as the identifier string in the operator registry. + */ +#define RELAX_REGISTER_STATISTICAL_OP_INTERFACE(OpName) \ + Expr OpName(Expr x, Optional> axis, bool keepdims) { \ + ObjectPtr attrs = make_object(); \ + attrs->axis = std::move(axis); \ + attrs->keepdims = keepdims; \ + static const Op& op = Op::Get("relax." #OpName); \ + return Call(op, {std::move(x)}, Attrs{attrs}, {}); \ + } \ + TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_REGISTER_OP("relax." #OpName) \ + .set_num_inputs(1) \ + .add_argument("x", "Tensor", "The input data tensor") \ + .set_attr("FInferStructInfo", InferStructInfoStatistical) \ + .set_attr("FRelaxInferLayout", InferLayoutStatistical) + +/*! + * \brief Computes the maximum value of tensor elements over given axes. + * \param x The input data tensor + * \param axis Axis or axes along which a max is performed. Being `NullOpt` means to max all the + * elements of the input tensor + * \param keepdims If this is set to True, the axes which are reduced are left in the result as + * dimensions with size one. With this option, the result will broadcast correctly against the + * input tensor. + * \return The result after reduction. + */ +Expr max(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the mean of tensor elements over given axes. */ +Expr mean(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the min of tensor elements over given axes. */ +Expr min(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the product of tensor elements over given axes. */ +Expr prod(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the standard deviation of tensor elements over given axes. */ +Expr std(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the sum of tensor elements over given axes. */ +Expr sum(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the variance of tensor elements over given axes. */ +Expr variance(Expr x, Optional> axis, bool keepdims); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_STATISTICAL_H_ diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc new file mode 100644 index 000000000000..940192bd8e45 --- /dev/null +++ b/src/relax/op/tensor/ternary.cc @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ternary.cc + * \brief ternary operators. + */ + +#include "ternary.h" + +namespace tvm { +namespace relax { + +StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo t1 = input_sinfo[0]; + TensorStructInfo t2 = input_sinfo[1]; + TensorStructInfo t3 = input_sinfo[2]; + + int ndim = kUnknownNDim; + if (!t1->IsUnknownNdim()) { + ndim = t1->ndim; + } + if (!t2->IsUnknownNdim()) { + if (ndim == kUnknownNDim) { + ndim = t2->ndim; + } else if (t2->ndim != ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The 3 arguments of EwiseFMA must have the same number of dimensions"); + } + } + if (!t3->IsUnknownNdim()) { + if (ndim == kUnknownNDim) { + ndim = t3->ndim; + } else if (t3->ndim != ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The 3 arguments of EwiseFMA must have the same number of dimensions"); + } + } + + DataType output_dtype; + if (t1->IsUnknownDtype() || t2->IsUnknownDtype() || t3->IsUnknownDtype()) { + output_dtype = DataType::Void(); + } else if (t1->dtype != t2->dtype || t2->dtype != t3->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Data types " << t1->dtype << ", " << t2->dtype << ", and " << t3->dtype + << " must be equal for EwiseFMA"); + } else { + output_dtype = t1->dtype; + } + + auto* s1 = t1->shape.as(); + auto* s2 = t2->shape.as(); + auto* s3 = t3->shape.as(); + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + if (s1 && s2 && s3) { + Array output_shape; + for (int i = 0; i < ndim; ++i) { + PrimExpr dim1 = s1->values[i]; + PrimExpr dim2 = s2->values[i]; + PrimExpr dim3 = s3->values[i]; + if (analyzer->CanProveEqual(dim1, dim2) && analyzer->CanProveEqual(dim2, dim3)) { + output_shape.push_back(dim1); + } else { + ctx->ReportFatal(Diagnostic::Error(call) + << "The 3 arguments of EwiseFMA must have the same shape"); + } + } + return TensorStructInfo(ShapeExpr(output_shape), output_dtype); + } else if (t1->shape.defined() && t1->shape.same_as(t2->shape) && t1->shape.same_as(t3->shape)) { + return TensorStructInfo(t1->shape.value(), output_dtype); + } + + return TensorStructInfo(output_dtype, ndim); +} + +InferLayoutOutput InferLayoutEwiseFMA(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + LayoutDecision layout0 = GetLayoutDecision(var_layout_map, call->args[0]); + LayoutDecision layout1 = GetLayoutDecision(var_layout_map, call->args[1]); + LayoutDecision layout2 = GetLayoutDecision(var_layout_map, call->args[2]); + LayoutDecision layout = layout0; + if (NLayoutEqual()(layout1, layout2)) { + layout = layout1; + } + return InferLayoutOutput({layout, layout, layout}, {layout}, Attrs(call->attrs)); +} + +TVM_REGISTER_OP("relax.ewise_fma") + .set_num_inputs(3) + .add_argument("x1", "Tensor", "The left hand operand of the multiplication") + .add_argument("x2", "Tensor", "The right hand operand of the multiplication") + .add_argument("x3", "Tensor", "The operand of the addition") + .set_attr("FInferStructInfo", InferStructInfoEwiseFMA) + .set_attr("FRelaxInferLayout", InferLayoutEwiseFMA) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + +Expr ewise_fma(Expr x1, Expr x2, Expr x3) { + static const Op& op = Op::Get("relax.ewise_fma"); + return Call(op, {x1, x2, x3}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.ewise_fma").set_body_typed(ewise_fma); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/ternary.h b/src/relax/op/tensor/ternary.h new file mode 100644 index 000000000000..ba22c56d9efd --- /dev/null +++ b/src/relax/op/tensor/ternary.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ternary.h + * \brief The functions to make Relax ternary operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_TERNARY_H_ +#define TVM_RELAX_OP_TENSOR_TERNARY_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Elementwise fused multiply-add operator + * Returns elementwise result of `x1 * x2 + x3` + * \param x1 The left hand operand of the multiplication + * \param x2 The right hand operand of the multiplication + * \param x3 The operand of the addition + * \return The computed result. + */ +Expr ewise_fma(Expr x1, Expr x2, Expr x3); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_TERNARY_H_ diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc new file mode 100644 index 000000000000..f1117c1826c5 --- /dev/null +++ b/src/relax/op/tensor/unary.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unary.cc + * \brief Relax unary arithmetic operators. + */ + +#include "unary.h" + +#include + +namespace tvm { +namespace relax { + +StructInfo InferStructInfoUnaryCheck(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoUnary( + call, ctx, [](const TensorStructInfo& input_sinfo) { return DataType::Bool(); }); +} + +/***************** Arithmetic operators *****************/ + +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(abs, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(acos, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(acosh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(asin, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(asinh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(atan, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(atanh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(ceil, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(cos, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(cosh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(exp, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(floor, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(log, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(negative, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(round, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sigmoid, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sign, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sin, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sinh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(square, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sqrt, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tan, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tanh, /*require_float_dtype=*/true); + +// relax.clip +TVM_REGISTER_OP("relax.clip") + .set_num_inputs(3) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("min", "PrimValue", "The lower-bound of the range to be clipped to") + .add_argument("max", "PrimValue", "The upper-bound of the range to be clipped to") + .set_attr("FInferStructInfo", ReturnStructInfoFromArg<0>); + +Expr clip(Expr x, Expr min, Expr max) { + CHECK(min->IsInstance()) + << "The argument `min` of relax.clip is expected to be a PrimValue, but got" + << min->GetTypeKey(); + CHECK(max->IsInstance()) + << "The argument `max` of relax.clip is expected to be a PrimValue, but got" + << max->GetTypeKey(); + static const Op& op = Op::Get("relax.clip"); + return Call(op, {std::move(x), std::move(min), std::move(max)}); +} + +TVM_REGISTER_GLOBAL("relax.op.clip").set_body_typed(clip); + +/***************** Check operators *****************/ + +RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isfinite); +RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isinf); +RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isnan); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h new file mode 100644 index 000000000000..8f6404c5d9ed --- /dev/null +++ b/src/relax/op/tensor/unary.h @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. Sex The NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. Sex The License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unary.h + * \brief The functions to make Relax unary arithmetic operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_UNARY_H_ +#define TVM_RELAX_OP_TENSOR_UNARY_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Quick helper macro to + * - expose a make-function interface which construct the call node. + * - register op to the registry. + * \param OpName The name of operator to register. + * \param RequireFloatDtype A boolean indicating if the input is required to have float dtype. + * (Only for unary arith operators since all check operators don't require float dtype.) + */ +#define RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName) \ + RELAX_UNARY_OP_INTERFACE(OpName, #OpName); \ + RELAX_REGISTER_UNARY_OP(#OpName) + +#define RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(OpName, RequireFloatDtype) \ + RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferStructInfo", InferStructInfoUnaryArith) + +#define RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(OpName) \ + RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferStructInfo", InferStructInfoUnaryCheck) // require_float_dtype=false for check op + +/***************** Arithmetic operators *****************/ + +/*! + * \brief Compute element-wise absolute value of the input data. + * \param x The input data. + * \return The computed result. + */ +Expr abs(Expr x); + +/*! \brief Compute element-wise arc cos of the input data. */ +Expr acos(Expr x); + +/*! \brief Compute element-wise arc cosh of the input data. */ +Expr acosh(Expr x); + +/*! \brief Compute element-wise arc sin of the input data. */ +Expr asin(Expr x); + +/*! \brief Compute element-wise arc sinh of the input data. */ +Expr asinh(Expr x); + +/*! \brief Compute element-wise arc tan of the input data. */ +Expr atan(Expr x); + +/*! \brief Compute element-wise arc tanh of the input data. */ +Expr atanh(Expr x); + +/*! \brief Take ceil of input data. */ +Expr ceil(Expr x); + +/*! \brief Compute element-wise cos of the input data. */ +Expr cos(Expr x); + +/*! \brief Compute element-wise cosh of the input data. */ +Expr cosh(Expr x); + +/*! \brief Compute element-wise exp of data. */ +Expr exp(Expr x); + +/*! \brief Take floor of input data. */ +Expr floor(Expr x); + +/*! \brief Compute element-wise natural logarithm of data. */ +Expr log(Expr x); + +/*! \brief Compute element-wise negative value of data. */ +Expr negative(Expr x); + +/*! \brief Rounds each element of the input data to nearest integer. */ +Expr round(Expr x); + +/*! \brief Compute element-wise sigmoid of data. */ +Expr sigmoid(Expr x); + +/*! \brief Returns an indication of the sign of a number for each element of the input data. */ +Expr sign(Expr x); + +/*! \brief Compute element-wise sin of data. */ +Expr sin(Expr x); + +/*! \brief Compute element-wise sinh of data. */ +Expr sinh(Expr x); + +/*! \brief Compute element-wise square root of data. */ +Expr sqrt(Expr x); + +/*! \brief Squares each element of the input data. */ +Expr square(Expr x); + +/*! \brief Compute element-wise tan of data. */ +Expr tan(Expr x); + +/*! \brief Compute element-wise tanh of data. */ +Expr tanh(Expr x); + +/*! \brief Clips tensor values to a specified min and max. */ +Expr clip(Expr x, Expr min, Expr max); + +/***************** Check operators *****************/ + +/*! \brief Check if input value is finite. */ +Expr isfinite(Expr x); + +/*! \brief Check if input value is infinite. */ +Expr isinf(Expr x); + +/*! \brief Check if input value is Nan. */ +Expr isnan(Expr x); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_UNARY_H_ diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc new file mode 100644 index 000000000000..1fadb86d715c --- /dev/null +++ b/src/relax/transform/alter_op_impl.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/alter_op_impl.cc + * \brief Change the layout of PrimFunc in the graph. It uses the kOperatorName attribute to + * identify PrimFuncs to be replaced. Marks the new PrimFuncs with kFrozenLayout attribute set to + * true. + */ +#include +#include +#include +#include +#include +#include +#include +namespace tvm { +namespace relax { + +using namespace tir; +static constexpr const char* kOperatorName = "operator_name"; + +/*! \brief Construct ranges from shape dimensions */ +static Array ConstructRangeFromShape(const Array& shape) { + return shape.Map([](const PrimExpr& dim) { return Range(tir::make_zero(dim.dtype()), dim); }); +} + +static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { + auto shape = tensor_sinfo->GetShape(); + ICHECK(shape.defined()); + return shape.value(); +} + +static Array GetShapeFromTensor(const Expr& expr) { + const auto& tensor_sinfo = Downcast(expr->struct_info_); + return GetShapeFromTensorStructInfo(tensor_sinfo); +} + +static IndexMap DeepCopyIndexMap(const IndexMap& index_map) { + return Downcast(LoadJSON(SaveJSON(index_map))); +} + +/*! \brief Checks if the \p transform is bijective on the shape of \p expr */ +bool IsTransformBijective(const Expr& expr, const IndexMap& transform) { + Array input_shape = GetShapeFromTensor(expr); + Array initial_ranges = ConstructRangeFromShape(input_shape); + auto [inverse, padding_predicate] = transform.NonSurjectiveInverse(initial_ranges); + (void)inverse; // to avoid unused variable warning; + arith::Analyzer analyzer; + if (!analyzer.CanProve(!padding_predicate)) return false; + return true; +} + +/*! + * \brief Replace each call_tir to PrimFunc which matches the kOperatorName attribute with the + * provided replacement PrimFunc and mark it with kFrozenLayout attribute. Insert layout + * transformations on i/o buffers as necessary for correctness. + */ +class AlterOpImplMutator : public ExprMutator { + public: + AlterOpImplMutator(const IRModule& mod, const Map& op_impl_map, + const Map>& op_buffer_transforms_) + : ExprMutator(mod), + mod_(mod), + op_impl_map_(op_impl_map), + op_buffer_transforms__(op_buffer_transforms_) {} + + IRModule Run() { + for (const auto& [gv, func] : mod_->functions) { + if (func->IsInstance()) { + relax::Function update_func = Downcast(VisitExpr(func)); + builder_->UpdateFunction(gv, update_func); + } + } + return builder_->GetContextIRModule(); + } + + private: + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* op) final { + auto call = Downcast(ExprMutator::VisitExpr_(op)); + + // TODO(@tvm-team): When we differentiate the call for tir function and packed function, + // this logic should be changed accordingly. + if (!call->op.same_as(call_tir_op_)) return call; + + // Do not do anything for external function + if (call->args[0].as()) return call; + + // Get operator name from callee + ICHECK(call->args[0]->IsInstance()); + const tir::PrimFunc& old_func = + Downcast(mod_->Lookup(Downcast(call->args[0]))); + Optional maybe_op_kind = old_func->attrs.GetAttr(kOperatorName); + + // If the callee does not have kOperatorName attribute or no replacement is requested for + // it, nothing to do here. + if (!maybe_op_kind.defined() || op_impl_map_.count(maybe_op_kind.value()) == 0) return call; + auto op_kind = maybe_op_kind.value(); + + const auto& replacement_func = op_impl_map_[op_kind]; + + Array buffer_transforms; + if (op_buffer_transforms__.count(op_kind)) buffer_transforms = op_buffer_transforms__[op_kind]; + + ICHECK(buffer_transforms.empty() || buffer_transforms.size() == replacement_func->params.size()) + << "Either the i/o buffers do not require any transformations or transformations for each " + "buffer is provided."; + ICHECK_EQ(old_func->params.size(), replacement_func->params.size()) + << "Number of parameters of old and replacement PrimFunc must match"; + + GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, op_kind); + + auto call_tir_inputs_tuple = GetRef(call->args[1].as()); + Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms); + + ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is expected to be 1"; + StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], buffer_transforms); + auto updated_call = builder_->Normalize( + Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, {updated_ret_sinfo})); + + // Now transform each of the outputs to previous layout. + return TransformOutputs(updated_call, buffer_transforms, call->sinfo_args[0]); + } + + Array GetTensorStructInfoPerOutput(const StructInfo& output_sinfo) { + if (const auto* tensor_sinfo = output_sinfo.as()) + return {GetRef(tensor_sinfo)}; + const auto* tuple_sinfo = output_sinfo.as(); + ICHECK(tuple_sinfo); + + Array arr_tensor_sinfo; + arr_tensor_sinfo.reserve(tuple_sinfo->fields.size()); + for (const auto& sinfo : tuple_sinfo->fields) { + const auto* tensor_sinfo = sinfo.as(); + ICHECK(tensor_sinfo) << "Nested tuples in output of call_tir is not supported yet"; + arr_tensor_sinfo.push_back(GetRef(tensor_sinfo)); + } + return arr_tensor_sinfo; + } + + Expr TransformLayout(const Expr& expr, const IndexMap& index_map) { + ObjectPtr attrs = make_object(); + // We want to avoid two layout_transform ops to share the same index map even if they are + // identical. The scope of vars used in index map initial indices is local to the op. Not doing + // so would confuse the structural equality check. + attrs->index_map = std::move(DeepCopyIndexMap(index_map)); + return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {}); + } + + Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map, + const TensorStructInfo& old_tensor_sinfo) { + Array old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo); + Array initial_ranges = ConstructRangeFromShape(old_shape); + auto [inverse_index_map, padding_predicate] = index_map.NonSurjectiveInverse(initial_ranges); + ICHECK(tir::is_zero(padding_predicate)) + << "Only bijective transformations on input/output buffers are supported, but found " + "padding predicate " + << padding_predicate << " on initial range " << initial_ranges; + return TransformLayout(expr, inverse_index_map); + } + + /*! + * \brief Adds the \p replacement_func to the module if it has not already been added before. + * \returns The global var associated with the PrimFunc. + */ + GlobalVar GetOrCreateGlobalVarForFunc(const PrimFunc& replacement_func, const String& op_kind) { + if (cache_.count(replacement_func) != 0) { + return cache_[replacement_func]; + } + // Retain the operator name attribute on the replacement PrimFunc. This can help any future + // passes that use kOperatorName attribute to identify operator represented by a PrimFunc. + PrimFunc replacement_func_with_frozen_layout = + WithAttr(replacement_func, kOperatorName, op_kind); + + GlobalVar gv_replacement = + builder_->AddFunction(replacement_func_with_frozen_layout, op_kind + "_replacement"); + cache_.Set(replacement_func, gv_replacement); + return gv_replacement; + } + + /*! + * \brief Updates call inputs with layout transformed inputs + */ + Tuple UpdateInputs(const Tuple& inputs, const Array& transforms) { + if (transforms.empty()) return inputs; + + Array updated_inputs; + int index = 0; + for (const auto& input : inputs->fields) { + auto transform = transforms[index++]; + ICHECK(IsTransformBijective(input, transform)) + << "Non bijective transforms on input and output buffers are not supported."; + updated_inputs.push_back(TransformLayout(input, transform)); + } + return Tuple(updated_inputs); + } + + /*! \brief Updates output struct info */ + StructInfo UpdateStructInfo(const StructInfo& out_sinfo, + const Array& buffer_transforms) { + if (buffer_transforms.empty()) return out_sinfo; + + if (out_sinfo->IsInstance()) + return UpdateStructInfo(Downcast(out_sinfo), + buffer_transforms[buffer_transforms.size() - 1]); + + ICHECK(out_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << out_sinfo; + + const auto& tuple_sinfo = Downcast(out_sinfo); + Array sinfo_fields; + size_t first_output_index = buffer_transforms.size() - tuple_sinfo->fields.size(); + size_t i = 0; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + sinfo_fields.push_back(UpdateStructInfo(Downcast(si), + buffer_transforms[first_output_index + i++])); + } + return TupleStructInfo(sinfo_fields); + } + + /*! \brief Returns the TensorStructInfo after applying the \p transform on its shape */ + StructInfo UpdateStructInfo(const TensorStructInfo& tensor_sinfo, const IndexMap& transform) { + auto shape = GetShapeFromTensorStructInfo(tensor_sinfo); + auto new_shape = transform->MapShape(shape); + return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype); + } + + Expr TransformOutputs(const Expr& expr, const Array& buffer_transforms, + const StructInfo& old_struct_info) { + if (buffer_transforms.empty()) return expr; + + Array old_output_sinfo = GetTensorStructInfoPerOutput(old_struct_info); + + size_t num_outputs = old_output_sinfo.size(); + if (num_outputs == 0) return expr; + + size_t first_output_index = buffer_transforms.size() - num_outputs; + // If there is a single output, return the transformed output. + if (num_outputs == 1) { + IndexMap output_map = buffer_transforms[first_output_index]; + return TransformLayoutInverse(expr, output_map, old_output_sinfo[0]); + } + + // In case of more than one output, we would have to get each item of the output tuple, + // transform it and return a tuple of all transformed outputs. + Array transformed_outputs; + for (size_t i = 0; i + first_output_index < buffer_transforms.size(); ++i) { + const auto& output_map = buffer_transforms[i + first_output_index]; + auto output = builder_->Normalize(TupleGetItem(expr, static_cast(i))); + transformed_outputs.push_back( + TransformLayoutInverse(output, output_map, old_output_sinfo[i])); + } + return Tuple(transformed_outputs); + } + + private: + /*! \brief Cache to keep track of the GlobalVar associated with the new PrimFunc added */ + Map cache_; + /*! \brief Input IRModule */ + const IRModule& mod_; + /*! \brief Map from kOperatorName attribute to the replacement PrimFunc */ + const Map& op_impl_map_; + /*! \brief Map from kOperatorName attribute to the layout transforms on i/o buffers */ + const Map>& op_buffer_transforms__; + + const Op& call_tir_op_ = Op::Get("relax.call_tir"); + const Op& layout_transform_op_ = Op::Get("relax.layout_transform"); +}; + +namespace transform { + +Pass AlterOpImpl(const Map& op_impl_map, + const Map>& op_buffer_transforms_) { + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext pc) { + return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_).Run(); + }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"AlterOpImpl", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AlterOpImpl").set_body_typed(AlterOpImpl); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/annotate_tir_op_pattern.cc b/src/relax/transform/annotate_tir_op_pattern.cc new file mode 100644 index 000000000000..b1c1ed29aff3 --- /dev/null +++ b/src/relax/transform/annotate_tir_op_pattern.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/annotate_tir_op_pattern.cc + * \brief Annotate Op Pattern for TIR functions. It is a pass works on TIR PrimFuncs, + * but they are needed for relax fusion. So we put them in the relax namespace. + */ +#include +#include +#include + +namespace tvm { +namespace relax { + +tir::PrimFunc AnnotateOpPattern(tir::PrimFunc f) { + if (f->HasNonzeroAttr("op_pattern")) { + return f; + } else { + relay::OpPatternKind kind = AnalyzeOpPatternKind(f); + return WithAttr(std::move(f), "op_pattern", Integer(static_cast(kind))); + } +} + +namespace transform { + +Pass AnnotateTIROpPattern() { + auto pass_func = [=](tir::PrimFunc f, IRModule m, PassContext ctx) { + return AnnotateOpPattern(std::move(f)); + }; + return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc new file mode 100644 index 000000000000..be779e97bcf5 --- /dev/null +++ b/src/relax/transform/attach_global_symbol.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/attach_global_symbol.cc + * \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen. + */ + +#include +#include + +namespace tvm { +namespace relax { + +class GlobalSymbolAttacher { + public: + explicit GlobalSymbolAttacher(IRModule mod) : mod_(mod) {} + + IRModule Attach() { + IRModule ret; + for (auto& p : mod_->functions) { + BaseFunc func = p.second; + if (auto* prim_func = func.as()) { + func = WithAttr(GetRef(prim_func), "global_symbol", p.first->name_hint); + } else if (auto* relax_func = func.as()) { + func = WithAttr(GetRef(relax_func), "global_symbol", p.first->name_hint); + } else { + LOG(FATAL) << "Unsupported function type: " << func->GetTypeKey(); + throw; + } + ret->Add(p.first, func); + } + return ret; + } + + private: + IRModule mod_; +}; + +namespace transform { + +Pass AttachGlobalSymbol() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return GlobalSymbolAttacher(mod).Attach(); }; + return CreateModulePass(pass_func, 0, "AttachGlobalSymbol", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AttachGlobalSymbol").set_body_typed(AttachGlobalSymbol); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc new file mode 100644 index 000000000000..c444a84f44e0 --- /dev/null +++ b/src/relax/transform/bind_params.cc @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +void MatchSymbolicVar(const Expr& arg, const Expr& constant, + Map* symbolic_var_map, arith::Analyzer* analyzer_) { + auto opt_arg_sinfo = MatchStructInfo(arg); + CHECK(opt_arg_sinfo) + << "The struct info of the bound parameter is expected to be TensorStructInfo, but got: " + << GetStructInfo(arg); + auto opt_const_sinfo = MatchStructInfo(constant); + // As the constant is generated by internal codes, we use ICHECK here. + ICHECK(opt_const_sinfo) + << "The struct info of the bound weight is expected to be TensorStructInfo, but got: " + << GetStructInfo(constant); + + TensorStructInfo arg_sinfo = opt_arg_sinfo.value(); + TensorStructInfo const_sinfo = opt_const_sinfo.value(); + ICHECK(!const_sinfo->IsUnknownDtype()); + ICHECK(!const_sinfo->IsUnknownNdim()); + ICHECK(const_sinfo->shape.defined()); + + // dtype mismatch + if (!arg_sinfo->IsUnknownDtype() && arg_sinfo->dtype != const_sinfo->dtype) { + LOG(FATAL) << "The dtype of the bound parameter is expected to be " << arg_sinfo->dtype + << ", but got: " << const_sinfo->dtype; + } + // ndim mismatch + if (!arg_sinfo->IsUnknownNdim() && arg_sinfo->ndim != const_sinfo->ndim) { + LOG(FATAL) << "The ndim of the bound parameter is expected to be " << arg_sinfo->ndim + << ", but got: " << const_sinfo->ndim; + } + if (!arg_sinfo->shape.defined()) return; + const auto* arg_shape = arg_sinfo->shape.value().as(); + const auto* const_shape = const_sinfo->shape.value().as(); + + CHECK(arg_shape && const_shape) + << "The shape of the bound parameter and weight is expected to be ShapeExprNode for now"; + + for (int i = 0; i < arg_sinfo->ndim; ++i) { + const PrimExpr& const_dim = const_shape->values[i]; + ICHECK(tir::is_const_int(const_dim)); + if (const auto* shape_var = arg_shape->values[i].as()) { + auto it = symbolic_var_map->find(GetRef(shape_var)); + if (it == symbolic_var_map->end()) { + symbolic_var_map->Set(GetRef(shape_var), const_dim); + } else { + CHECK(analyzer_->CanProveEqual((*it).second, const_dim)) + << "The shape of the bound parameter is expected to be " << (*it).second + << ", but got: " << const_dim; + } + } + } +} + +/*! + * \brief Bind params to function by using name + * \param func Relax function + * \param params params dict + * \return Function + */ +inline Function BindParamsByName(Function func, const Map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto& name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(name_dict[name]); + } else { + name_dict[name] = arg; + } + } + + arith::Analyzer analyzer; + Map bind_dict; + Map symbolic_var_map; + + for (auto& kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + const Var& arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "ValueError: Multiple args in the function have name " << kv.first; + } + Expr const_expr = Constant(kv.second); + bind_dict.Set(arg, const_expr); + MatchSymbolicVar(arg, const_expr, &symbolic_var_map, &analyzer); + } + Expr bound_expr = Bind(func, bind_dict, symbolic_var_map); + Function ret = Downcast(bound_expr); + ICHECK(ret.defined()) << "The returning type is expected to be a Relax Function." + << "\n"; + return ret; +} + +/*! + * \brief Bind params to a specific function in a module + * \param m The module + * \param func_name The name of the specific function + * \param param The param dict + * \return The module after binding params. + */ +IRModule BindParam(IRModule m, String func_name, Map param) { + IRModuleNode* new_module = m.CopyOnWrite(); + Map functions = m->functions; + for (const auto& func_pr : functions) { + if (const auto* relax_f = func_pr.second.as()) { + if (relax_f->GetLinkageType() == LinkageType::kExternal) { + // Use global_symbol if it's external linkage + Optional gsymbol = relax_f->GetAttr(tvm::attr::kGlobalSymbol); + if (gsymbol.defined() && gsymbol.value() == func_name) { + Function f_after_bind = BindParamsByName(GetRef(relax_f), param); + new_module->Update(func_pr.first, f_after_bind); + } + } else { + // Use global var's name_hint if it's internal linkage + if (func_pr.first->name_hint == func_name) { + Function f_after_bind = BindParamsByName(GetRef(relax_f), param); + new_module->Update(func_pr.first, f_after_bind); + } + } + } + } + return GetRef(new_module); +} + +namespace transform { + +Pass BindParams(String func_name, Map params) { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); }; + return CreateModulePass(pass_func, 0, "BindParams", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc new file mode 100644 index 000000000000..6066ed8d2a7d --- /dev/null +++ b/src/relax/transform/call_tir_rewrite.cc @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/call_tir_rewrite.cc + * \brief Perform explicit tensor allocation for call_tir. + */ +#include +#include +#include +#include +#include + +#include "../../relay/transforms/pattern_utils.h" + +namespace tvm { +namespace relax { + +// ================== +// CallTIRMutator +// Perform explicit tensor allocation for call_tir or call_dps_packed. +// Example: +// lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32") +// --> +// gv0 = rx.call("relax.builtin.alloc_tensor", [n, m], dtype="float32") +// rx.call_packed(func, x, gv0) + +class CallTIRMutator : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const CallNode* call) override { + // post-order mutation + Expr expr = VisitExprPostOrder_(call); + call = expr.as(); + + static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn"); + + if (call->op == call_tir_op || call->op == call_dps_packed_op) { + Array outs; + if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { + // single output case + const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); + ICHECK(tensor_sinfo->shape.defined()) + << "the TensorStructInfo shape of call_tir has not populated"; + outs.push_back( + builder_->Emit(Call(alloc_tensor_op, // + {Downcast(tensor_sinfo->shape.value()), + DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(0)}, // + Attrs()), + "alloc")); + } else if (const auto& _tuple_sinfo = MatchStructInfo(expr)) { + // multiple output case + const TupleStructInfo& tuple_sinfo = _tuple_sinfo.value(); + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + const auto& field = tuple_sinfo->fields[i]; + + ICHECK(field->IsInstance()) + << "call_tir expects Tuple of TensorStructInfo, but got " << field + << " as an element of TupleStructInfo"; + const auto& field_tensor = Downcast(field); + ICHECK(field_tensor->shape.defined()) + << "call_tir expects all TensorStructInfo has shape, but got " << field_tensor + << " as an element of TupleStructInfo"; + outs.push_back( + builder_->Emit(Call(alloc_tensor_op, + {Downcast(field_tensor->shape.value()), + DataTypeImm(field_tensor->dtype), PrimValue::Int64(0)}, + Attrs()), + "alloc")); + } + } else { + LOG(FATAL) << "TypeError: The struct info of call_tir expects to be TensorStructInfo or " + "TupleStructInfo, but got" + << expr->struct_info_; + } + + Array args; + if (call->args[1].as()) { + args = Downcast(call->args[1])->fields; + args.insert(args.end(), outs.begin(), outs.end()); + + if (call->args.size() == 2) { + builder_->Emit(Call(call->args[0], args), "_"); + } else { + // unpack semantics + args.push_back(call->args[2]); + builder_->Emit(Call(call_tir_dyn_op, {call->args[0], Tuple(args)}), "_"); + } + } else { + args = outs; + args.insert(args.begin(), call->args[1]); + builder_->Emit(Call(call->args[0], args), "_"); + } + + if (outs.size() == 1) { + return outs[0]; + } + return std::move(Tuple(outs)); + } + + return GetRef(call); + } +}; + +Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); } + +namespace transform { + +Pass CallTIRRewrite() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(CallTIRRewrite(f)); }; + return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc new file mode 100644 index 000000000000..962f76a376b6 --- /dev/null +++ b/src/relax/transform/canonicalize_bindings.cc @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/canonicalize_bindings.cc + * \brief Pass for simplifying modules by folding var bindings and match shape nodes. + * May include other forms of simplification in the future. + * Ideally should be used before constant folding and eliminating unused bindings. + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class BindingCanonicalizer : public ExprMutator { + public: + BindingCanonicalizer() {} + + Expr VisitExpr_(const VarNode* op) override { + // remap first + Var v = Downcast(ExprMutator::VisitExpr_(op)); + if (!CanCanonicalizeVar(v)) { + return Downcast(v); + } + // visit again in case we need to do a substitution in the value + return ExprMutator::VisitExpr_(LookupBinding(v).as()); + } + + Expr VisitExpr_(const DataflowVarNode* op) override { + Var v = Downcast(ExprMutator::VisitExpr_(op)); + if (!CanCanonicalizeVar(v)) { + return Downcast(v); + } + return ExprMutator::VisitExpr_(LookupBinding(v).as()); + } + + void VisitBinding_(const VarBindingNode* binding) override { + // Unlike default visitor, we do not permit the checked type to change + // if the new value's checked type is different (this preserves user annotations) + Expr new_value = this->VisitExpr(binding->value); + Var new_var = this->VisitVarDef(binding->var); + + if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + this->builder_->EmitNormalized(GetRef(binding)); + return; + } + + this->builder_->EmitNormalized(VarBinding(new_var, new_value)); + } + + void VisitBinding_(const MatchCastNode* binding) override { + // If we have a trivial shape check (the shape_ of LHS and RHS is the same), + // we can canonicalize to a var binding + Expr new_value = this->VisitExpr(binding->value); + + // if the LHS and RHS have the same struct info, we canonicalize to a var binding instead + if (StructuralEqual()(binding->struct_info, GetStructInfo(new_value))) { + builder_->EmitNormalized(VarBinding(binding->var, new_value)); + } else if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + builder_->EmitNormalized(MatchCast(binding->var, new_value, binding->struct_info)); + } + } + + private: + bool AnnotationsDiffer(const ObjectRef& obj1, const ObjectRef& obj2, + std::function check_eq) { + // annotations differ if one is present but not the other + // or they're both present and they differ + bool both_present = obj1.defined() && obj2.defined(); + bool neither_present = !obj1.defined() && !obj2.defined(); + return !(both_present || neither_present) || (both_present && !check_eq(obj1, obj2)); + } + + bool CanCanonicalizeVar(Var v) { + Optional value = LookupBinding(v); + // can replace only if the value is also a var + if (!value || !value.as()) { + return false; + } + Var parent_var = Downcast(value); + + // Cases when we conservatively do not unify: + // 1. checked_type_ or shape_ of the child differs from that of the parent + // In this case, we could be overriding user annotations. + // 2. If the child is a Var and the parent is a DataflowVar. + // That could result in a DataflowVar leaving the current DataflowBlock. + bool annotations_differ = AnnotationsDiffer(v->struct_info_, parent_var->struct_info_, + [&](const ObjectRef& lhs, const ObjectRef& rhs) { + return tvm::StructuralEqual()(lhs, rhs); + }); + bool var_to_dataflow = (!v.as() && parent_var.as()); + return !annotations_differ && !var_to_dataflow; + } +}; + +Expr CanonicalizeBindings(const Expr& e) { return BindingCanonicalizer().VisitExpr(e); } + +namespace transform { + +Pass CanonicalizeBindings() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeBindings(f)); + }; + return CreateFunctionPass(pass_func, 1, "CanonicalizeBindings", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings").set_body_typed(CanonicalizeBindings); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc new file mode 100644 index 000000000000..4f36cfbc0fed --- /dev/null +++ b/src/relax/transform/convert_layout.cc @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/convert_layout.cc + * \brief Automatic layout conversion pass, especially for axis swapping. + */ + +#include +#include +#include +#include + +#include "../op/tensor/manipulate.h" +#include "infer_layout_utils.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +using tir::Layout; + +/*! + * \brief Main logic to convert the layout of conv2d. Other ops + * can adapt to such layout conversion following conv2d accordingly. + * + * Structurally speaking, a Relax function is composed of a series of VarBinding and + * MatchCast. And a specific class of VarBindings is the basic unit we want to rewrite. + * Formally, they are of the form: + * + * var = Call(Op, [args], attrs) + * + * where Op is a specific op we want to rewrite, and attrs is the attributes of the op. + * var and args are all exprs with type Tensor or Tuple of Tensors. They might + * be vars, constants, or Tuple of vars and constants. + * + * We register the layout inference function for each op (FRelaxInferLayout), which accepts the + * current call, the desired layout of conv2d ops, and the layout map of previous vars. The result + * of the layout inference function is contained in an InferLayoutOutput object, which contains 3 + * fields: input_layouts, output_layouts, and attr, which represents the expected input layout, + * output_layout and converted attrs of the new op call. + * + * The rewrite pass does the rewriting in a single forward pass, where for each Call(Op), + * we collect the current Layout of each input var, and let the InferLayout function to infer the + * desired layout of the output. The rewriter will use these info to convert + * the layout of inputs and attrs of the op call, and note down the new layout of the output. + * + * The desired layout of conv2d ops is a map from the name of the op to the desired layout of the + * desired feature map, weight and output. For example, if we want to convert the layout of conv2d + * from NCHW to NHWC, we can set the desired layout of conv2d to be {"conv2d": ["NHWC", "OHWI"]}. + * + * The way we represent the layout of a var is a NLayout object, which is a nested tuple of Layout. + * The incoming layout of the module will be set as the default layout (We use ABCD... as the + * default) Note that for operators like conv, pool, people typically use NHWC to refer to the axes. + * But to be generic and support more operators, we use ABCD... to refer to the axes. + * + * Note that currently the layout conversion of conv2d only support axis swapping, such as NCHW to + * NWHC. Packed layout such as NCHW to NCHW4c is not supported now. + */ +class LayoutConvertMutator : public ExprMutator { + public: + explicit LayoutConvertMutator(const Map>& desired_layouts) + : desired_layouts_(desired_layouts) {} + + private: + Array LayoutToIntegers(const Layout& layout) { + Array ret; + LayoutDecision src = InitialLayoutDecision(layout.ndim()); + for (size_t i = 0; i < layout.ndim(); ++i) { + ret.push_back(Integer(src->layout.IndexOf(layout[i]))); + } + return ret; + } + + Expr RewriteExpr(const Expr& expr, const NLayout& to) { + auto fvisitleaf = [&](const Expr& expr, std::array layouts) -> Expr { + NLayout from = layouts[0], to = layouts[1]; + if (NLayoutEqual()(from, to)) return expr; + // If not both from and to are unknown, then none of them can be unknown. + ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) && + !NLayoutEqual()(to, LayoutDecision::InitUnknownDim())) + << "Cannot convert when exactly one of the layouts is unknown"; + const auto* tensor = GetStructInfoAs(expr); + ICHECK(tensor != nullptr) << "Expect a tensor, but got: " << expr; + Layout axes = TransposeLike(InitialLayoutDecision(tensor->ndim)->layout, + from.LeafValue()->layout, to.LeafValue()->layout); + return permute_dims(expr, LayoutToIntegers(axes)); + }; + return TransformTupleLeaf( + VarReplacer::Replace(expr, var_remap_), + std::array({GetNLayout(var_layout_map_, expr), to}), fvisitleaf); + } + + Array RewriteArgs(const Array& args, const Array& to) { + ICHECK(args.size() == to.size()); + std::vector new_args; + for (size_t i = 0; i < args.size(); ++i) { + new_args.push_back(RewriteExpr(args[i], to[i])); + } + return std::move(new_args); + } + + void VisitBinding(const Binding& binding) final { + // Emit the binding + ExprMutator::VisitBinding(binding); + // The layout is default to be initial if not rewritten. + if (var_layout_map_.find(binding->var) == var_layout_map_.end()) { + var_layout_map_[binding->var] = InitialNLayout(binding->var); + } + } + + Expr VisitVars_(const Var& var) { + // We encounter a var use outside of inferrable regions, we rewrite it to initial layout. + return RewriteExpr(var, InitialNLayout(var)); + } + + Expr VisitExpr_(const VarNode* op) final { return VisitVars_(GetRef(op)); } + + Expr VisitExpr_(const DataflowVarNode* op) final { return VisitVars_(GetRef(op)); } + + bool HasUnknownDimTensor(const NLayout& nlayout) { + bool find = false; + auto fvisit = [&](const LayoutDecision& layout) { + find = find | (NLayoutEqual()(layout, LayoutDecision::InitUnknownDim())); + }; + ForEachLeaf(nlayout, fvisit); + return find; + } + + bool HasUnknownDimTensor(const Array& args) { + for (const auto& arg : args) { + if (IsNestedTensor(arg)) { + if (HasUnknownDimTensor(GetNLayout(var_layout_map_, arg))) { + return true; + } + } + } + return false; + } + + Optional GetInferLayoutInfo(const CallNode* call_node, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + const OpNode* op_node = call_node->op.as(); + if (op_node == nullptr) return NullOpt; + Op op = Downcast(GetRef(op_node)); + const auto attr_map = Op::GetAttrMap("FRelaxInferLayout"); + if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) { + // If the op has FRelaxInferLayout, and all the input tensors have known ndim + FRelaxInferLayout f = attr_map[op]; + return f(GetRef(call_node), desired_layouts, var_layout_map); + } else { + // Otherwise, we use the default policy. + return NullOpt; + } + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { + Optional res = + GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_); + ObjectPtr new_call = make_object(*call_node); + new_call->struct_info_ = NullOpt; + if (!res.defined() || + (!IsNestedTensor(binding->var) && !binding->var->IsInstance())) { + // Default policy: use the initial layout. + // When we don't have the infer layout info, or it's a non-tensor global var binding. + std::vector input_layout; + for (const auto& arg : call_node->args) { + input_layout.push_back(InitialNLayout(arg)); + } + Array new_args = RewriteArgs(call_node->args, std::move(input_layout)); + new_call->args = std::move(new_args); + ReEmitBinding(binding, builder_->Normalize(Call(new_call))); + // update the layout map + var_layout_map_[binding->var] = InitialNLayout(binding->var); + } else { + // Convert the layout according to the inferred layout output. + Array new_args = RewriteArgs(call_node->args, res.value()->input_layouts); + new_call->args = std::move(new_args); + new_call->attrs = std::move(res.value()->new_attrs); + Expr cur_call = builder_->Normalize(Call(new_call)); + if (binding->var->IsInstance()) { + // Dataflow var, we emit the rewritten call. + ReEmitBinding(binding, cur_call); + // update the layout map + var_layout_map_[binding->var] = res.value()->output_layouts[0]; + } else { + // Global var (tensor), we rewrite it to initial layout + ICHECK(IsNestedTensor(binding->var)); + if (!NLayoutEqual()(res.value()->output_layouts[0], InitialNLayout(binding->var))) { + Var new_var = builder_->Emit(cur_call); + var_layout_map_[new_var] = res.value()->output_layouts[0]; + cur_call = builder_->Normalize(RewriteExpr(new_var, InitialNLayout(binding->var))); + } + ReEmitBinding(binding, cur_call); + // update the layout map + var_layout_map_[binding->var] = InitialNLayout(binding->var); + } + } + } + + void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final { + std::vector input_layout; + for (const auto& field : val->fields) { + if (binding->var->IsInstance()) { + // Df var: Use the current realized layout to group the tuple; + input_layout.push_back(GetNLayout(var_layout_map_, field)); + } else { + // Global var: Use the initial layout to group the tuple; + input_layout.push_back(InitialNLayout(field)); + } + } + Array new_fields = RewriteArgs(val->fields, std::move(input_layout)); + if (IsNestedTensor(binding->var)) { + ReEmitBinding(binding, builder_->Normalize(Tuple(new_fields))); + var_layout_map_[binding->var] = input_layout; + } + } + + void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { + NLayout input_layout = binding->var->IsInstance() + ? GetNLayout(var_layout_map_, val->tuple) + : InitialNLayout(val->tuple); + ReEmitBinding(binding, builder_->Normalize( + TupleGetItem(RewriteExpr(val->tuple, input_layout), val->index))); + // update the layout map + var_layout_map_[binding->var] = input_layout.NestedArray()[val->index]; + } + + void VisitBinding_(const MatchCastNode* binding) final { + if (!binding->var->IsInstance()) { + ExprMutator::VisitBinding_(binding); + return; + } + NLayout from_layout = InitialNLayout(binding->value); + NLayout input_layout = GetNLayout(var_layout_map_, binding->value); + auto fvisitleaf = [&](const StructInfo& sinfo, std::array layouts) -> StructInfo { + NLayout from = layouts[0], to = layouts[1]; + if (NLayoutEqual()(from, to)) return sinfo; + // If not both from and to are unknown, then none of them can be unknown. + ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) && + !NLayoutEqual()(to, LayoutDecision::InitUnknownDim())) + << "Cannot convert when exactly one of the layouts is unknown"; + const TensorStructInfoNode* tsinfo = sinfo.as(); + ICHECK(tsinfo != nullptr) << "We can not set layout for non-tensor struct"; + if (!tsinfo->shape.defined()) return sinfo; + const ShapeExprNode* shape = tsinfo->shape.value().as(); + if (shape == nullptr) return sinfo; + ICHECK_EQ(shape->values.size(), to.LeafValue()->layout.ndim()); + std::vector new_shape; + for (size_t i = 0; i < shape->values.size(); ++i) { + new_shape.push_back( + shape->values[from.LeafValue()->layout.IndexOf(to.LeafValue()->layout[i])]); + } + return TensorStructInfo(ShapeExpr(new_shape), tsinfo->dtype, tsinfo->span); + }; + StructInfo new_struct_info = TransformTupleLeaf( + binding->struct_info, std::array({from_layout, input_layout}), fvisitleaf); + // re-emit old binding if nothing changes + if (new_struct_info.same_as(binding->struct_info)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + Var new_var = + builder_->EmitMatchCast(RewriteExpr(binding->value, input_layout), new_struct_info); + var_layout_map_[binding->var] = input_layout; + this->var_remap_[binding->var->vid] = new_var; + } + } + + std::unordered_map var_layout_map_; + Map> desired_layouts_; +}; // namespace relax + +DataflowBlock ConvertLayoutPass(const DataflowBlock& df_block, + Map> desired_layouts) { + LayoutConvertMutator mutator(desired_layouts); + return Downcast(mutator.VisitBindingBlock(df_block)); +} + +namespace transform { + +Pass ConvertLayout(Map> desired_layouts) { + runtime::TypedPackedFunc pass_func = + [=](DataflowBlock df_block, IRModule m, PassContext pc) { + return Downcast(ConvertLayoutPass(df_block, desired_layouts)); + }; + return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ConvertLayout").set_body_typed(ConvertLayout); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc new file mode 100644 index 000000000000..fe36eb28ef61 --- /dev/null +++ b/src/relax/transform/dead_code_elimination.cc @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/dead_code_elimination.cc + * \brief Dead code elimination pass. + * \sa tvm/relax/ir/binding_rewrite.cc + * + * Currently it removes: + * 1. Unused local VarBindings in a DataflowBlock. + * 2. Unused DataflowBlocks in a function. + * 3. Unused Relax functions in the module. + * We detect the call chain from the entry function, and remove all unused functions. + */ + +#include +#include +#include +#include + +#include "utils.h" + +namespace tvm { +namespace relax { + +/** + * \brief Detects all the functions that can be possibly called by entry function. + */ +class CallTracer : ExprVisitor { + public: + explicit CallTracer(IRModule mod_) : mod_{mod_}, called_funcs_{}, visiting_{} {} + + void VisitExpr_(const GlobalVarNode* op) final { + called_funcs_.insert(GetRef(op)); + auto func = mod_->Lookup(op->name_hint); + if (const auto* function_node = func.as()) { + VisitExpr(GetRef(function_node)); + } + // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein. + } + + void VisitExpr_(const CallNode* call_node) final { ExprVisitor::VisitExpr_(call_node); } + + void VisitExpr_(const FunctionNode* func_node) final { + auto func = GetRef(func_node); + if (visiting_.find(func) == visiting_.end()) { + visiting_.insert(func); + for (auto param : func_node->params) { + ExprVisitor::VisitExpr(param); + } + ExprVisitor::VisitExpr(func_node->body); + } + } + + void Trace(std::string entry) { + called_funcs_.insert(mod_->GetGlobalVar(entry)); + auto main_func = mod_->Lookup(entry); + VisitExpr(main_func); + } + + bool check_if_called(GlobalVar gv) { return called_funcs_.count(gv) > 0; } + + private: + IRModule mod_; + + // Record the names of all encountered functions. + std::unordered_set called_funcs_; + + // Record the expressions that are being visited. + std::unordered_set visiting_; +}; + +IRModule RemoveUnusedFunctions(IRModule mod_, Array entry_funcs) { + auto tracer = CallTracer(mod_); + for (auto entry : entry_funcs) { + tracer.Trace(entry); + } + auto existing_functions = mod_->functions; + for (auto f : existing_functions) { + // If a function has an external linkage type, we do not remove it. + // Otherwise, we check the function and remove it if it is not used anywhere. + if (f.second->GetLinkageType() == LinkageType::kInternal && !tracer.check_if_called(f.first)) { + mod_->Remove(f.first); + } + } + return mod_; +} + +IRModule DeadCodeElimination(const IRModule& mod, Array entry_functions) { + // S1: remove unused functions to reduce the number of functions to be analyzed. + IRModule tmp_mod = RemoveUnusedFunctions(mod, entry_functions); + // S2: remove unused variables in each function. + for (const auto& gv : tmp_mod->GetGlobalVars()) { + auto func = tmp_mod->Lookup(gv); + if (func->IsInstance()) { + tmp_mod->Update(gv, RemoveAllUnused(Downcast(func))); + } + } + // S3: remove unused functions again as some callers may be removed in S2. + return RemoveUnusedFunctions(tmp_mod, entry_functions); +} + +namespace transform { + +Pass DeadCodeElimination(Array entry_functions) { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relax::DeadCodeElimination(m, entry_functions); }; + return CreateModulePass(pass_func, 1, "DeadCodeElimination", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.DeadCodeElimination").set_body_typed(DeadCodeElimination); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/decompose_composite_ops.cc b/src/relax/transform/decompose_composite_ops.cc new file mode 100644 index 000000000000..36814422216b --- /dev/null +++ b/src/relax/transform/decompose_composite_ops.cc @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! \file src/relax/transform/simplify_norm_inference.cc */ + +#include +#include +#include +#include + +#include "utils.h" + +namespace tvm { +namespace relax { + +TensorStructInfo MatchTensorStructInfo(Expr data) { + auto _sinfo = MatchStructInfo(data); + ICHECK(_sinfo.defined()) << "Expect data to be a tensor, but get " << GetStructInfo(data); + return _sinfo.value(); +} + +Expr ExpandToMatchInput(Expr data, int ndim, Array axes) { + axes = GetOrderedPositiveAxes(axes, ndim); + Array expand_axes; + for (int i = 0, j = 0; i < ndim; ++i) { + if (j < static_cast(axes.size()) && i == axes[j]->value) { + ++j; + } else { + expand_axes.push_back(i); + } + } + return expand_dims(data, expand_axes); +} + +Expr SimplifyBatchNorm(const CallNode* call) { + auto attrs = call->attrs.as(); + ICHECK_NOTNULL(attrs); + + Expr data = call->args[0]; + TensorStructInfo sinfo = MatchTensorStructInfo(data); + Expr gamma = call->args[1]; + Expr beta = call->args[2]; + Expr moving_mean = ExpandToMatchInput(call->args[3], sinfo->ndim, {attrs->axis}); + Expr moving_var = ExpandToMatchInput(call->args[4], sinfo->ndim, {attrs->axis}); + + // output = (x - mean) / sqrt(var + epsilon) * gamma + beta + Expr epsilon = MakeConstantScalar(static_cast(attrs->epsilon), sinfo->dtype); + Expr sqrt_var = sqrt(add(moving_var, epsilon)); + Expr out = divide(subtract(data, moving_mean), sqrt_var); + + if (attrs->scale) { + out = multiply(out, ExpandToMatchInput(gamma, sinfo->ndim, {attrs->axis})); + } + if (attrs->center) { + out = add(out, ExpandToMatchInput(beta, sinfo->ndim, {attrs->axis})); + } + + return out; +} + +/*! \brief A mutator to simplify the normalization inference. */ +class NormInferenceSimplifier : public ExprMutator { + public: + static Expr Simplify(Expr expr) { return NormInferenceSimplifier()(expr); } + + private: + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const TupleGetItemNode* op) final { + Expr expr = ExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK_NOTNULL(op); + + auto it = batch_norm_map_.find(op->tuple); + if (it != batch_norm_map_.end() && op->index == 0) { + return (*it).second; + } else { + return expr; + } + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* val) final { + ExprMutator::VisitBinding_(binding, val); + if (val->op == Op::Get("relax.nn.batch_norm")) { + // NOTE: we won't directly replace the batch_norm call since + // the following bindings may depend on the returned moving_mean and moving_var. + // Instead, we will store the unpacked value in the batch_norm_map_, and replace it + // at the TupleGetItemNode. And the original batch_norm call will be removed in the + // follow-up pass `RemoveAllUnused` + batch_norm_map_.Set(binding->var, SimplifyBatchNorm(val)); + } + } + + private: + /*! \brief The mapping from binding var of batch_norm to the unpacked value. */ + Map batch_norm_map_; +}; + +class OpDecomposer : public ExprMutator { + public: + static Expr Decompose(Expr expr) { return OpDecomposer()(expr); } + + private: + using ExprMutator::VisitExpr_; + Expr TensorToShape(const Call& call_node) { + ICHECK(call_node->struct_info_.defined()); + Expr expr = call_node->args[0]; + const ShapeStructInfoNode* sinfo = GetStructInfoAs(call_node); + ICHECK(sinfo); + // call builtin function that converts tensor to shape tuple + // TODO(@sunggg): Register operator for "vm.builtin.tensor_to_shape" + Var call = builder_->Emit(Call(ExternFunc("vm.builtin.tensor_to_shape"), {expr}, {}, + {GetRef(sinfo)})); + + // Operators like reshape take the output of `TensorToShape` as their output shape. + // Because TOPI expects to have such output shape in symbolic shape at least (i.e., + // Array), we define symbolic variables and returns them as a ShapeExpr. + Array shape_var; + for (int i = 0; i < sinfo->ndim; i++) { + shape_var.push_back(tir::Var("x", DataType::Int(64))); + } + // bind symbolic variables to the shape tuple + relax::Var var("y", ShapeStructInfo(shape_var)); + builder_->EmitNormalized(MatchCast(var, call, ShapeStructInfo(shape_var))); + return ShapeExpr(shape_var); + } + + Expr VisitExpr_(const CallNode* call_node) final { + Call call = Downcast(VisitExprPostOrder_(call_node)); + if (call->op == tensor_to_shape_op_) { + return TensorToShape(call); + } else { + return call; + } + } + + const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape"); +}; + +namespace transform { +Pass DecomposeCompositeOps() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + f = Downcast(NormInferenceSimplifier::Simplify(f)); + f = Downcast(OpDecomposer::Decompose(f)); + // Remove original ops if it's not used. + return RemoveAllUnused(f); + }; + return CreateFunctionPass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"DecomposeCompositeOps", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.DecomposeCompositeOps").set_body_typed(DecomposeCompositeOps); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc new file mode 100644 index 000000000000..9c9252ddfa72 --- /dev/null +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/eliminate_common_subexpr.cc + * \brief Eliminrate common subexpression pass. + * + * Currently it removes common subexpressions within a DataflowBlock. + */ +#include +#include + +namespace tvm { +namespace relax { + +class SubexprCounter : public ExprVisitor { + public: + // overriding VisitExpr ensures we do this for every subexpression + void VisitExpr(const Expr& e) override { + // Cases we ignore because we will not substitute them: + // 1. Vars of all kinds + // 2. Op nodes (nothing we can do) + // 3. Scalar constants (not much benefit from binding to a var) + if (!(e->IsInstance() || e->IsInstance() || + e->IsInstance() || e->IsInstance() || + (e.as() && (e.as()->is_scalar())))) { + int count = 0; + if (count_map_.count(e)) { + count = count_map_.at(e); + } + count_map_[e] = count + 1; + } + ExprVisitor::VisitExpr(e); + } + + // do not visit inner functions: we will do CSE within those + void VisitExpr_(const FunctionNode* func) override {} + + // we are not going to do replacements inside struct info to avoid binding lots of reused shapes + void VisitExprDepStructInfoField(const StructInfo& struct_info) override {} + + std::unordered_map Count( + const DataflowBlock& df_block) { + for (auto binding : df_block->bindings) { + VisitBinding(binding); + } + return count_map_; + } + + private: + std::unordered_map count_map_; +}; + +// forward declaration +DataflowBlock EliminateCommonSubexpr(const DataflowBlock&); + +class CommonSubexprEliminator : public ExprMutator { + public: + explicit CommonSubexprEliminator( + const std::unordered_map& count_map) + : count_map_(count_map) {} + + // overriding here ensures we visit every subexpression + Expr VisitExpr(const Expr& e) override { + if (count_map_.count(e) && count_map_.at(e) > 1) { + // if we already have a mapping for it, get it + if (replacements_.count(e)) { + return replacements_.at(e); + } + // Otherwise, insert a new binding for the current expression. + // Visit before emitting to do inner replacements + Expr new_e = ExprMutator::VisitExpr(e); + Var v = builder_->Emit(new_e); + replacements_[e] = v; + return v; + } + return ExprMutator::VisitExpr(e); + } + + // we are not going to do replacements inside struct info to avoid binding lots of reused shapes + StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override { + return struct_info; + } + + Expr VisitExpr_(const FunctionNode* func) override { + // for an inner function, we will do CSE on its body + Expr new_body = ExprMutator::VisitExpr(func->body); + if (new_body.same_as(func->body)) { + return GetRef(func); + } + return Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span); + } + + // this should happen only for the inner function case + Expr VisitExpr_(const SeqExprNode* seq) override { + bool all_unchanged = true; + Array new_blocks; + // apply CSE within dataflow blocks only + for (auto block : seq->blocks) { + if (const DataflowBlockNode* df_block = block.as()) { + auto new_df_block = EliminateCommonSubexpr(GetRef(df_block)); + if (!new_df_block.same_as(block)) { + new_blocks.push_back(new_df_block); + all_unchanged = false; + continue; + } + } + new_blocks.push_back(block); + } + + if (all_unchanged) { + return GetRef(seq); + } + // do not visit the body + return SeqExpr(new_blocks, seq->body, seq->span); + } + + void VisitBinding_(const VarBindingNode* binding) override { + // no need to visit var def because the struct info isn't going to change + Expr new_value = RegisterBoundValue(binding->var, binding->value); + + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + // no need to renormalize new_value because all replacements are with vars + builder_->EmitNormalized(VarBinding(binding->var, new_value, binding->span)); + } + } + + void VisitBinding_(const MatchCastNode* binding) override { + // no need to visit var def because the struct info isn't going to change + Expr new_value = RegisterBoundValue(binding->var, binding->value); + + // re-emit old binding if nothing changes + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + // no need to renormalize new_value because all replacements are with vars + builder_->EmitNormalized( + MatchCast(binding->var, new_value, binding->struct_info, binding->span)); + } + } + + private: + Expr RegisterBoundValue(Var var, Expr bound_value) { + // special case: if we are processing a binding + // and this is the first time we've encountered it, + // we will use the binding's var for the mapping + bool newly_replaced = false; + if (count_map_.count(bound_value) && count_map_.at(bound_value) > 1 && + !replacements_.count(bound_value)) { + replacements_[bound_value] = var; + newly_replaced = true; + } + + if (newly_replaced) { + // If we've just added the mapping, using the overridden visitor will + // just return the var, which we don't want, so we will use + // the superclass VisitExpr to do inner substitutions + return ExprMutator::VisitExpr(bound_value); + } + return VisitExpr(bound_value); + } + + const std::unordered_map& count_map_; + std::unordered_map replacements_; +}; + +DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block) { + SubexprCounter counter; + auto count_map = counter.Count(df_block); + CommonSubexprEliminator eliminator(count_map); + return Downcast(eliminator.VisitBindingBlock(df_block)); +} + +namespace transform { + +Pass EliminateCommonSubexpr() { + runtime::TypedPackedFunc pass_func = + [=](DataflowBlock df_block, IRModule m, PassContext pc) { + return Downcast(EliminateCommonSubexpr(df_block)); + }; + return CreateDataflowBlockPass(pass_func, 1, "EliminateCommonSubexpr", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr") + .set_body_typed(EliminateCommonSubexpr); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc new file mode 100644 index 000000000000..db30900cd2a1 --- /dev/null +++ b/src/relax/transform/fold_constant.cc @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class ConstantFolder : public ExprMutator { + public: + static Function Fold(Function func, IRModule ctx_module) { + ConstantFolder folder(std::move(ctx_module)); + func = RemoveAllUnused(Downcast(folder(func))); + return func; + } + + private: + explicit ConstantFolder(IRModule ctx_module) : ExprMutator(ctx_module) {} + + /*! + * \brief Pattern match the shape inside the given struct info to a + * constant shape and get runtime shape tuple from it. + * \param struct_info The given struct info whose shape inside is to be casted. + * \return The runtime shape tuple, or nullopt if it is not a constant shape. + * \note Only TensorStructInfo is supported at this moment. Return NullOpt + * if the input struct info is not TensorStructInfo. + */ + static Optional MatchConstShape(const StructInfo& struct_info) { + // Only support single output for call_tir at this moment. + const auto* tensor_sinfo = struct_info.as(); + if (tensor_sinfo == nullptr) { + return NullOpt; + } + + const auto* shape = tensor_sinfo->shape.as(); + ICHECK(shape != nullptr) << "struct info given by call_tir should have ShapeExpr shape"; + + std::vector shape_values; + for (const auto v : shape->values) { + auto* ptr = v.as(); + if (!ptr) return NullOpt; + shape_values.push_back(ptr->value); + } + return runtime::ShapeTuple(shape_values.begin(), shape_values.end()); + } + + /*! + * \brief Pattern match op to constant array arguments. + * \return The constant array arguments, or nullopt if match fails. + */ + static Optional> MatchConstArrayArgs(const Array& args) { + Array res; + for (auto arg : args) { + auto* ptr = arg.as(); + if (!ptr) return NullOpt; + res.push_back(ptr->data); + } + return res; + } + + /*! + * \brief Pattern match op to a TIR function and look it up. + * \return The TIR function, or nullopt if pattern match fails. + */ + Optional MatchPrimFunc(const Expr& op) { + const GlobalVar& global_var = Downcast(op); + // NOTE: as check works for nullptr(returns null) + Optional base_func = builder_->GetContextIRModule()->functions.Get(global_var); + if (auto* pfunc = base_func.as()) { + return GetRef(pfunc); + } + return NullOpt; + } + + /*! + * \brief Get a cached build version of func + * \return The cached func, nullopt if func cannot be built. + */ + Optional GetCachedBuild(tir::PrimFunc func) { + // TODO(tvm-team): consider another way of bulk extract and build PrimFunc once + // would be helpful for future cases where PrimFunc recursively call into each other + Target eval_cpu_target{"llvm"}; + + auto it = func_build_cache_.find(func); + if (it != func_build_cache_.end()) { + return it->second; + } + Optional build_func = NullOpt; + + try { + // Not all the primfunc can be directly built via llvm, for example, if a function is + // already scheduled to only work on GPU, we will need to skip this in the const folder for + // now + // TODO(Hongyi): further check and narrow the scope of foldable function + runtime::Module rt_module = + build(LowerPrimFunc(func, "tir_function"), eval_cpu_target, eval_cpu_target); + build_func = rt_module.GetFunction("tir_function"); + } catch (const tvm::Error& err) { + // build failure may happen in which case we skip + DLOG(WARNING) << "Build failure for function " << func << ", Error message: " << err.what(); + } + func_build_cache_[func] = build_func; + return build_func; + } + + /*! + * \brief Checks if it is useful to fold \p expr. + * \details Folding an expr is a trade-off - we are materializing a constant in the IRModule and + * paying compile time cost to avoid the cost of executing this expr at runtime. For example, + * folding iota ops could result in large constants being materialized, thus increasing the size + * of the program. + */ + bool ShouldBeFolded(Expr expr) { + // TODO(prakalp): Implement a heuristic to check if folding this expr is actually useful or + // not. + return true; + } + + // Try constant evaluate the function call + // if failed return NullOpt + Optional ConstEvaluateCallTIR(tir::PrimFunc tir_func, Array arr_args, + runtime::ShapeTuple shape, DataType ret_type) { + // obtain function from the cache. + Optional func = GetCachedBuild(tir_func); + if (!func) return NullOpt; + + // here the vector size has an additional + 1 because we need to put ret_tensor at the end + std::vector values(arr_args.size() + 1); + std::vector type_codes(arr_args.size() + 1); + + DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0}; + runtime::NDArray ret_tensor = runtime::NDArray::Empty(shape, ret_type, cpu_dev); + + // avoid set rvalue ref which get de-allocated later, store args in a vector + // where temp_args[i] are lvalue ref that is stable + std::vector temp_args(arr_args.begin(), arr_args.end()); + + size_t arg_offset = 0; + for (; arg_offset < arr_args.size(); ++arg_offset) { + runtime::TVMArgsSetter(values.data(), type_codes.data())(arg_offset, temp_args[arg_offset]); + } + // set return value + runtime::TVMArgsSetter(values.data(), type_codes.data())(arg_offset++, ret_tensor); + + TVMRetValue ret; + // invoke + func.value().CallPacked(TVMArgs(values.data(), type_codes.data(), values.size()), &ret); + return Constant(ret_tensor); + } + + // Returns the folded expr if the call is successfully folded to constant, otherwise null. + Optional VisitCallTIR(Call call) { + // call_tir needs to have at least three arguments + ICHECK_GE(call->args.size(), 2); + Optional func = MatchPrimFunc(call->args[0]); + ICHECK(call->args[1].as()) << "call_tir.args[1] must be Tuple"; + Optional> arr_args = + MatchConstArrayArgs(call->args[1].as()->fields); + ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg"; + Optional shape = MatchConstShape(call->sinfo_args[0]); + bool output_not_tuple = call->sinfo_args.size() == 1; + // Pattern 0: call constant function, const argument with const shape. + if (func && arr_args && shape && output_not_tuple) { + DynTensorType ret_type = Downcast(call->checked_type()); + // value_or will return value if it is not null, otherwise return or + return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), ret_type->dtype) + .value_or({}); + } + // TODO(hongyi): support const-fold tuple outputs + return {}; + } + + using ExprMutator::VisitExpr_; + + // TODO(@sunggg): + // Next PR will support fold with PackedFunc and MatchCast + // Until then, DecomposeCompositeOps() should be applied after + // this pass to fold `tensor_to_shape` op. + Expr VisitExpr_(const CallNode* call) final { + // post-order mutation + Call post_call = Downcast(VisitExprPostOrder_(call)); + + // Check if it is useful to fold this call + if (!ShouldBeFolded(post_call)) return post_call; + + static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const auto& legalize_map = Op::GetAttrMap("FLegalize"); + auto* op_node = post_call->op.as(); + + // Not an OpNode + if (op_node == nullptr) { + return post_call; + } + auto op = GetRef(op_node); + + if (op.same_as(call_tir_op)) { + return VisitCallTIR(post_call).value_or(post_call); + } + + // Special logic to fold ShapeExpr between operators + // e.g., + // + // lv: R.Shape([16, 16]) = R.shape([16, 16]) + // gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, lv) + // + // gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, R.shape([16, 16])) + // + Array new_args; + for (auto arg : post_call->args) { + if (arg->IsInstance()) { + Optional val = LookupBinding(Downcast(arg)); + if (val.defined() && val.value()->IsInstance()) { + new_args.push_back(val.value()); + continue; + } + } + new_args.push_back(arg); + } + post_call = + Call(post_call->op, new_args, post_call->attrs, post_call->sinfo_args, post_call->span); + + // If we are in a dataflow block, we can fold ops. + if (builder_->CurrentBlockIsDataFlow()) { + // Check if we can them to call_tir + if (legalize_map.count(op)) { + // Get the legalized expression + Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_, post_call)); + // If the legalized expression is call_tir, try to fold it. + const CallNode* call = legalized_expr.as(); + if (call && call->op.same_as(call_tir_op)) { + return VisitCallTIR(GetRef(call)).value_or(post_call); + } + } else if (op->name == "relax.tensor_to_shape") { + // Special handling for composite op "relax.tensor_to_shape" + // If its input is constant, we can access its value and create ShapeExpr + // TODO(@sunggg): + // currently, we do not have a info map about decomposition. + // Thus, this is a temporary solution until we have a consensus about + // how to deal with composite ops. One possibility is we register the + // decomposition map for each op in a similar way we do for legalization. + ICHECK_EQ(post_call->args.size(), 1); + Expr arg = post_call->args[0]; + if (arg->IsInstance()) { + Constant constant = Downcast(arg); + runtime::NDArray ndarray = constant->data; + ICHECK_EQ(ndarray->device.device_type, kDLCPU); + ICHECK(ndarray->strides == nullptr); + ICHECK_EQ(ndarray->byte_offset, 0); + ICHECK_EQ(ndarray->ndim, 1); + const int64_t* data = static_cast(ndarray->data); + int64_t num_elems = ndarray->shape[0]; + Array shape_values; + for (int64_t i = 0; i < num_elems; i++) { + shape_values.push_back(IntImm(DataType::Int(64), data[i])); + } + return ShapeExpr(shape_values); + } + } else if (op->name == "relax.shape_to_tensor") { + // Special handling for "relax.shape_to_tensor" since it is implemented in PackedFunc. + // TODO(sunggg): revisit this when we extend ConstantFolding to fold PackedFunc. + Expr arg = post_call->args[0]; + ShapeExpr shape = Downcast(arg); + Array values = shape->values; + Array arr; + bool is_known = true; + for (size_t i = 0; i < values.size(); i++) { + PrimExpr val = values[i]; + arr.push_back(GetRef(val.as())); + is_known &= (val.dtype() == DataType::Int(64)); + } + if (is_known) { + const auto* func = tvm::runtime::Registry::Get("relax.run.shape_to_tensor"); + ICHECK(func != nullptr); + runtime::NDArray vals = (*func)(arr); + return Constant(vals); + } + } + } + + return std::move(post_call); + } + + Expr VisitExpr_(const DataflowVarNode* op) final { + Optional opt = LookupBinding(GetRef(op)); + // `as` check checks if opt is not null and is instance of constant + if (opt.as()) { + return opt.value(); + } + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const VarNode* op) final { + Optional opt = LookupBinding(GetRef(op)); + // `as` check checks if opt is not null and is instance of constant + if (opt.as()) { + return opt.value(); + } + return ExprMutator::VisitExpr_(op); + } + + // cache for function build, via structural equality + std::unordered_map, StructuralHash, StructuralEqual> + func_build_cache_; +}; + +namespace transform { + +Pass FoldConstant() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return ConstantFolder::Fold(f, m); }; + return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc new file mode 100644 index 000000000000..8e4346e2062b --- /dev/null +++ b/src/relax/transform/fuse_ops.cc @@ -0,0 +1,1234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/fuse_ops.cc + * \brief This file contains a pass which groups bindings in a dataflow block of Relax + * functions and generate a new grouped Relax function for each group, according to the fusion + * algorithm described below. By grouping bindings into new Relax functions, we substitute the + * bindings in the function being manipulated into function calls to the new grouped function. + * + * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../../relay/analysis/graph_partitioner.h" +#include "../../support/arena.h" +#include "tvm/relax/expr.h" + +namespace tvm { +namespace relax { + +/* + Note on Fusing algorithm: + + The main challenge of general fusor is to handle possible diamond shape branches, + in the following graph, conv2d can be fused to elemwise add. + + conv2d + / | \ + / | \ + op op op + \ | / + \ | / + elemwise add + | + + However, at the point of conv2d we do not necessarily know that all the future paths + will merge at the elemwise add. The fusion algorithm applies post-dominator analysis. + + The immediate post-dominator of a node defined by the closest node where all the future path goes + into. In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm + is as follows: + + - Construct a DAG of dataflow graph for dominator analysis + - Construct a post-dominator tree which gives immediate post dominator of each node. + - Run fusion algorithm with the given post-dominator information. + + Note that, because we run analysis on a DAG, we use a single pass post-dominator + tree construction algorithm via LCA, which is simpler than the full version that handles cycles. + + The fusion algorithm traverses from each node and checks if it can be fused to its + immediate post dominator. It has to check the following things: + + - CheckPath: check all the path between a node and its immediate post-dominator + satisfies the fuse condition. + - Note that these intermediate node can already be fused with another nodes, the algorithm + will still run correctly. + - CommitFuse: mark all the nodes between source and post-dominator as the same group. + - We use an Union-Find data structure to manage the groups. +*/ + +using relay::GraphPartitioner; +using relay::IndexedForwardGraph; +using relay::OpPatternKind; +using support::LinkNode; + +constexpr uint32_t kMaxFusedOps = 256; + +TVM_REGISTER_PASS_CONFIG_OPTION("relax.FuseOps.max_depth", Integer); + +class GraphCreator : public ExprVisitor { + public: + /*! + * \brief Create a IndexedForwardGraph according to the input module. The graph will be used for + * graph partition and operator fusion. + * \param mod The module which the creation accords to + * \param arena The allocator of all the internal node objects + * \return The created IndexedForwardGraph + */ + static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) { + GraphCreator creator(mod, arena); + for (const auto& it : mod->functions) { + // Only visit Relax function without attr kPrimitive. + const auto* func = it.second.as(); + if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive)) { + continue; + } + creator(GetRef(func)); + } + + // The algorithm of the graph creator ensures that each created node will be added to the + // post-dfs order and will be set its op pattern. Thus we check whether all these containers + // have the same size. + size_t n_nodes = creator.graph_.node_map.size(); + ICHECK_EQ(n_nodes, creator.graph_.post_dfs_order.size()); + ICHECK_EQ(n_nodes, creator.initialized_nodes_.size()); + + return creator.graph_; + } + + private: + explicit GraphCreator(IRModule mod, support::Arena* arena) + : mod_(std::move(mod)), arena_(arena) {} + + void VisitExpr_(const FunctionNode* func) final { + for (const Var& param : func->params) { + IndexedForwardGraph::Node* param_node = CreateNode(param.get()); + // The parameter is passed in from the outside, and thus it's marked as an external reference, + // and it's pattern is `kOpaque`. + MarkAsExternRef(param_node); + SetNodePattern(param_node, OpPatternKind::kOpaque); + AddToPostDFSOrder(param_node, param.get()); + } + ExprVisitor::VisitExpr_(func); + } + + void VisitBindingBlock(const BindingBlock& block) final { + if (const auto* df_block = block.as()) { + VisitBindingBlock_(df_block); + } + // We skip ordinary binding blocks since they might be impure (with side effect or control flow) + } + + // TODO(tvm-team): how to deal with MatchCast binding here + + void VisitBinding_(const VarBindingNode* binding) final { + IndexedForwardGraph::Node* node = CreateNode(binding->var.get()); + + // If the variable is not a dataflow variable, it must be the output variable of this dataflow + // block + if (!binding->var->IsInstance()) { + this->MarkAsExternRef(node); + } + if (const auto* call = binding->value.as()) { + // Case 1. The expression is a CallNode + VisitCall(call, node); + } else if (const auto* tuple_get_item = binding->value.as()) { + // Case 2. The expression is a TupleGetItemNode + VisitTupleGetItem(tuple_get_item, node); + } else { + VisitUnsupportedNode(binding->value, node); + // Case 3. The type of the expression is not fusion-supported. + // In this case, we skip adding edges, adding an empty node into graph. + } + AddToPostDFSOrder(node, binding->var.get()); + } + + /********** Non-Leaf Expression Nodes **********/ + + void VisitCall(const CallNode* call, IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + OpPatternKind pattern = OpPatternKind::kOpaque; + Array args = call->args; + + // - If the op being called is a TIR PrimFunc, we get the function op pattern directly from the + // function attribute and visit the arguments one by one. + // - Otherwise, the pattern of the current binding variable node is set to `kOpaque`, and we + // recurse into the call expression. + const auto* op = call->op.as(); + if (op == call_tir_op_.get()) { + const GlobalVar& global_var = Downcast(call->args[0]); + tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); + + // Override args for call_tir + args = Downcast(call->args[1])->fields; + + Optional opt_pattern = func->GetAttr("op_pattern"); + if (opt_pattern.defined()) { + pattern = static_cast(Downcast(opt_pattern)->value); + } else { + pattern = OpPatternKind::kOpaque; + } + } + // The pattern of the current binding variable node is set to the pattern of this operator. + SetNodePattern(binding_var_node, pattern); + // Visit all call args + for (const Expr& arg : args) { + ICHECK(IsLeafOrTuple(arg)); + VisitLeaf(arg, binding_var_node, pattern); + } + } + + void VisitTupleGetItem(const TupleGetItemNode* tuple_item, + IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + + SetNodePattern(binding_var_node, OpPatternKind::kInjective); + VisitLeaf(tuple_item->tuple, binding_var_node, OpPatternKind::kInjective); + } + + void VisitUnsupportedNode(const Expr& expr, IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + SetNodePattern(binding_var_node, OpPatternKind::kOpaque); + + auto visit_leaves = [this, &binding_var_node](const Expr& e) { + if (e->IsInstance() || e->IsInstance()) { + VisitLeaf(e, binding_var_node, OpPatternKind::kOpaque); + } + }; + PostOrderVisit(expr, visit_leaves); + } + + /********** Leaf Expression Nodes **********/ + + void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node* binding_var_node, + const OpPatternKind& pattern) { + ICHECK_NOTNULL(binding_var_node); + + // Recursive visit if it's Tuple + if (const auto* tuple = leaf_expr.as()) { + for (const Expr& expr : tuple->fields) { + VisitLeaf(expr, binding_var_node, pattern); + } + return; + } + + if (!leaf_expr->IsInstance()) { + // Skip GlobalVar, ExternFunc, OpNode. + return; + } + + auto it = graph_.node_map.find(leaf_expr.get()); + IndexedForwardGraph::Node* leaf_node = nullptr; + if (it != graph_.node_map.end()) { + leaf_node = it->second; + } else if (leaf_expr->IsInstance()) { + leaf_node = CreateNode(leaf_expr.get()); + // Since we never fuse constants, the pattern of the constant is set to `kOpaque`. + SetNodePattern(leaf_node, OpPatternKind::kOpaque); + AddToPostDFSOrder(leaf_node, leaf_expr.get()); + } else { + LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got: " << leaf_expr + << " used before definition."; + } + AddEdge(leaf_node, binding_var_node, pattern); + } + + /********** Helper Functions **********/ + + /*! + * \brief Create a graph node corresponding to the input key + * \param key The object which is used to create the graph node + * \return The created graph node + * \note The node corresponding to each key is supposed to be created for only once + */ + IndexedForwardGraph::Node* CreateNode(const Object* key) { + ICHECK(graph_.node_map.find(key) == graph_.node_map.end()) + << "The node corresponding to the input key is not supposed to be created before"; + auto* node = arena_->make(); + graph_.node_map[key] = node; + return node; + } + + /*! + * \brief Append the input node to the post-dfs order of the graph + * \param node The node to be appended + * \param key The key corresponding to the node + * \note Each node is supposed to be appended to the post-dfs order for only once + */ + void AddToPostDFSOrder(IndexedForwardGraph::Node* node, const Object* key) { + auto it = graph_.node_map.find(key); + ICHECK(it != graph_.node_map.end() && it->second == node) + << "The node must have been created before adding to the post-dfs order"; + + // We only set the reference of the node when adding it to the post-dfs order. Thus, if the + // reference of a node is already set, it must have been appended to the post-dfs order. + ICHECK(node->ref == nullptr) + << "The node is not supposed to be added into the post-dfs order before"; + + node->ref = key; + node->index = graph_.post_dfs_order.size(); + graph_.post_dfs_order.push_back(node); + } + + /*! + * \brief Add an edge from the input start to the input end in the graph, with specific pattern + * \param start The start of the edge + * \param end The end of the edge + * \param pattern The pattern of this edge + */ + void AddEdge(IndexedForwardGraph::Node* start, IndexedForwardGraph::Node* end, + OpPatternKind pattern) { + auto* link = arena_->make>(); + link->value.node = end; + link->value.pattern = pattern; + start->outputs.Push(link); + } + + /*! + * \brief Mark a given node as "external reference", which means the node cannot be fused as an + * intermediate node + * \param node The graph node to be marked + */ + void MarkAsExternRef(IndexedForwardGraph::Node* node) { node->extern_ref = true; } + + /*! + * \brief Set the pattern of the input node + * \param node The graph node to be set + * \param pattern The pattern of the node + */ + void SetNodePattern(IndexedForwardGraph::Node* node, OpPatternKind pattern) { + ICHECK(initialized_nodes_.find(node) == initialized_nodes_.end()) + << "The input node is supposed to be set pattern for only once"; + initialized_nodes_.insert(node); + node->pattern = pattern; + } + + private: + /*! \brief The IRModule from which the indexed forward graph is created */ + IRModule mod_; + /*! \brief The allocator of all the internal node objects */ + support::Arena* arena_; + /*! \brief The created indexed forward graph */ + IndexedForwardGraph graph_; + /*! \brief The graph nodes whose patterns are set */ + std::unordered_set initialized_nodes_; +}; + +/*! + * \brief Renew the definition of symbolic vars in Relax. + * \details This mutator is used to prevent the same symbolic var from being used in different + * functions, which is malformed. + */ +class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator { + public: + static Function Renew(const Function& function) { + SymbolicVarRenewMutator mutator; + return Downcast(mutator.VisitExpr(function)); + } + + private: + SymbolicVarRenewMutator() = default; + using relax::ExprMutator::VisitExpr; + using relax::ExprMutator::VisitExpr_; + using tir::ExprMutator::VisitExpr_; + + PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return tir::ExprMutator::VisitExpr(expr); } + + // TODO(Siyuan): enhance the method to the following steps: + // 1. Visit and replace all tir::Vars at the definition point + // 2. Revisit the function again and update the use side. + PrimExpr VisitExpr_(const tir::VarNode* op) final { + auto it = var_map_.find(GetRef(op)); + if (it != var_map_.end()) { + return (*it).second; + } else { + auto n = make_object(*op); + tir::Var v(n); + var_map_.Set(GetRef(op), v); + return v; + } + } + + private: + Map var_map_; +}; + +/*! + * \brief The ExprMutator used to create a new grouped function + * \details The workflow of this ExprMutator is: + * - The bindings in the function will be added by OperatorFusor via `AppendBinding(...)`. + * - When adding a new binding through `AppendBinding(...)`, we check whether the variables and + * constants used by the binding are defined by some previous added binding. And for the undefined + * variables and constants, we add them to the argument list and created new variables as the + * corresponding parameters. + * - When `CreateFunction()` is called, we go through each binding and update the binding with the + * new parameters. After that we wrap all bindings with a DataflowBlock and a Function. + */ +class FunctionCreator : public ExprMutator { + public: + explicit FunctionCreator(bool lift_constant) : lift_constant_(lift_constant) {} + /*! + * \brief Append a new binding to this function and possibly create new parameters for the + * function accordingly + * \param binding The binding to be appended + * \note Allowed bindings are: + * - VarBinding with value being a call node calling `relax.call_tir`. + * - VarBinding with value being a tuple-get-item node. + * // TODO(tvm-team): handle match shape + */ + void AppendBinding(const Binding& binding) { + ICHECK(!function_.defined()) + << "The `function_` is supposed to be uncreated when adding bindings"; + + if (const auto* var_binding = binding.as()) { + if (const auto* call = var_binding->value.as()) { + if (call->op == Op::Get("relax.call_tir")) { + // Update the name of the function. + name_hint_ = name_hint_ + "_" + Downcast(call->args[0])->name_hint; + + const Tuple& args = Downcast(call->args[1]); + for (const Expr& arg : args->fields) { + CheckDefAndUpdateParam(arg); + } + // TODO(tvm-team): handle shape expr + } else { + if (call->op->IsInstance()) { + name_hint_ = name_hint_ + "_" + Downcast(call->op)->name; + } else if (call->op->IsInstance()) { + std::string gvar_name = Downcast(call->op)->name_hint; + if (auto pos = gvar_name.find("fused_"); pos == 0) { + name_hint_ = name_hint_ + "_" + gvar_name.substr(std::string("fused_").size()); + } else { + name_hint_ = name_hint_ + "_" + gvar_name; + } + } + + for (const Expr& arg : call->args) { + CheckDefAndUpdateParam(arg); + } + } + } else { + const auto* tuple_item = var_binding->value.as(); + ICHECK(tuple_item != nullptr); + CheckDefAndUpdateParam(tuple_item->tuple); + } + + // Mark the binding variable as defined. + defined_vars_.insert(var_binding->var.get()); + // Set var as output true if the binding is not a dataflow variable + if (!var_binding->var->IsInstance()) { + AppendOutput(var_binding->var); + } + } else { + // TODO(tvm-team): handle match_cast + } + bindings_.push_back(binding); + } + + /*! \brief Set a var defined in the group as output. */ + size_t AppendOutput(const Var& var) { + ICHECK(defined_vars_.count(var.get())); + auto output_idx = GetOutputIndex(var); + if (output_idx) { + return *output_idx; + } + output_vars_.push_back(var.get()); + return output_vars_.size() - 1; + } + + /*! + * \brief Create the grouped function according according to the collected bindings and parameters + * \param composite_name The name to identify the pattern this function is created from, if any. + * It will become the value of the kComposite attribute of the created function. + * \note The created function won't be returned immediately. It's stored in the `function_` field. + */ + void CreateFunction(Map group_attrs) { + // Step 1. Start constructing a new dataflow block. + builder_->BeginDataflowBlock(); + + // Step 2. Visit each binding and collect outputs one by one. + Array outputs(output_vars_.size(), Expr()); + for (const Binding& binding : bindings_) { + if (auto output_idx = GetOutputIndex(binding->var)) { + // Case 1. It is an output binding + // We only allow VarBinding as output. + const auto* var_binding = binding.as(); + ICHECK_NOTNULL(var_binding); + Var output_var = builder_->EmitOutput(VisitExpr(var_binding->value)); + var_remap_[var_binding->var->vid] = output_var; + outputs.Set(*output_idx, output_var); + } else { + // Case 2. It is an internal binding, add it to the binding list. + VisitBinding(binding); + } + } + + // Step 3. Finish constructing the new block. + BindingBlock new_block = builder_->EndBlock(); + if (outputs.empty()) { + // If the result is not used outside + LOG(WARNING) << "There are dead codes in the current IRModule, please run the " + "DeadCodeElimination Pass before FuseOps"; + function_ = NullOpt; + } else { + Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs); + body = builder_->Normalize(body); + body = builder_->Normalize(SeqExpr({new_block}, body)); + group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1)); + function_ = SymbolicVarRenewMutator::Renew(Function(/*params=*/params_, // + /*body=*/body, // + /*ret_struct_info=*/NullOpt, // + /*attrs=*/DictAttrs(group_attrs))); + } + } + + /*! \brief The original bindings of the function */ + Array bindings_; + /*! \brief The parameters of the function */ + Array params_; + /*! \brief The arguments to call the function on the caller side */ + Array arguments_; + /*! \brief The name for the fused function */ + String name_hint_ = "fused"; + /*! \brief The constructed Relax function */ + Optional function_ = NullOpt; + + private: + std::optional GetOutputIndex(Var v) { + auto it = std::find(output_vars_.begin(), output_vars_.end(), v.get()); + if (it != output_vars_.end()) { + return std::distance(output_vars_.begin(), it); + } + return std::nullopt; + } + + /*! + * \brief Check whether the input expression is defined within this function. If not, create a new + * parameter for the expression. + * \param expr The expression to be checked + */ + void CheckDefAndUpdateParam(const Expr& expr) { + // If the expression has already served as an argument, no need to create another one for it. + if (std::find(arguments_.begin(), arguments_.end(), expr) != arguments_.end()) { + return; + } + + // If the expression is not a variable or is a undefined variable, it should be populated as a + // parameter of the relax function. + const auto* var = expr.as(); + if ((var == nullptr || defined_vars_.count(var) == 0) && + (lift_constant_ || !expr->IsInstance())) { + String name{nullptr}; + if (var != nullptr) { + name = var->name_hint(); + } else { + name = String("param_" + std::to_string(n_param_for_const_++)); + } + + Var param(std::move(name), GetStructInfo(expr)); + arguments_.push_back(expr); + params_.push_back(param); + } + } + + Expr VisitExpr(const Expr& expr) final { + // If the expression serves as an argument, return its correspondng parameter. + auto it = std::find(arguments_.begin(), arguments_.end(), expr); + if (it != arguments_.end()) { + return params_[it - arguments_.begin()]; + } + // Otherwise, recurse into this expression. + return ExprMutator::VisitExpr(expr); + } + + private: + /*! \brief The variables defined in this function */ + std::unordered_set defined_vars_; + /*! \brief The number of parameters reserved for constants */ + int n_param_for_const_ = 0; + /*! \brief The output vars */ + std::vector output_vars_; + /*! \brief Whether or not to lift bound constants to parameters */ + bool lift_constant_; +}; + +/*! + * \brief The ExprMutator used to fuse the operators in Relax functions + * \details Given the partition results on the indexed-forward graph, for each group whose size is + * larger than one, we create a new grouped function for it, containing all bindings in that group. + * And we substitute the bindings in a group with a single function call to the newly created + * grouped function. The workflow of this ExprMutator is: for each dataflow block, + * - we go through the bindings one by one. For each binding, if it is in a group whose size is + * larger than one, we add the binding to the function of the group it is in and update the + * parameters and arguments of that function; + * - then we finalize all the grouped functions by updating their bindings using BlockBuilder; + * - lastly, we go through the bindings again and substitute the bindings in a group with a single + * call to the corresponding grouped function. + * + * After transforming a Relax function, we update the function in the IRModule. Besides, we add all + * newly created grouped function to the IRModule. + */ +class OperatorFusor : public ExprMutator { + public: + using Group = GraphPartitioner::Group; + using GroupMap = std::unordered_map; + + OperatorFusor(IRModule mod, const GroupMap& obj2group, bool lift_constants = true) + : ExprMutator(mod), + mod_(std::move(mod)), + obj2group_(obj2group), + lift_constants_(lift_constants) {} + + /*! + * \brief Construct a new operator fusor. Given the indexed-forward graph and the graph partition + * result on that graph, the constructor creates a mapping from each leaf AST object + * (e.g. parameters, variables, constants) to the group of the node corresponding to the object + * in the graph. + * \param mod The IRModule to be transformed + * \param graph The indexed-forward graph of the input IRModule + * \param groups The grouped result of the group partition on the input indexed-forward graph. + * \param lift_constant Whether or not to lift bound constants to parameters of the grouped + * function. + */ + OperatorFusor(IRModule mod, const IndexedForwardGraph& graph, const std::vector& groups, + bool lift_constant = true) + : OperatorFusor(mod, CreateGroupMap(graph, groups), lift_constant) {} + + /*! + * \brief The main transformation on the IRModule + * \return The new IRModule after transformation + */ + IRModule Transform() { + for (const auto& [gv, func] : mod_->functions) { + // Only visit Relax function without attr kPrimitive. + if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { + auto updated_func = Downcast(VisitExpr(func)); + builder_->UpdateFunction(gv, updated_func); + } + } + return builder_->GetContextIRModule(); + } + + private: + static GroupMap CreateGroupMap(const IndexedForwardGraph& graph, + const std::vector& groups) { + GroupMap obj2group; + for (int nid = 0; nid < static_cast(graph.post_dfs_order.size()); ++nid) { + Group* group_root = groups[nid]->FindRoot(); + ICHECK(group_root != nullptr); + ICHECK(graph.post_dfs_order[nid]->ref != nullptr); + obj2group[graph.post_dfs_order[nid]->ref] = group_root; + } + return obj2group; + } + + bool IsTupleOutput(Function f) { + auto sinfo = GetStructInfo(f).as(); + ICHECK(sinfo); + return sinfo->ret->IsInstance(); + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) final { + if (const auto* df_block = block.as()) { + return VisitBindingBlock_(df_block); + } + // We skip ordinary binding blocks since they might be impure (with side effect or control flow) + return block; + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { + group2func_.clear(); + + // Step 1. Collect the bindings for each grouped function. + CollectFuncBindings(block->bindings); + + // Step 2. Collect all group's boundary (i.e. the output vars for each group) + CollectFuncBoundary(block->bindings); + + // Step 3. Create the grouped function for each group. + for (auto& [g, creator] : group2func_) { + creator.CreateFunction(g->attrs); + } + + // Step 4. Start generating the new binding block. + // - For groups with single binding, we directly recurse into the binding and emit the new one. + // - For groups with multiple bindings, we emit the call to the grouped function only when + // visiting the last binding of the group, because only by doing this we don't break the + // dependencies among the bindings of different groups. And therefore, we will skip all but the + // last binding of the group. + builder_->BeginDataflowBlock(); + + // For each group, record which variables need to be remapped to the output of TupleGetItem. + // Only relevant when the output of the grouped function is a tuple. + std::unordered_map> pending_tuple_get; + + // A grouped function which returns a tuple requires attaching TupleGetItem to each element and + // remapping variables in earlier bindings appropriately. Thus, a binding whose value depends on + // some elements of a tuple from other group's function must be emitted after a call to the + // tuple-producing function is emitted and remapping is done. + // To guarantee this, we process bindings in the order of the topological sort of the group + // dependency relations. + for (const auto& binding : TopoSortByGroupDep(block->bindings)) { + // Case 1. If the binding is the only binding in its group, recurse into it and emit the + // transformed binding as usual. + Group* group = GetGroupFromBinding(binding); + if (group->num_nodes == 1 && group->attrs.empty()) { + VisitBinding(binding); + continue; + } + + const auto& it_creator = group2func_.find(group); + ICHECK(it_creator != group2func_.end()); + const FunctionCreator& func_info = it_creator->second; + + if (!func_info.function_.defined()) { + // The function is not created yet, so we skip the binding. + continue; + } + const Function& func = func_info.function_.value(); + + // If this binding belongs to a group whose output is a tuple, the original bound variable + // needs to be remapped to the output of TupleGetItem after the corresponding tuple is + // emitted. + if (IsTupleOutput(func) && tuple_get_indices_.count(binding->var.get())) { + pending_tuple_get[group].push_back(binding->var); + } + + // Case 2. If the binding is not the last binding of the group, we skip it. + if (!func_info.bindings_.back().same_as(binding)) { + continue; + } + + // Case 3. The binding is the last binding of the group. + const auto* var_binding = binding.as(); + ICHECK(var_binding != nullptr) << "The last binding of a group whose size is larger than 1 " + "is supposed to be a variable binding"; + + // Step a. Add the grouped function to the IRModule + GlobalVar gv = builder_->AddFunction(func, func_info.name_hint_); + + // Step b. Create the call to the deduplicated function, and then emit the call. + // - If this binding is an output binding, emit an output variable. + // - Otherwise, emit a dataflow variable. + Var new_var; + Call call_to_emit = Call(gv, UpdateArgs(func_info.arguments_)); + + if (var_binding->var->IsInstance()) { + new_var = builder_->Emit(call_to_emit); + } else { + new_var = builder_->EmitOutput(call_to_emit); + } + + // Step c. Update the mapping used for the remapping of the binding variables. + if (IsTupleOutput(func)) { + // If the output is a tuple, attach TupleGetItem to all tuple elements, and + // remap variables approriately. + // The variables that need to be remapped and the corresponding tuple indices are + // available in pending_tuple_get and tuple_get_indices_ respectively. + for (const auto& var : pending_tuple_get[group]) { + auto tuple_get = TupleGetItem(new_var, tuple_get_indices_[var.get()]); + var_remap_[var->vid] = builder_->Emit(tuple_get); + } + } else { + var_remap_[var_binding->var->vid] = new_var; + } + } + // Step 5. Finish the binding block generation. + return builder_->EndBlock(); + } + + /*! + * \brief Collect the bindings for each grouped function and update the information of the grouped + * function + * \param bindings The bindings to be collected + * \note The function update is done by `AppendBinding(...)` + */ + void CollectFuncBindings(const Array& bindings) { + for (const Binding& binding : bindings) { + // If the binding is the only binding in its group, there is no need to create a new function. + Group* group = GetGroupFromBinding(binding); + if (group->num_nodes == 1 && group->attrs.empty()) { + continue; + } + // Add the binding to the grouped function it's in, and update the function information + // accordingly. + if (!group2func_.count(group)) { + group2func_.emplace(group, lift_constants_); + } + group2func_.find(group)->second.AppendBinding(binding); + } + } + + void CollectFuncBoundary(const Array& bindings) { + for (const Binding& binding : bindings) { + // Step 1. Get current binding's group + Group* cur_group = GetGroupFromBinding(binding); + + // Step 2. Collect all used vars in the binding value and update bondary. + // - If the var's group is same as the binding's, the var is defined in the same group + // - If the var's group is different with the binding's, the var must be the output from + // another group. Mark it to be the group output. + auto update_boundary = [this, binding, &cur_group](const Expr& e) { + if (e->IsInstance()) { + const Var& used_var = Downcast(e); + Group* producer_group = GetGroupFromVar(used_var); + // Only check those group defined before. + // Skip the vars from input or groups with single binding. + if (producer_group != cur_group) { + for (Group* depgroup : group_deps_[producer_group]) { + ICHECK(depgroup != cur_group) + << "A cyclic dependency detected between the groups " << binding->var->name_hint() + << " and " << used_var->name_hint() << " are in."; + } + group_deps_[cur_group].push_back(producer_group); + } + + if (auto producer = group2func_.find(producer_group); + producer_group != cur_group && producer != group2func_.end()) { + auto output_index = producer->second.AppendOutput(used_var); + tuple_get_indices_[used_var.get()] = output_index; + } + } + }; + + if (const auto* var_binding = binding.as()) { + PostOrderVisit(var_binding->value, update_boundary); + } else { + const auto* match_cast = binding.as(); + ICHECK_NOTNULL(match_cast); + PostOrderVisit(match_cast->value, update_boundary); + } + } + } + + /*! + * \brief Get the group which the input binding is in + * \param binding The binding to be queried + * \return The pointer to the group which the input binding is in + */ + Group* GetGroupFromBinding(const Binding& binding) { + Var var = binding->var; + return GetGroupFromVar(var); + } + + /*! + * \brief Get the group which the input var is in + * \param Var The var to be queried + * \return The pointer to the group which the input var is in + */ + Group* GetGroupFromVar(const Var& var) { + const auto& it_group = obj2group_.find(var.get()); + ICHECK(it_group != obj2group_.end()); + Group* group = it_group->second; + return group->FindRoot(); + } + + /*! + * \brief Update the pre-stored arguments according to the variable remapping of the fusor, by + * recursing into each argument + * \param args The arguments to be updated + * \return The updated arguments + */ + Array UpdateArgs(const Array& args) { + Array new_args; + new_args.reserve(args.size()); + for (const Expr& arg : args) { + new_args.push_back(VisitExpr(arg)); + } + return new_args; + } + + private: + // Topologically sort bindings according to the group dependency relations. + Array TopoSortByGroupDep(const Array& bindings) { + std::unordered_map> bindings_per_group; + // The order to visit groups should respect the original order of bindings as much as possible. + std::vector group_order; + for (const auto& binding : bindings) { + auto g = GetGroupFromBinding(binding); + group_order.push_back(g); // Duplication does not matter since each group is visited once. + bindings_per_group[g].push_back(binding); + } + + std::unordered_set visited; + + std::function)> dfs_visit; + dfs_visit = [this, &visited, &dfs_visit](Group* g, auto leaf_fun) { + if (!visited.count(g)) { + visited.insert(g); + for (auto dep : group_deps_[g]) { + dfs_visit(dep, leaf_fun); + } + leaf_fun(g); + } + }; + + Array sorted; + + for (auto g : group_order) { + dfs_visit(g, [&sorted, &bindings_per_group](Group* leaf) { + for (const auto& binding : bindings_per_group[leaf]) { + sorted.push_back(binding); + } + }); + } + + return sorted; + } + + /*! \brief The IRModule. */ + IRModule mod_; + /*! \brief Internal arena. */ + support::Arena arena_; + /*! \brief The group assignment map. */ + GroupMap obj2group_; + /*! \brief Internal function information map. */ + std::unordered_map group2func_; + /*! \brief Record the index for TupleGetItem if the variable needs to be remapped to an output + * tuple element after fusion. */ + std::unordered_map tuple_get_indices_; + /*! + * \brief A map from a group to its dependent groups, used to detect cyclic dependencies. + * \note Use vector so we can be deterministic, there won't be a lot of dep groups so + * linear search is OK. + */ + std::unordered_map> group_deps_; + /*! \brief Whether or not to lift bound constants to parameters of the grouped function. */ + bool lift_constants_{true}; +}; + +IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) { + support::Arena arena; + + // Step 1. Create the indexed-forward graph according to the input IRModule. + IndexedForwardGraph graph = GraphCreator::Create(mod, &arena); + + // Step 2. Partition the graph by applying the fusion algorithm. + std::vector groups = + GraphPartitioner(&arena, opt_level, max_fuse_depth).Partition(graph); + + // Step 3. Transform the IRModule by fusing the operators in accordance with the graph partition + // results. + return OperatorFusor(mod, graph, groups, /*lift_constants*/ true).Transform(); +} + +IRModule MakeGroupedFunctions( + IRModule mod, const std::unordered_map& partition, + bool lift_constants) { + return OperatorFusor(mod, partition, lift_constants).Transform(); +} + +/*! \brief Create a "partitioning", a map from interior / leaf expr to its representative group, + * based on the provided pattern. The result can be passed to OperatorFusor above to fuse operations + * in a group and create a grouped function. + */ +class PatternBasedPartitioner : ExprVisitor { + public: + using Group = GraphPartitioner::Group; + using GroupMap = OperatorFusor::GroupMap; + using PatternCheckContext = transform::PatternCheckContext; + using ExprVisitor::VisitExpr_; + using FCheckMatch = runtime::TypedPackedFunc; + + static GroupMap Run(String pattern_name, DFPattern pattern, + Map annotation_patterns, FCheckMatch check, Expr expr, + support::Arena* arena) { + PatternBasedPartitioner part(pattern_name, pattern, annotation_patterns, check, arena); + part.VisitExpr(expr); + return part.group_map_; + } + + PatternBasedPartitioner(String pattern_name, DFPattern pattern, + Map annotation_patterns, FCheckMatch check, + support::Arena* arena) + : pat_name_(pattern_name), + pat_(pattern), + annotation_pat_(annotation_patterns), + check_(check), + arena_(arena) {} + + void VisitBindingBlock_(const DataflowBlockNode* block) final { + current_block_use_def_ = DataflowBlockUseDef(GetRef(block)); + ExprVisitor::VisitBindingBlock_(block); + current_block_use_def_ = {}; + } + + void VisitVarDef(const Var& var) final { group_map_[var.get()] = arena_->make(); } + + void VisitBinding_(const VarBindingNode* binding) final { + bindings_.Set(binding->var, binding->value); + value_to_bound_var_.Set(binding->value, binding->var); + ExprVisitor::VisitBinding_(binding); + } + + void VisitExpr_(const ConstantNode* op) final { group_map_[op] = arena_->make(); } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { + VisitVarDef(binding->var); + if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef(call), bindings_)) { + if (check_ != nullptr && !check_(CreatePatternCheckContext(call, matches_opt.value()))) { + return; + } + // If a match is found, put all matching expressions into the same group. + // OperatorFusor also requires that the bound variable be in the same group as the RHS value. + // Since is_op(...) based pattern only matches against call nodes on the right hand side, + // we need to take care of groups corresponding to the LHS bound variables carefully. + + // In the example below, conv2d + relu pattern would match if the "call" variable in this + // function points to the relu op. We identify the group corresponding to "conv1", and make + // it the representative group for relu and conv2d on the RHS and also "lv" on the LHS. + + // with R.dataflow(): + // lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(...) + // conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv) + + // parent_group corresponds to the group of "conv1" above. + auto parent_group = GetGroupForBoundVar(binding->var); + ICHECK(parent_group); + parent_group->attrs.Set(attr::kComposite, pat_name_); + for (const auto& [pat, match] : matches_opt.value()) { + // Put all matching call nodes into the parent group. + if (pat->IsInstance() && match != GetRef(call)) { + // Put the bound variable on the LHS into the same parent group. + AddToGroup(value_to_bound_var_[match], parent_group); + } + } + } + } + + private: + void AddToGroup(Expr e, Group* to) { + if (group_map_[e.get()] != to) { + --group_map_[e.get()]->num_nodes; + group_map_[e.get()]->parent = to; + ++to->num_nodes; + } + } + + Group* GetGroupForBoundVar(const Var& bound_var) { + ICHECK(group_map_.count(bound_var.get())); + return group_map_[bound_var.get()]->FindRoot(); + } + + PatternCheckContext CreatePatternCheckContext(const CallNode* call, + const Map& matched_result) { + Map annotated_expr; + for (const auto& it : annotation_pat_) { + if (matched_result.count(it.second)) { + annotated_expr.Set(it.first, matched_result[it.second]); + } + } + + Map matched_bindings; + for (const auto& [pat, match] : matched_result) { + if (pat->IsInstance()) { + matched_bindings.Set(value_to_bound_var_[match], match); + } + } + + return PatternCheckContext(GetRef(call), annotated_expr, matched_bindings, + current_block_use_def_, value_to_bound_var_); + } + + String pat_name_; + DFPattern pat_; + Map annotation_pat_; + FCheckMatch check_; + support::Arena* arena_; + Map bindings_; + Map value_to_bound_var_; + Map> current_block_use_def_; + GroupMap group_map_; +}; + +/*! + * \brief Wrap each created composite function with another function, whose body consists + * only of a call to the composite function, and annotate the outer function with kCodegen + * and kGlobalSymbol attributes. + */ +class CompositeFunctionAnnotator : public ExprMutator { + public: + explicit CompositeFunctionAnnotator(IRModule mod) : ExprMutator(mod) {} + using ExprMutator::VisitExpr_; + + IRModule Run() { + auto mod = builder_->GetContextIRModule(); + auto gvar = mod->GetGlobalVar("main"); + auto func = Downcast(mod->Lookup(gvar)); + auto new_func = + Function(func->params, VisitExpr(func->body), func->ret_struct_info, func->attrs); + builder_->UpdateFunction(gvar, new_func); + return builder_->GetContextIRModule(); + } + + Expr VisitExpr_(const CallNode* call_node) final { + if (auto const* gvar = call_node->op.as()) { + if (auto it = gvar_map_.find(gvar); it != gvar_map_.end()) { + return Call(it->second, call_node->args); + } + auto func = builder_->GetContextIRModule()->Lookup(GetRef(gvar)); + if (auto composite_name = func->GetAttr(attr::kComposite)) { + auto new_func = Downcast(VisitExpr(func)); + auto codegen_name = GetCodegenName(composite_name.value()); + auto gsymbol = gvar->name_hint + "_" + codegen_name; + new_func = WithAttrs(new_func, + {{attr::kCodegen, codegen_name}, {tvm::attr::kGlobalSymbol, gsymbol}}); + builder_->GetContextIRModule()->Remove(GetRef(gvar)); + auto new_gvar = builder_->AddFunction(new_func, gsymbol); + gvar_map_[gvar] = new_gvar; + return Call(new_gvar, call_node->args); + } + } + return ExprMutator::VisitExpr_(call_node); + } + + Expr VisitExpr_(const FunctionNode* func_node) final { + auto f_inner = ExprMutator::VisitExpr_(func_node); + auto composite_name = func_node->GetAttr(attr::kComposite); + ICHECK(composite_name); + + Array param_vars; + Array params; + + for (auto v : func_node->params) { + Var new_v(v->name_hint(), GetStructInfo(v)); + param_vars.push_back(new_v); + params.push_back(new_v); + } + + return Function(param_vars, Call(f_inner, params), func_node->ret_struct_info); + } + + private: + String GetCodegenName(const std::string& composite_name) { + auto delim_pos = composite_name.find("."); + ICHECK(delim_pos != std::string::npos) << "The pattern name for a composite function should " + "start with a compiler name followed by period."; + return composite_name.substr(0, delim_pos); + } + + /*! \brief A map from old global vars to their replacements. */ + std::unordered_map gvar_map_; +}; + +IRModule FuseOpsByPattern(const tvm::Array& patterns, IRModule mod, + bool bind_constants, bool annotate_codegen) { + support::Arena arena; + for (const auto& pattern : patterns) { + OperatorFusor::GroupMap group_map; + for (const auto& entry : mod->functions) { + if (entry.second->IsInstance()) { + continue; + } + auto map = PatternBasedPartitioner::Run( + pattern->name, pattern->pattern, pattern->annotation_patterns, + pattern->check.value_or(nullptr), entry.second, &arena); + group_map.insert(map.begin(), map.end()); + } + mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants); + } + if (annotate_codegen) { + return CompositeFunctionAnnotator(mod).Run(); + } + return mod; +} + +namespace transform { + +FusionPattern::FusionPattern(String name, DFPattern pattern, + Map annotation_patterns, + Optional check) { + ObjectPtr n = make_object(); + n->name = std::move(name); + n->pattern = std::move(pattern); + n->annotation_patterns = std::move(annotation_patterns); + n->check = check; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(FusionPatternNode); +TVM_REGISTER_GLOBAL("relax.transform.FusionPattern") + .set_body_typed([](String name, DFPattern pattern, Map annotation_patterns, + Optional check) { + return FusionPattern(name, pattern, annotation_patterns, check); + }); + +PatternCheckContext::PatternCheckContext(Expr matched_expr, Map annotated_expr, + Map matched_bindings, + Map> var_usages, + Map value_to_bound_var) { + ObjectPtr n = make_object(); + n->matched_expr = std::move(matched_expr); + n->annotated_expr = std::move(annotated_expr); + n->matched_bindings = std::move(matched_bindings); + n->var_usages = std::move(var_usages); + n->value_to_bound_var = std::move(value_to_bound_var); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PatternCheckContextNode); + +Pass FuseOps(int fuse_opt_level) { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { + int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; + auto max_fuse_depth = pc->GetConfig("relax.FuseOps.max_depth", Integer(kMaxFusedOps)); + return relax::FuseOps(m, opt_level, max_fuse_depth.value().IntValue()); + }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseOps", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); + +Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_constants, + bool annotate_codegen) { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { + return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen); + }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseOpsByPattern", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc new file mode 100644 index 000000000000..b695c5f6c7cf --- /dev/null +++ b/src/relax/transform/fuse_tir.cc @@ -0,0 +1,692 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include "../../relay/analysis/graph_partitioner.h" +#include "../../support/arena.h" +#include "../../tir/ir/functor_common.h" + +namespace tvm { +namespace tir { + +// TODO(Siyuan): move it to somewhere under tir folder +/*! + * \brief Substitute a given source buffer with a given target buffer in statements or expressions. + */ +class FuseTIRBufferSubstitor : private StmtExprMutator { + public: + static Stmt Substitute(const Map& buffer_map, Stmt stmt) { + return FuseTIRBufferSubstitor(buffer_map)(std::move(stmt)); + } + + private: + explicit FuseTIRBufferSubstitor(const Map& buffer_map) { + for (const auto& kv : buffer_map) { + const Buffer& src = kv.first; + const Buffer& tgt = kv.second; + buffer_var_map_[src->data.get()] = tgt; + } + } + + PrimExpr VisitExpr_(const VarNode* _op) final { + auto it = buffer_var_map_.find(_op); + if (it != buffer_var_map_.end()) { + return it->second->data; + } else { + return GetRef(_op); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* _op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); + auto it = buffer_var_map_.find(load->buffer->data.get()); + if (it != buffer_var_map_.end()) { + auto n = make_object(*load.get()); + n->buffer = it->second; + return BufferLoad(n); + } else { + return std::move(load); + } + } + + Stmt VisitStmt_(const BufferStoreNode* _op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); + auto it = buffer_var_map_.find(store->buffer->data.get()); + if (it != buffer_var_map_.end()) { + auto n = CopyOnWrite(store.get()); + n->buffer = it->second; + return BufferStore(n); + } else { + return std::move(store); + } + } + + Stmt VisitStmt_(const BlockNode* _op) final { + Block block = Downcast(StmtMutator::VisitStmt_(_op)); + + // Define the mutation functions. + auto f_mutate_match_buffers = [this](const MatchBufferRegion& match_buffer) { + const Buffer& src_buffer = match_buffer->source->buffer; + auto it = buffer_var_map_.find(src_buffer->data.get()); + if (it != buffer_var_map_.end()) { + return MatchBufferRegion(match_buffer->buffer, + BufferRegion(it->second, match_buffer->source->region)); + } else { + return match_buffer; + } + }; + + auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) { + auto it = buffer_var_map_.find(buffer_region->buffer->data.get()); + return it == buffer_var_map_.end() ? buffer_region + : BufferRegion(it->second, buffer_region->region); + }; + + // Step 1. Mutate `match_buffers`. + Array match_buffers = + MutateArray(block->match_buffers, f_mutate_match_buffers); + // Step 2. Mutate the read/write region. + Array reads = MutateArray(block->reads, f_mutate_read_write_region); + Array writes = MutateArray(block->writes, f_mutate_read_write_region); + + reads = UnionAccessRegion(reads); + writes = UnionAccessRegion(writes); + + if (reads.same_as(block->reads) && // + writes.same_as(block->writes) && // + match_buffers.same_as(block->match_buffers)) { + return std::move(block); + } else { + auto n = CopyOnWrite(block.get()); + n->reads = std::move(reads); + n->writes = std::move(writes); + n->match_buffers = std::move(match_buffers); + return Block(n); + } + } + + private: + /*! \brief Mapping from src buffer.data to tgt buffer. */ + std::unordered_map buffer_var_map_; + /*! \brief The structural equality checker */ + StructuralEqual structural_equal_; + + Array UnionAccessRegion(const Array& regions) const { + // For now we only allow Buffer access the same elements. + // e.g. `[A[vi, vj], A[vi, vj]]` is a legal pattern but need to union to `A[vi, vj]` + // However, `A[vi, vj], A[vi, vj + 1]` is not allow for now. + // Note: the order of return region should remain the same as the first occurance of the region + Array ret; + std::unordered_map buffer_region_set; + + for (const BufferRegion& region : regions) { + auto it = buffer_region_set.find(region->buffer.get()); + if (it == buffer_region_set.end()) { + ret.push_back(region); + buffer_region_set[region->buffer.get()] = region->region; + } else { + ICHECK(structural_equal_(region->region, it->second)); + } + } + + if (ret.size() == regions.size()) { + return regions; + } else { + return ret; + } + } +}; + +/*! \brief A mutator which detect block name duplication and deduplicate the names. */ +class BlockNameDeduplicator : public tir::StmtMutator { + private: + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(tir::StmtMutator::VisitStmt_(op)); + + String name = GetUniqueName(block->name_hint); + + if (name == block->name_hint) { + return std::move(block); + } else { + ObjectPtr n = CopyOnWrite(block.get()); + n->name_hint = std::move(name); + return Stmt(n); + } + } + + String GetUniqueName(const String& prefix) { + String unique_prefix = prefix; + auto it = name_count_.find(prefix); + while (name_count_.count(unique_prefix)) { + unique_prefix = prefix + "_" + std::to_string(++it->second); + } + name_count_[unique_prefix] = 0; + return unique_prefix; + } + + // TODO(relax-team): It should detects the number suffix and do renaming properly + // e.g. GetUniqueName("name1") should return "name2" instead of "name10". + /*! \brief The count map to make block name unique. */ + std::unordered_map name_count_; +}; + +} // namespace tir + +namespace relax { + +class FusedTIRConstructor : public ExprVisitor { + public: + /*! + * \brief Construct a fused TIR PrimFunc from a relax sub-function + * \param mod The IRModule + * \param gv The global var of relax subfunction to be fused into one PrimFunc + * \return The fused TIR PrimFunc + */ + static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) { + FusedTIRConstructor visitor(mod, gv->name_hint); + BaseFunc f = mod->Lookup(gv); + CHECK(f->IsInstance()) + << "Expected relax functions, but got: " << f->GetTypeKey(); + CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive)) + << "Expected a function with attr `kPrimitive`"; + visitor(Downcast(f)); + return visitor.fused_tir_; + } + + private: + explicit FusedTIRConstructor(const IRModule& mod, const String& func_name) + : mod_(mod), func_name_(func_name) {} + + void VisitExpr_(const FunctionNode* func) final { + // Step 1. Create buffers for function params + for (const Var& relax_param : func->params) { + auto ret = CreateParamsAndBuffers(GetStructInfo(relax_param), // + relax_param->name_hint()); + const Array& params = ret.first; + const Array& buffers = ret.second; + ICHECK_EQ(params.size(), buffers.size()); + for (size_t i = 0; i < params.size(); ++i) { + func_info_.buffer_map.Set(params[i], buffers[i]); + func_info_.params.push_back(params[i]); + } + func_info_.expr2buffers.Set(relax_param, buffers); + } + + // Step 2. Visit Function body and create intermediate buffers + ExprVisitor::VisitExpr_(func); + + // Step 3. Create and remap buffers for function output + ICHECK(func->body->IsInstance()) + << "Function body is expected to be a SeqExpr, but got: " << func->body->GetTypeKey(); + Expr body = Downcast(func->body)->body; + auto it = func_info_.expr2buffers.find(body); + ICHECK(it != func_info_.expr2buffers.end()) + << "Fail to detect output buffers for function body"; + const Array& buffers = (*it).second; + for (size_t i = 0; i < buffers.size(); ++i) { + tir::Var param = tir::Var("p_output" + std::to_string(i), PrimType(DataType::Handle())); + func_info_.buffer_map.Set(param, buffers[i]); + func_info_.params.push_back(param); + func_info_.output_buffers.insert(buffers[i].get()); + } + + // Step 4. Create PrimFunc + fused_tir_ = ConstructFunc(); + } + + void VisitBinding_(const VarBindingNode* binding) final { + // Update expr2buffers by visiting values. + this->VisitExpr(binding->value); + auto it = func_info_.expr2buffers.find(binding->value); + if (it != func_info_.expr2buffers.end()) { + // assign binding var to the buffers of the value + func_info_.expr2buffers.Set(binding->var, (*it).second); + } else { + LOG(FATAL) << "Unsupported binding value: " << binding->value; + } + } + + void VisitBinding_(const MatchCastNode* match_cast) final { + LOG(FATAL) << "MatchCast is unsupported in primitive functions"; + } + + void VisitExpr_(const CallNode* call) final { + ExprVisitor::VisitExpr_(call); + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + ICHECK(call->op == call_tir_op_) + << "Only call_tir is supported in primitive function, but got: " << GetRef(call); + + // Step 1. Get Global var and PrimFunc + GlobalVar gv = Downcast(call->args[0]); + tir::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); + + // Step 2. Renew all vars/buffer definitions and blocks to avoid duplication + tir::PrimFunc prim_func = tir::RenewDefs(prim_func_); + + // Step 3. Check functions are all schedulable funcs. i.e. the body of func is root block + // TODO(Siyuan): support un-schedulable functions. + ICHECK(prim_func->body->IsInstance()) + << "Only schedulable functions (whose body is the root block) can be fused"; + const tir::BlockRealize& root_realize = Downcast(prim_func->body); + const tir::Block& root_block = root_realize->block; + + // Step 4. Add all the original alloc_buffers and body to the fused function. + func_info_.alloc_buffers.insert(func_info_.alloc_buffers.end(), + root_block->alloc_buffers.begin(), + root_block->alloc_buffers.end()); + func_info_.bodies.push_back(root_block->body); + + // Step 5. Map input arguments to buffer + MapInputBuffer(prim_func, call->args[1]); + size_t num_output_buffers = GetCallTIROutputSize(call); + AllocateIntermediateBuffer(GetRef(call), prim_func, num_output_buffers); + // Update fused func name + func_info_.global_name += "_" + gv->name_hint; + } + + void VisitExpr_(const TupleGetItemNode* tuple_get_item) final { + ExprVisitor::VisitExpr_(tuple_get_item); + auto it = func_info_.expr2buffers.find(tuple_get_item->tuple); + if (it != func_info_.expr2buffers.end()) { + int begin_buf_idx = 0; + int end_buf_idx = 0; + const TupleType& tuple_type = Downcast(tuple_get_item->tuple->checked_type()); + for (int i = 0; i < tuple_get_item->index; ++i) { + begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]); + } + end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]); + func_info_.expr2buffers.Set( + GetRef(tuple_get_item), + {(*it).second.begin() + begin_buf_idx, (*it).second.begin() + end_buf_idx}); + } + } + + void VisitExpr_(const TupleNode* tuple) final { + ExprVisitor::VisitExpr_(tuple); + Array buffers; + for (const Expr& expr : tuple->fields) { + auto it = func_info_.expr2buffers.find(expr); + if (it != func_info_.expr2buffers.end()) { + buffers.insert(buffers.end(), (*it).second.begin(), (*it).second.end()); + } + } + if (!buffers.empty()) { + func_info_.expr2buffers.Set(GetRef(tuple), buffers); + } + } + + void VisitExpr_(const ConstantNode* op) final { + LOG(FATAL) << "Relax.Constant is not supported in primitive functions."; + } + + /*! + * \brief Get the number of outputs for a call_tir node. + * \return The number of outputs. + */ + static size_t GetCallTIROutputSize(const CallNode* call) { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + ICHECK(call->op.same_as(call_tir_op_)); + ICHECK_EQ(call->sinfo_args.size(), 1); + if (const auto* tuple_sinfo = call->sinfo_args[0].as()) { + return tuple_sinfo->fields.size(); + } else { + return 1; + } + } + + /*! \brief Map old TIR func param buffer to new buffer, and then update `buffer_subst_map` */ + void MapArgsToBuffer(const Array args, const Array& buffers) { + size_t buffer_idx = 0; + for (const Expr& arg : args) { + if (const auto* v = arg.as()) { + auto it = func_info_.expr2buffers.find(GetRef(v)); + // Substitute the buffer with the already allocated one if it is an intermediate var + if (it != func_info_.expr2buffers.end()) { + for (const tir::Buffer& target_buffer : (*it).second) { + ICHECK_LT(buffer_idx, buffers.size()); + const tir::Buffer& buffer = buffers[buffer_idx]; + // TODO(relax-team): Add support for symbolic shape fusion + for (const PrimExpr& shape_expr : buffer->shape) { + ICHECK(shape_expr.as()) << "Only support constant shape fusion for now"; + } + func_info_.buffer_subst_map.Set(buffer, target_buffer); + buffer_idx++; + } + } + } + } + // Make sure every buffers are maped. + ICHECK_EQ(buffer_idx, buffers.size()); + } + + /*! + * \brief Update buffer mapping `func_info_.buffer_subst_map` for input args + * \param func The old TIR PrimFunc + * \param output_size The number of output params. All output params are at the end of param list. + */ + void MapInputBuffer(const tir::PrimFunc& func, const relax::Expr& args) { + Array arg_list; + Array buffer_list; + if (const auto* arg_tuple = args.as()) { + arg_list = arg_tuple->fields; + } else { + arg_list = {args}; + } + + ICHECK_GE(func->params.size(), arg_list.size()); + for (size_t i = 0; i < arg_list.size(); ++i) { + const tir::Var& param = func->params[i]; + const tir::Buffer& buffer = func->buffer_map.at(param); + buffer_list.push_back(buffer); + } + + MapArgsToBuffer(arg_list, buffer_list); + } + + /*! + * \brief Allocate buffer(s) and update `func_info.expr2buffers` if the PrimFunc output(s) are + * intermediate results. + * \param expr The relax Expr, which can be binding vars or binding values. + * \param func The old TIR PrimFunc + * \param output_size The number of output params. All output params are at the end of param list. + */ + void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func, size_t output_size) { + size_t n = func->params.size(); + ICHECK_GE(n, output_size); + // Allocate intermediate buffer + Array alloc_buffers; + for (size_t i = 0; i < output_size; ++i) { + const tir::Var& param = func->params[n - output_size + i]; + const tir::Buffer& buffer = func->buffer_map.at(param); + func_info_.alloc_buffers.push_back(buffer); + alloc_buffers.push_back(buffer); + } + // Update expr2buffers + func_info_.expr2buffers.Set(expr, alloc_buffers); + } + + /*! + * \brief Create an TIR func params and buffers with specified relax type and shape + * \param struct_info The struct info + * \param name_hint The name hint for params and buffers + * \param index The index used for unique name_hint if type is Tuple. + * -1 means no need to add postfix since the relax param is not a Tuple. + * \return The created TIR func params and buffers + */ + static std::pair, Array> CreateParamsAndBuffers( + StructInfo struct_info, const String& name_hint, int index = -1) { + Array params; + Array buffers; + if (const auto* tensor = struct_info.as()) { + // Case 1. the relax param is a DynTensor, we directly create a tir var and buffer + const auto* shape_expr = tensor->shape.as(); + ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with symbolic shape."; + + String name = index == -1 ? name_hint : name_hint + "_" + std::to_string(index); + DataType dtype = tensor->dtype; + tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name); + // Differentiate buffer name and param name by adding prefix `v_` to param + // Every symbol should be unique in TVMScript, and Buffer is used more than param + // So we decide to make sure buffer names have better readability. + tir::Var param = tir::Var("p_" + name, PrimType(DataType::Handle())); + params.push_back(std::move(param)); + buffers.push_back(std::move(buffer)); + } else if (const auto* tuple = struct_info.as()) { + // Case 2. the relax param is a Tuple, we recursively visit each field until it's a DynTensor + // Enable postfix + if (index == -1) index = 0; + for (size_t i = 0; i < tuple->fields.size(); ++i) { + auto ret = CreateParamsAndBuffers(tuple->fields[i], name_hint, index); + const Array& ret_params = ret.first; + const Array& ret_buffers = ret.second; + ICHECK_EQ(ret_params.size(), ret_buffers.size()); + // Adding tuple field results to the end of params and buffers. + params.insert(params.end(), ret_params.begin(), ret_params.end()); + buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end()); + index += ret_params.size(); + } + } else { + ICHECK(false) << "shapes are expected to be ShapeExprNode or TupleNode"; + } + return std::make_pair(params, buffers); + } + + /*! + * \brief Construct fused TIR func with collected FuseFuncInfo + * \return The fused TIR + */ + tir::PrimFunc ConstructFunc() { + Map attr_map; + attr_map.Set("tir.noalias", tir::const_true()); + ICHECK(func_info_.global_name != "fused"); + // Remove output buffers from func_info_.alloc_buffers + Array alloc_buffers; + for (const tir::Buffer& buf : func_info_.alloc_buffers) { + if (func_info_.output_buffers.count(buf.get()) == 0) { + alloc_buffers.push_back(buf); + } + } + tir::Stmt body = tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies)); + body = tir::FuseTIRBufferSubstitor::Substitute(func_info_.buffer_subst_map, body); + body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt, alloc_buffers); + body = tir::BlockRealize({}, Bool(true), Downcast(body)); + tir::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, + DictAttrs(attr_map)); + return func; + } + + /*! \brief Get DynTensor numbers from recursive Tuples. */ + static size_t GetTotalTensorSize(const Type& type) { + if (type.as()) { + return 1; + } else if (const auto* tuple_type = type.as()) { + size_t num = 0; + for (const Type& type : tuple_type->fields) { + num += GetTotalTensorSize(type); + } + return num; + } else { + LOG(FATAL) << "DynTensorType and TupleType are expect, but got: " << type; + return 0; + } + } + + /********** Function Info **********/ + + /*! \brief auxiliary information for FuseTIR */ + struct FuseFuncInfo { + /*! \brief The arguments for calling prim_func */ + Array arguments; + /*! + * \brief The map from each dataflow var (intermediate var) to the corresponding buffers + * allocated in the fused func + */ + Map> expr2buffers; + /*! \brief The buffers to allocate in the fused func*/ + Array alloc_buffers; + /*! \brief The bodies of the original funcs, which is also the body of the fused func. */ + Array bodies; + /*! \brief The params of the fused function*/ + Array params; + /*! + * \brief The map from buffer in original functions to corresponding buffer in the fused + * function + */ + Map buffer_subst_map; + /*! \brief The `buffer_map` in the fused function*/ + Map buffer_map; + /*! \brief The output buffers in the function buffer_map*/ + std::unordered_set output_buffers; + /*! \brief The name of the fused function */ + std::string global_name = "fused"; + }; + + /*! \brief The IRModule */ + const IRModule& mod_; + /*! \brief The name hint for the input func. */ + String func_name_; + /*! \brief The helper info to fuse TIR prim_func */ + FuseFuncInfo func_info_; + /*! \brief The tir function after fusion*/ + tir::PrimFunc fused_tir_; +}; + +/*! + * \brief The helper class to fuse TIR functions and build a new module which calls the fused TIR. + */ +class TIRFuseMutator : public ExprMutator { + public: + static IRModule Transform(const IRModule& mod) { + // Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder. + TIRFuseMutator mutator(mod); + // Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_` + for (const auto& kv : mod->functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + // Only fuse primitive relax functions + if (func->IsInstance() && func->HasNonzeroAttr(attr::kPrimitive)) { + tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv); + mutator.fused_tir_funcs_.Set(gv, fused_tir); + } + } + + // Step 2. Update all non-primitive relax functions and add it, with the dependent function, + // into the new IRModule + for (const auto& kv : mod->functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { + relax::Function update_func = Downcast(mutator.VisitExpr(func)); + mutator.builder_->AddFunction(update_func, gv->name_hint); + } + } + + // Step 3. Copy over module attributes and return. + auto modified_mod = mutator.builder_->GetContextIRModule(); + if (mod->attrs.defined()) modified_mod = WithAttrs(modified_mod, mod->attrs->dict); + return modified_mod; + } + + private: + explicit TIRFuseMutator(const IRModule& mod) : mod_(mod) {} + + using ExprMutator::VisitExpr_; + + // Get shape from call tir + static Expr GetCallTIRShape(StructInfo sinfo) { + if (auto* tuple = sinfo.as()) { + Array fields = tuple->fields.Map([&](StructInfo x) { return GetCallTIRShape(x); }); + return Tuple(fields); + } else { + auto* tensor = sinfo.as(); + ICHECK(tensor) << "FuseTIR can only take tensor or tuple type"; + auto* shape_expr = tensor->shape.as(); + ICHECK(shape_expr) << "FuseTIR requires all intermediate values have shape"; + return GetRef(shape_expr); + } + } + + Expr VisitExpr_(const CallNode* op) final { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + Call call = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(op))); + + if (call->op->IsInstance()) { + // Case 1. It is a relax cross function call + GlobalVar old_gv = Downcast(call->op); + auto it = fused_tir_funcs_.find(old_gv); + if (it != fused_tir_funcs_.end()) { + const tir::PrimFunc& fused_tir = (*it).second; + // Case 1.1. It calls a primitive relax function, update the call into a call_tir + GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir, old_gv->name_hint); + // Step a. Flatten all args since call_tir does not support Tuple value. + Array arg_list; + for (const Expr& arg : call->args) { + Array flattened = FlattenArg(arg); + arg_list.insert(arg_list.end(), flattened.begin(), flattened.end()); + } + // Step b. Create call_tir + Array call_args = {fused_tir_gv, Tuple(arg_list)}; + return Call(call_tir_op_, call_args, call->attrs, {GetStructInfo(call)}); + } else { + // Case 1.2. The callee function is not primitive, nothing to do. + return call; + } + } else if (call->op == call_tir_op_) { + // Case 2. It is a call_tir, re-emit the PrimFunc. + if (const auto* gv = call->args[0].as()) { + tir::PrimFunc func = Downcast(mod_->Lookup(GetRef(gv))); + GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint); + return Call(call->op, {new_gv, call->args[1]}, call->attrs, call->sinfo_args, call->span); + } + } + + // Case 3. CallNode in other types. Leave it as it is. + return call; + } + + /********** Helper Functions **********/ + + /*! \brief Flatten the call args if it's Tuple by emitting `TupleGetItem`. */ + Array FlattenArg(const Expr& arg) { + if (const auto* tuple_sinfo = GetStructInfoAs(arg)) { + Array arg_list; + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + Expr new_arg = builder_->Emit(TupleGetItem(arg, i)); + Array flattened = FlattenArg(new_arg); + arg_list.insert(arg_list.end(), flattened.begin(), flattened.end()); + } + return arg_list; + } else { + return {arg}; + } + } + + private: + /*! \brief The IRModule */ + const IRModule& mod_; + /*! \brief The map from global var of primitive relax function to generated prim func. */ + Map fused_tir_funcs_; +}; + +IRModule FuseTIR(IRModule mod) { + mod = TIRFuseMutator::Transform(mod); + return mod; +} + +namespace transform { + +Pass FuseTIR() { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { return relax::FuseTIR(m); }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseTIR", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/infer_amp_utils.cc b/src/relax/transform/infer_amp_utils.cc new file mode 100644 index 000000000000..330fe9a72ac4 --- /dev/null +++ b/src/relax/transform/infer_amp_utils.cc @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "infer_amp_utils.h" + +namespace tvm { +namespace relax { + +NType NTypeFrom(const StructInfo& sinfo, DataType dtype) { + auto fmapleaf = [&](const StructInfo& sinfo) -> NType { + const auto* tensor = sinfo.as(); + ICHECK(tensor) << "Expected TensorStructInfo, but got " << sinfo; + if (dtype == DataType::Void()) + return NType(DLDataType2String(tensor->dtype)); + else + return NType(DLDataType2String(dtype)); + }; + return MapToNestedMsg(sinfo, fmapleaf); +} + +NType NTypeFrom(const Expr& expr, DataType dtype) { return NTypeFrom(GetStructInfo(expr), dtype); } + +NType NTypeMerge(const NType& a, const NType& b) { + auto fcombine = [&](const String& a_str, const String& b_str) -> String { + DataType a = DataType(String2DLDataType(a_str)); + DataType b = DataType(String2DLDataType(b_str)); + ICHECK_EQ(a.code(), b.code()); + ICHECK_EQ(a.lanes(), b.lanes()); + return a.bits() > b.bits() ? a_str : b_str; + }; + return CombineNestedMsg(a, b, fcombine); +} + +Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype) { + return {Integer(MixedPrecisionPolicyKind::kFollow), call}; +} + +Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype) { + return {Integer(MixedPrecisionPolicyKind::kNever), call}; +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/infer_amp_utils.h b/src/relax/transform/infer_amp_utils.h new file mode 100644 index 000000000000..3c98af6db965 --- /dev/null +++ b/src/relax/transform/infer_amp_utils.h @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file infer_amp_utils.h + * \brief Utility functions to be used in to_mixed_precision pass. + */ + +#ifndef TVM_RELAX_TRANSFORM_INFER_AMP_UTILS_H_ +#define TVM_RELAX_TRANSFORM_INFER_AMP_UTILS_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { + +using runtime::DLDataType2String; +using runtime::String; +using runtime::String2DLDataType; + +enum MixedPrecisionPolicyKind : int { kAlways = 0, kFollow = 1, kNever = 2 }; + +/*! \brief the operator pattern */ +using TMixedPrecisionPolicy = int; + +// NType is the message we want to track for vars with nested tensorstructinfo +// which represents the realization decision of the var. +// The string is the name of the dtype decision. +using NType = NestedMsg; + +struct NTypeEqual { + bool operator()(const NType& a, const NType& b) const { + auto dtype_equal = [](const String& a, const String& b) { return a == b; }; + return Equal(a, b, dtype_equal); + } +}; + +// Construct a NType from an StructInfo +NType NTypeFrom(const StructInfo& sinfo, DataType dtype = DataType::Void()); + +// Construct a NType from an Expr +NType NTypeFrom(const Expr& expr, DataType dtype = DataType::Void()); + +// Merge two messages, we keep the higher precision type for each leaf tensor +NType NTypeMerge(const NType& a, const NType& b); + +// The map that notes the NType message of each var +using VarDTypeMap = std::unordered_map; + +// Call is a call node, out_dtype is the expected output_dtype +using FInferMixedPrecision = + runtime::TypedPackedFunc; + +Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype); + +Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_TRANSFORM_INFER_AMP_UTILS_H_ diff --git a/src/relax/transform/infer_layout_utils.cc b/src/relax/transform/infer_layout_utils.cc new file mode 100644 index 000000000000..d746f9394a75 --- /dev/null +++ b/src/relax/transform/infer_layout_utils.cc @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "infer_layout_utils.h" + +#include "utils.h" + +namespace tvm { +namespace relax { + +using tir::IterVar; +using tir::Layout; + +Layout TransposeLike(const Layout& input, const Layout& src, const Layout& dst) { + ICHECK(src.ndim() == dst.ndim() && input.ndim() == src.ndim()) + << "Layouts must have the same size"; + std::vector axes; + for (size_t i = 0; i < src.ndim(); ++i) { + axes.push_back(input->axes[src.IndexOf(dst[i])]); + } + return Layout(axes); +} + +String TransposeStrLike(const String& input, const Layout& src, const Layout& dst) { + ICHECK(src.ndim() == dst.ndim() && input.size() == src.ndim()) + << "Layouts must have the same size"; + std::string axes; + for (size_t i = 0; i < src.ndim(); ++i) { + axes.push_back(input.at(src.IndexOf(dst[i]))); + } + return axes; +} + +int FindAxis(const Layout& dst, int axis) { + axis = (axis + dst.ndim()) % dst.ndim(); + return dst.name().find('A' + axis); +} + +Layout InitialLayout(int ndim) { + ICHECK(ndim >= 0 && ndim <= 26) << "Only support up to 26 dimensions, but got " << ndim; + return Layout("ABCDEFGHIJKLMNOPQRSTUVWXYZ").SubLayout(0, ndim); +} + +LayoutDecision InitialLayoutDecision(int ndim) { + if (ndim == kUnknownNDim) { + return LayoutDecision::InitUnknownDim(); + } + ICHECK(ndim >= 0 && ndim <= 26) << "Only support up to 26 dimensions, but got " << ndim; + return Layout("ABCDEFGHIJKLMNOPQRSTUVWXYZ").SubLayout(0, ndim); +} + +NLayout InitialNLayout(const StructInfo& sinfo) { + auto fmapleaf = [&](const StructInfo& sinfo) -> NLayout { + if (const auto* tensor_sinfo = sinfo.as()) { + return NLayout(InitialLayoutDecision(tensor_sinfo->ndim)); + } + return LayoutDecision::InitUnknownDim(); + }; + return MapToNestedMsg(sinfo, fmapleaf); +} + +NLayout InitialNLayout(const Expr& expr) { return InitialNLayout(GetStructInfo(expr)); } + +LayoutDecision GetLayoutDecision(const VarLayoutMap& var_layout_map, const Expr& arg) { + NLayout nlayout = GetNLayout(var_layout_map, arg); + ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << arg; + return nlayout.LeafValue(); +} + +NLayout GetNLayout(const VarLayoutMap& var_layout_map, const Expr& arg) { + auto fmapleaf = [&](const Expr& expr) -> NLayout { + if (const auto* var = expr.as()) { + auto it = var_layout_map.find(GetRef(var)); + if (it != var_layout_map.end()) { + return (*it).second; + } else { + return InitialNLayout(expr); + } + } else if (const auto* constant = expr.as()) { + return InitialLayoutDecision(constant->data.Shape().size()); + } + return LayoutDecision::InitUnknownDim(); + }; + return MapToNestedMsg(arg, fmapleaf); +} + +bool NoDesiredLayout(const Call& call, const Map>& desired_layouts) { + const OpNode* op_node = call->op.as(); + if (op_node == nullptr) return false; + const auto& it = desired_layouts.find(op_node->name); + return it == desired_layouts.end(); +} + +LayoutDecision FollowDecision(const LayoutDecision& src, int dst_ndim) { + int src_ndim = src->layout.ndim(); + // broadcast case + if (src_ndim == dst_ndim) { + return src; + } else { + ICHECK_LT(src_ndim, dst_ndim) << "Cannot broadcast from " << src_ndim << " to " << dst_ndim; + std::string layout = InitialLayout(dst_ndim - src_ndim).name(); + for (int i = 0; i < src_ndim; ++i) { + layout.push_back(src->layout.name()[i] + dst_ndim - src_ndim); + } + return LayoutDecision(Layout(layout)); + } +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/infer_layout_utils.h b/src/relax/transform/infer_layout_utils.h new file mode 100644 index 000000000000..2cbbe23ede66 --- /dev/null +++ b/src/relax/transform/infer_layout_utils.h @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file infer_layout_utils.h + * \brief Utility functions to alter the layouts of operators or replace primitive operators with + other expressions. This pass can be used for computing convolution in + custom layouts or other general weight pre-transformation. + */ + +#ifndef TVM_RELAX_TRANSFORM_INFER_LAYOUT_UTILS_H_ +#define TVM_RELAX_TRANSFORM_INFER_LAYOUT_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using tir::Layout; + +/*! + * \brief A layout decision node that holds the layout decision of the tensor. + * \param layout The layout of the tensor. + */ +class LayoutDecisionNode : public Object { + public: + /*! \brief The layout decision of the tensor. */ + Layout layout; + /*! \brief Whether the dim of tensor is unknown. */ + bool is_unknown_dim = false; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("layout", &layout); } + + TVM_DECLARE_BASE_OBJECT_INFO(LayoutDecisionNode, Object); + + static constexpr const char* _type_key = "relax.transform.LayoutDecision"; +}; + +class LayoutDecision : public ObjectRef { + public: + LayoutDecision(Layout layout, bool is_unknown_dim = false) { // NOLINT(*) + auto n = make_object(); + n->layout = std::move(layout); + n->is_unknown_dim = is_unknown_dim; + data_ = n; + } + + static LayoutDecision InitUnknownDim() { return LayoutDecision(Layout::Undef(), true); } + + inline std::string name() const { + if (operator->()->is_unknown_dim) { + return "unknown_dim"; + } + return operator->()->layout.name(); + } + + TVM_DEFINE_OBJECT_REF_METHODS(LayoutDecision, ObjectRef, LayoutDecisionNode); +}; + +using NLayout = NestedMsg; + +/*! + * \brief An output structure to hold results from FInferCorrectLayout calls. + * \param input_layouts Inferred input layouts. + * \param output_layouts Inferred output layouts. + * \param new_attrs Updated attributes consistent with inferred layouts. + */ +class InferLayoutOutputNode : public Object { + public: + Array input_layouts; + Array output_layouts; + Attrs new_attrs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("input_layouts", &input_layouts); + v->Visit("output_layouts", &output_layouts); + v->Visit("new_attrs", &new_attrs); + } + + TVM_DECLARE_BASE_OBJECT_INFO(InferLayoutOutputNode, Object); + + static constexpr const char* _type_key = "relax.transform.InferLayoutOutput"; +}; + +class InferLayoutOutput : public ObjectRef { + public: + explicit InferLayoutOutput(Array input_layouts, Array output_layouts, + Attrs new_attrs) { + auto n = make_object(); + n->input_layouts = std::move(input_layouts); + n->output_layouts = std::move(output_layouts); + n->new_attrs = std::move(new_attrs); + data_ = n; + } + TVM_DEFINE_OBJECT_REF_METHODS(InferLayoutOutput, ObjectRef, InferLayoutOutputNode); +}; + +struct NLayoutEqual { + bool operator()(const NLayout& a, const NLayout& b) const { + auto layout_equal = [](const LayoutDecision& a, const LayoutDecision& b) { + if (a.defined() && b.defined()) { + return a.name() == b.name(); + } + return a.defined() == b.defined(); + }; + return Equal(a, b, layout_equal); + } +}; + +using VarLayoutMap = Map; + +/*! + * \brief Layout conversion interface. + * \param call The call node. + * \param desired_layouts The desired layouts of the operator. + * \param var_layout_map The layout of the variables. + */ +using FRelaxInferLayout = runtime::TypedPackedFunc>& desired_layouts, + const VarLayoutMap& var_layout_map)>; + +/*! + * \brief Initialize a layout given the number of dimensions. + * \param ndim The number of dimensions. + * \return The initialized layout. + */ +Layout InitialLayout(int ndim); + +/*! + * \brief Initialize a layout decision given the number of dimensions. + * \param ndim The number of dimensions. + * \return The initialized layout decision. + */ +LayoutDecision InitialLayoutDecision(int ndim); + +/*! + * \brief Initialize a nested layout decision given the struct info. + * \param sinfo The sinfo. + * \return The initialized nested layout decision. + */ +NLayout InitialNLayout(const StructInfo& sinfo); + +/*! + * \brief Initialize a nested layout decision given expression + * \param sinfo The expr + * \return The initialized nested layout decision. + */ +NLayout InitialNLayout(const Expr& expr); + +/*! + * \brief Transpose the input layout like the src layout to the dst layout. + * \param input The input layout. + * \param src The source layout. + * \param dst The destination layout. + * \return The transposed input layout. + */ +Layout TransposeLike(const Layout& input, const Layout& src, const Layout& dst); + +/*! + * \brief Transpose the input string like the src layout to the dst layout. + * \param input The input str. + * \param src The source layout. + * \param dst The destination layout. + * \return The transposed input str. + */ +String TransposeStrLike(const String& input, const Layout& src, const Layout& dst); + +/*! + * \brief Find axis in the dst layout. 0 represents the first axis, 1 represents the second axis, + * etc. + * \param dst The destination layout. + * \param axis The axis to be found + * \return The axis in the dst layout. + */ +int FindAxis(const Layout& dst, int axis); + +/*! + * \brief Get the layout decision of the expr. The expr must be a Tensor. + * \param var_layout_map The layout of the variables. + * \param arg The expr. + * \return The layout decision of the expr. + */ +LayoutDecision GetLayoutDecision(const VarLayoutMap& var_layout_map, const Expr& arg); + +/*! + * \brief Get the nested layout decision of the expr. The expr must be a nested Tensor. + * \param var_layout_map The layout of the variables. + * \param arg The expr. + * \return The nested layout decision of the expr. + */ +NLayout GetNLayout(const VarLayoutMap& var_layout_map, const Expr& arg); + +/*! + * \brief Check if the op is not in the desired layout + * \param call The call node contains the op + * \param desired_layouts The desired layouts of the operator. + * \return True if the op is not in the desired layout. + */ +bool NoDesiredLayout(const Call& call, const Map>& desired_layouts); + +/*! + * \brief Let a tensor with ndim to follow the src layout decision. + * \param src The source layout decision. + * \param dst_ndim The number of dimensions of the tensor. + * \return The layout decision of the tensor. + */ +LayoutDecision FollowDecision(const LayoutDecision& src, int dst_ndim); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_TRANSFORM_INFER_LAYOUT_UTILS_H_ diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc new file mode 100644 index 000000000000..74920823100a --- /dev/null +++ b/src/relax/transform/lambda_lift.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/lambda_lift.cc + * \brief Lift local functions into global functions. + */ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/* The goal of this class is to lift out any nested functions into top-level + * functions. + * + * We will lift a function out into a global which takes the set of the free + * vars and then return the new created function. + */ +class LambdaLifter : public ExprMutator { + public: + explicit LambdaLifter(const IRModule& module) : ExprMutator(module) { mod_ = module; } + + using ExprMutator::VisitExpr_; + + void VisitBinding_(const VarBindingNode* binding) final { + bool is_lambda = false; + if (binding->value->IsInstance()) { + is_lambda = true; + recur_vars_.push_back(binding->var); + } + Expr new_value = this->VisitExpr(binding->value); + if (new_value->struct_info_.defined() && + !new_value->struct_info_.same_as(binding->var->struct_info_)) { + binding->var->struct_info_ = GetStructInfo(new_value); + binding->var->checked_type_ = new_value->checked_type_; + } + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + builder_->EmitNormalized(VarBinding(binding->var, new_value)); + } + if (is_lambda) { + recur_vars_.pop_back(); + } + } + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + if (const auto* var_node = call_node->op.as()) { + auto var = GetRef(var_node); + bool has_closure = HasClosure(var); + auto val = builder_->LookupBinding(var); + if (const auto* fsinfo_node = GetStructInfo(var).as()) { + auto fsinfo = GetRef(fsinfo_node); + if (!GetStructInfo(call).same_as(fsinfo)) { + call->struct_info_ = fsinfo->ret; + call->checked_type_ = GetStaticType(fsinfo->ret); + } + } + // Call "relax.invoke_closure" to invoke closure + Var clo_arg = var; + if (has_closure && val->IsInstance()) { + if (this->var_remap_.find(var->vid) != this->var_remap_.end()) { + clo_arg = this->var_remap_.at(var->vid); + } + return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {}, + {GetStructInfo(GetRef(call_node))}); + } + auto it = lambda_map_.find(var); + if (it != lambda_map_.end()) { + // flatten nested call, e.g. call(y)(x) -> call(x, y)) + Array new_args; + Array params; + for (const auto arg : call->args) { + new_args.push_back(arg); + params.push_back(StructInfoFromType(arg->checked_type())); + } + if (const auto* nest_call = it->second.as()) { + // Update the StructInfo accordingly + for (const auto arg : nest_call->args) { + new_args.push_back(arg); + params.push_back(StructInfoFromType(arg->checked_type())); + } + StructInfo new_func_sinfo; + if (const auto* fsinfo = GetStructInfo(nest_call->op).as()) { + auto func_sinfo = GetRef(fsinfo); + new_func_sinfo = FuncStructInfo(params, func_sinfo->ret); + } + nest_call->op->struct_info_ = new_func_sinfo; + nest_call->op->checked_type_ = GetStaticType(new_func_sinfo); + return Call(nest_call->op, new_args, call_node->attrs, call_node->sinfo_args); + } + return Call(it->second, call->args, call_node->attrs, call_node->sinfo_args); + } + } + return std::move(call); + } + + Expr VisitExpr_(const FunctionNode* func_node) final { + auto func = GetRef(func_node); + + // TODO(@yongwww): consider appending inner func name into the lifted func name + String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++); + auto global = GlobalVar(lift_func_name); + Array free_vars = FreeVars(func); + Array captured_vars; + + Array typed_captured_vars; + bool recursive = false; + for (const auto& var : free_vars) { + if (!recur_vars_.empty() && var == recur_vars_.back()) { + recursive = true; + } else { + captured_vars.push_back(var); + } + } + + Map rebinding_map; + for (auto free_var : captured_vars) { + Var var = Var(free_var->name_hint(), GetStructInfo(free_var), free_var->span); + typed_captured_vars.push_back(var); + rebinding_map.Set(free_var, var); + } + + // recursive call + if (recursive) { + if (!captured_vars.empty()) { + Array fvs; + for (auto fv : captured_vars) { + fvs.push_back(fv); + } + // it is required by block_blocker, will be updated later + UpdateStructInfo(global, GetStructInfo(recur_vars_.back())); + lambda_map_.emplace(recur_vars_.back(), Call(global, fvs)); + } else { + if (recur_vars_.size() > 0) { + lambda_map_.emplace(recur_vars_.back(), global); + } + } + } + + tvm::Array params; + bool all_params_unchanged = true; + for (Var param : func_node->params) { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + all_params_unchanged &= param.same_as(new_param); + } + + Expr body = this->VisitWithNewScope(func_node->body); + Expr visited_func; + + if (all_params_unchanged && body.same_as(func_node->body)) { + visited_func = GetRef(func_node); + } else if (const auto& body_sinfo = MatchStructInfo(body)) { + visited_func = Function(params, body, body_sinfo.value(), func_node->attrs); + } else { + visited_func = Function(params, body, func_node->ret_struct_info, func_node->attrs); + } + auto new_func = Downcast(visited_func); + + Function lifted_func; + bool is_closure = IsClosure(captured_vars); + if (!is_closure) { + lifted_func = Function( + /*params=*/new_func->params, + /*body=*/new_func->body, + /*ret_struct_info=*/new_func->ret_struct_info, + /*attrs=*/new_func->attrs, + /*span=*/new_func->span); + } else { + // Flatten the Closure + std::vector closure_params; + closure_params.reserve(func->params.size() + typed_captured_vars.size()); + for (size_t i = 0; i < func->params.size(); ++i) { + closure_params.emplace_back(func->params[i]); + } + for (size_t i = 0; i < typed_captured_vars.size(); ++i) { + closure_params.emplace_back(typed_captured_vars[i]); + } + + lifted_func = Function(/*params=*/closure_params, + /*body=*/Bind(new_func->body, rebinding_map), + /*ret_struct_info=*/new_func->ret_struct_info, + /*attrs=*/new_func->attrs, + /*span=*/func->span); + + for (Var param : closure_params) { + CHECK(param->checked_type_.defined()) + << "relax.Function requires params to contain checked_type_"; + } + } + + ICHECK(lifted_func.defined()); + + // Add the lifted function to the module. + global->struct_info_ = GetStructInfo(lifted_func); + global->checked_type_ = lifted_func->checked_type_; + builder_->UpdateFunction(global, lifted_func); + + if (!is_closure) { + return std::move(global); + } else { + // If we need to allocate a closure, + // we pass the variables in its environment here. + Array fvs; + for (auto fv : captured_vars) { + fvs.push_back(fv); + } + // Call make_closure intrinsic + return Call(make_closure_op_, {global, Tuple(fvs)}, {}, {}); + } + } + + bool HasClosure(const Var& var) { + auto val = builder_->LookupBinding(var); + if (const auto* value = val.as()) { + IRModule ctx_mod = builder_->GetContextIRModule(); + ICHECK(ctx_mod->functions.size() > 0); + BaseFunc func = ctx_mod->Lookup(GetRef(value)); + if (const auto* func_node = func.as()) { + if (const auto* call_node = func_node->body.as()) { + if (call_node->op == make_closure_op_) { + return true; + } + } else if (const auto* seq_expr_node = func_node->body.as()) { + // the return var points to a make_closure intrinsic + if (const auto* var = seq_expr_node->body.as()) { + return HasClosure(GetRef(var)); + } + } + } + } else if (const auto* func_node = val.as()) { + if (const auto* call_node = func_node->body.as()) { + if (call_node->op == make_closure_op_) { + return true; + } + } + } else if (const auto* call_node = val.as()) { + // recursive call + auto op = call_node->op; + if (make_closure_op_ == op) { + return true; + } + if (const auto* lv = op.as()) { + return HasClosure(GetRef(lv)); + } + } + return false; + } + + bool IsClosure(const Array& captured_vars) { return captured_vars.size() > 0; } + + IRModule Lift() { + auto glob_funcs = mod_->functions; + for (auto pair : glob_funcs) { + if (auto* n = pair.second.as()) { + auto func = GetRef(n); + func = Function(func->params, VisitExpr(func->body), func->ret_struct_info, func->attrs); + builder_->UpdateFunction(pair.first, func); + } + } + return builder_->GetContextIRModule(); + } + + private: + std::unordered_map lambda_map_; + Array recur_vars_; + IRModule mod_; + size_t lift_func_num_ = 0; + /*! \brief Cache ops that would be used later to reduce lookup overhead. */ + const Op& make_closure_op_ = Op::Get("relax.make_closure"); + const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); +}; + +namespace transform { + +Pass LambdaLift() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relax::LambdaLifter(m).Lift(); }; + return CreateModulePass(pass_func, 1, "LambdaLift", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.LambdaLift").set_body_typed(LambdaLift); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc new file mode 100644 index 000000000000..350a40c37bf8 --- /dev/null +++ b/src/relax/transform/legalize_ops.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/legalize_ops.cc + * \brief Legalize high-level operator calls in Relax functions to call_tir + * with corresponding low-level TIR PrimFuncs. + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Check if a given Tensor/Shape/TupleStructInfo contains shapes whose + * values are all known. + * \param sinfo The StructInfo to be checked. + * \return A boolean indicating the given struct info contains shape values that are all known. + */ +bool KnowAllShapeValues(const StructInfo& sinfo) { + if (const auto* tensor_sinfo = sinfo.as()) { + return tensor_sinfo->shape.defined() && + tensor_sinfo->shape.value()->IsInstance(); + } else if (const auto* shape_sinfo = sinfo.as()) { + return shape_sinfo->values.defined(); + } else if (const auto* tuple_sinfo = sinfo.as()) { + return std::all_of(tuple_sinfo->fields.begin(), tuple_sinfo->fields.end(), + [](StructInfo field_sinfo) { return KnowAllShapeValues(field_sinfo); }); + } else if (sinfo.as()) { + return true; + } else { + return false; + } +} + +class LegalizeMutator : public ExprMutator { + public: + explicit LegalizeMutator(const IRModule& mod, const Optional>& cmap) + : ExprMutator(mod), mod_(std::move(mod)), cmap_(std::move(cmap)) {} + + IRModule Transform() { + for (const auto& [gv, func] : mod_->functions) { + if (func->IsInstance()) { + auto updated_func = Downcast(this->VisitExpr(func)); + builder_->UpdateFunction(gv, Downcast(updated_func)); + } + } + return builder_->GetContextIRModule(); + } + + private: + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call) final { + Call visited_call = Downcast(this->VisitExprPostOrder_(call)); + static const auto& legalize_map = Op::GetAttrMap("FLegalize"); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); + auto* op_node = visited_call->op.as(); + + // Not an OpNode + if (op_node == nullptr) { + return visited_call; + } + + // Not all shape values are known + if (!std::all_of(visited_call->args.begin(), visited_call->args.end(), + [](Expr arg) { return KnowAllShapeValues(GetStructInfo(arg)); }) || + !KnowAllShapeValues(GetStructInfo(visited_call))) { + return visited_call; + } + + auto op = GetRef(op_node); + + // Priority: customize > default. + // Check if it has customize legalization registered. + if (cmap_.defined() && cmap_.value().count(op->name)) { + return cmap_.value()[op->name](this->builder_, visited_call); + } + // Check if it has default legalization registered. + if (legalize_map.count(op)) { + return legalize_map[op](this->builder_, visited_call); + } + + // No legalization. + if (op != call_tir_op && op != call_dps_packed_op) { + LOG(WARNING) << "No legalization func for " << op->name << " is found."; + } + return visited_call; + } + + /*! \brief The context IRModule. */ + IRModule mod_; + /*! \brief The customized legalization function map. */ + Optional> cmap_; +}; + +namespace transform { + +Pass LegalizeOps(Optional> cmap) { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return LegalizeMutator(mod, cmap).Transform(); }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"LegalizeOps", + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.LegalizeOps").set_body_typed(LegalizeOps); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc new file mode 100644 index 000000000000..88939bd1f5ea --- /dev/null +++ b/src/relax/transform/lift_transform_params.cc @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/lambda_lift.cc + * \brief Lift local functions into global functions. + */ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Plan of lifting transform params */ +struct LiftTransformParamsInfoPlan { + Function f_transform_params; // the lifted function that transforms the parameters + std::unordered_map + output_to_index; // the index of the original bindings in the output tuple + std::unordered_set + lifted_bindings; // the bindings of the original function that are lifted +}; + +/*! \brief Builder of the function that transforms the parameters. */ +class TransformParamsFuncBuilder : public ExprMutator { + public: + TransformParamsFuncBuilder() { builder_->BeginDataflowBlock(); } + + /*! \brief Add a input parameter. */ + void AddInput(const Var& var) { inputs_.push_back(var); } + + /*! \brief Add a binding to lift. */ + void AddBinding(const VarBinding& binding) { bindings_.push_back(binding); } + + /*! \brief Mark a variable as the output of the function. */ + void MarkOutput(const Var& output) { outputs_.insert(output); } + + /*! + * \brief Build the function that transforms the parameters + * \return The created function, and a map from the variable in the original function to the index + * of the element of the output tuple + */ + std::pair> Build() { + Array input_sinfo; + Array output_vars; + std::unordered_map output_to_index; + + for (const auto& input : inputs_) { + input_sinfo.push_back(Downcast(input->struct_info_.value())); + } + Var params("params", TupleStructInfo(input_sinfo)); + + // Helper to add a variable to the output tuple + // original_var: the binding variable in the original function + // output_var: the variable, which is a binding in the transform_params function, that is added + // to the output tuple + auto f_add_output = [&](const Var& original_var, const Var& output_var) -> void { + output_to_index[original_var] = output_vars.size(); + output_vars.push_back(output_var); + }; + + // Create mapping from the original input variables to the TupleGetItem from the packed + // parameter tuple Add the parameters that are marked as the output of the function to the + // output tuple + for (const auto& input : inputs_) { + input_remap_.emplace(input.get(), TupleGetItem(params, input_remap_.size())); + if (outputs_.count(input)) { + auto output_var = builder_->Emit(input_remap_.at(input.get())); + f_add_output(input, output_var); + } + } + + // Re-emit the bindings that are lifted. Update the output tuple if the binding is marked as the + // output. + for (const auto& binding : bindings_) { + if (outputs_.count(binding->var)) { + auto output_var = builder_->Emit(VisitExpr(binding->value)); + var_remap_[binding->var->vid] = output_var; + f_add_output(binding->var, output_var); + } else { + VisitBinding(binding); + } + } + + // Create the function. + Expr transformed_params = builder_->EmitOutput(Tuple(output_vars)); + BindingBlock block = builder_->EndBlock(); + Expr body = builder_->Normalize(SeqExpr({block}, transformed_params)); + Function f_transform_params = + Function(/*params=*/{params}, /*body=*/body, /*ret_struct_info=*/NullOpt); + return {f_transform_params, output_to_index}; + } + + Expr VisitExpr_(const VarNode* var) final { + if (auto it = input_remap_.find(var); it != input_remap_.end()) { + return builder_->Emit((*it).second); + } else { + return ExprMutator::VisitExpr_(var); + } + } + + // The input parameters of the function. + Array inputs_; + // Remap from the original input variable to TupleGetItem from the packed parameter tuple, which + // is the input of the lifted function. + std::unordered_map input_remap_; + // The bindings that are lifted. + Array bindings_; + // The variables that are marked as the output of the function. + std::unordered_set outputs_; +}; + +/*! + * \brief Visitor that creates the plan of lifting transform params. + * + * Starting from the parameters of the function (they are the initial set of lifted bindings), we + * will visit the body of the function to find the bindings that can be lifted. A binding can be + * lifted if all the variables that it depends on are also lifted. + * + * When a binding cannot be lifted, all the variables that 1) it depends on, and 2) have been + * lifted, will be marked as the boundary variable and will be in the output of the lifted function. + */ +class LiftTransformParamsPlanner : public ExprVisitor { + public: + LiftTransformParamsInfoPlan Plan(const Function& function, int num_inputs) { + for (int i = num_inputs; i < static_cast(function->params.size()); ++i) { + builder_.AddInput(function->params[i]); + lifted_bindings_.emplace(function->params[i]); + } + VisitExpr(function->body); + + const auto& [f_transform_params, output_to_index] = builder_.Build(); + return {f_transform_params, output_to_index, std::move(lifted_bindings_)}; + } + + private: + void VisitBindingBlock_(const DataflowBlockNode* block) final { + is_in_dataflow_block_ = true; + ExprVisitor::VisitBindingBlock_(block); + is_in_dataflow_block_ = false; + } + + void VisitBinding_(const VarBindingNode* binding) final { + std::vector producers; + bool can_lift = true; + if (!is_in_dataflow_block_) { + can_lift = false; + } + if (const auto* call = binding->value.as()) { + static const Op& stop_lift_params_op = Op::Get("relax.builtin.stop_lift_params"); + if (call->op.same_as(stop_lift_params_op)) { + can_lift = false; + } + } + + PostOrderVisit(binding->value, [&](const ObjectRef& obj) { + if (const VarNode* var = obj.as()) { + producers.push_back(var); + if (!lifted_bindings_.count(GetRef(var))) { + can_lift = false; + } + } + }); + if (can_lift) { + lifted_bindings_.insert(binding->var); + builder_.AddBinding(GetRef(binding)); + } else { + for (const VarNode* producer : producers) { + if (lifted_bindings_.count(GetRef(producer))) { + builder_.MarkOutput(GetRef(producer)); + } + } + } + } + + // The bindings that are lifted + std::unordered_set lifted_bindings_; + // The builder of the function that transforms the parameters + TransformParamsFuncBuilder builder_; + // Whether we are in a dataflow block + bool is_in_dataflow_block_{false}; +}; + +/*! + *\brief The rewriter that lifts the transform params of a function and updates the original + * function. + */ +class TransformParamsLifter : public ExprMutator { + public: + explicit TransformParamsLifter(const IRModule& module) : ExprMutator(module) {} + + IRModule Lift() { + auto mod = builder_->GetContextIRModule(); + for (const auto& [gv, base_func] : mod->functions) { + // Skip non-Relax functions. + const auto* func_ = base_func.as(); + if (func_ == nullptr) { + continue; + } + // Skip functions that do not have the `num_input` attribute. + Optional opt_num_input = func_->attrs.GetAttr(attr_num_input_); + if (!opt_num_input.defined()) { + continue; + } + Function func = RewriteFunc(GetRef(func_), opt_num_input.value()->value, + gv->name_hint + "_transform_params"); + builder_->UpdateFunction(gv, func); + } + + return builder_->GetContextIRModule(); + } + + private: + Function RewriteFunc(const Function& func, int num_input, String new_func_name) { + LiftTransformParamsPlanner planner; + + // Step 1: Create the plan of lifting transform params + lift_plan_ = planner.Plan(func, num_input); + + // Step 2: Add the lifted function to the module + builder_->AddFunction(lift_plan_.f_transform_params, new_func_name); + + // Step 3: Update the current function. + + // Step 3.1: Update the function signature + Var params("params", lift_plan_.f_transform_params->ret_struct_info); + Array new_params; + for (int i = 0; i < num_input; ++i) { + new_params.push_back(func->params[i]); + } + new_params.push_back(params); + + // Step 3.2: Update the function body + for (const auto& [var, index] : lift_plan_.output_to_index) { + param_remap_[var] = TupleGetItem(params, index); + } + auto new_body = VisitExpr(func->body); + + // Step 3.3: Remove function attributes that are not needed + auto new_attrs = func->attrs; + auto* new_attrs_node = new_attrs.CopyOnWrite(); + new_attrs_node->dict.erase(attr_num_input_); + if (new_attrs->dict.empty()) { + new_attrs = NullValue(); + } + + Function new_func(new_params, new_body, func->ret_struct_info, new_attrs); + return new_func; + } + + void VisitBinding_(const VarBindingNode* binding) final { + if (lift_plan_.lifted_bindings.count(binding->var)) { + return; + } + if (const auto* call = binding->value.as()) { + static const Op& stop_lift_params_op = Op::Get("relax.builtin.stop_lift_params"); + if (call->op.same_as(stop_lift_params_op)) { + var_remap_[binding->var->vid] = Downcast(VisitExpr(call->args[0])); + return; + } + } + ExprMutator::VisitBinding_(binding); + } + + Expr VisitExpr_(const VarNode* var) final { + auto it = param_remap_.find(GetRef(var)); + if (it != param_remap_.end()) { + return builder_->Emit(it->second); + } + return ExprMutator::VisitExpr_(var); + } + + Expr VisitExpr_(const DataflowVarNode* var) final { + return VisitExpr_(static_cast(var)); + } + + const char* attr_num_input_ = "num_input"; + // Remap the original parameters to TupleGetItem from the packed tuple of transformed parameters. + std::unordered_map param_remap_; + // The plan of lifting the transform params + LiftTransformParamsInfoPlan lift_plan_; +}; + +namespace transform { +Pass LiftTransformParams() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return TransformParamsLifter(m).Lift(); }; + return CreateModulePass(pass_func, 1, "LiftTransformParams", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.LiftTransformParams").set_body_typed(LiftTransformParams); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc new file mode 100644 index 000000000000..f444d5c4f63f --- /dev/null +++ b/src/relax/transform/merge_composite_functions.cc @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/merge_composite_functions.cc + * \brief Group one or multiple composite functions created by FuseOpsByPattern into a new + * function. + * + * The new function will be annotated with kCodegen and kGlobalSymbol attributes, and it is + * intented to be offloaded to an external backend. + * + * A group for one composite function can be merged into another group for one of its arguments, + * which we call the parent group for brevity, if the following conditions are met: + * - The argument is the result of calling a composite function offloaded to the same backend + * - Merging into the parent group would not create a cyclic dependency with other parent groups + * + * For example, in the subgraph below the bottom group cannot be merged into the left parent group, + * since the right parent group for X depends on an output from the left parent group. + * + * O = Offloaded to A + * X = Offloaded to B + * + * Correct partitioning: + * + * O O + * / \ / \ + * O X --> O + + X + * \ / \ / + * O O + * + * The algorithm proceeds by assigning a group to each subexpression in the function according to + * its dataflow. On encountering a call node whose callee is a composite function, we check the + * two conditions above to see if we can merge this call node into one of its parent groups, and + * if we can merge some of its parent groups. + * + * To detect cyclic dependencies between groups, we propagate dependency relations, both direct + * and indirect ones, as we flow through the function. The propagation of indirect dependencies + * is important since the dependency relation is transitive. + */ + +#include +#include +#include +#include + +#include "../../support/arena.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +using relay::GraphPartitioner; + +namespace { + +using Group = GraphPartitioner::Group; + +/*! \brief Assign group to each subexpression in a function according to its + * dataflow, and returns a mapping from a subexpression to its group. */ +class CompositeGroupsBuilder : public MemoizedExprTranslator { + public: + using GroupMap = std::unordered_map; + using MemoizedExprTranslator::VisitExpr_; + + CompositeGroupsBuilder(IRModule mod, support::Arena* arena) : mod_(mod), arena_(arena) {} + + GroupMap Run(Function func) { + for (const auto& param : func->params) { + memo_[param] = arena_->make(); + } + VisitExpr(func->body); + + GroupMap group_map; + for (const auto& [expr, group] : memo_) { + group_map[expr.get()] = group->FindRoot(); + } + + return group_map; + } + + Group* VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + return VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + } + + void VisitBindingBlock_(const BindingBlockNode* block) { + for (Binding binding : block->bindings) { + VisitBinding(binding); + } + } + + void VisitBindingBlock_(const DataflowBlockNode* block) { + for (Binding binding : block->bindings) { + VisitBinding(binding); + } + } + + void VisitBindingBlock(const BindingBlock& block) { + if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + } + + Group* VisitExpr_(const SeqExprNode* op) { + for (BindingBlock block : op->blocks) { + VisitBindingBlock(block); + } + return VisitExpr(op->body); + } + + Group* VisitExpr_(const CallNode* call) { + std::vector groups_to_merge = GetGroupsToMerge(call); + Group* group; + + if (groups_to_merge.size() == 0) { + // Create new group if there is nothing to merge with + group = CreateNewGroup(call); + } else { + auto it = groups_to_merge.cbegin(); + // Assign the first mergable group to current node + // to reduce the number of groups created + group = *it++; + group->num_nodes += 1; + + // Merge all groups + for (; it != groups_to_merge.cend(); ++it) { + MergeGroup(*it, group); + } + } + + UpdateGroupDependencies(group, call->args); + return group; + } + + private: + String GetCodegenName(const std::string& composite_name) { + auto delim_pos = composite_name.find("."); + ICHECK(delim_pos != std::string::npos) << "The pattern name for a composite function should " + "start with a compiler name followed by period."; + return composite_name.substr(0, delim_pos); + } + + Optional GetCodegenName(const Expr& callee) { + auto const* gvar = callee.as(); + if (!gvar) { + return NullOpt; + } + + auto composite_name_opt = + mod_->Lookup(GetRef(gvar))->GetAttr(attr::kComposite); + if (!composite_name_opt) { + return NullOpt; + } + + return GetCodegenName(composite_name_opt.value()); + } + + Optional GetCodegenName(Group* group) { + return Downcast>(group->attrs.Get(attr::kCodegen)); + } + + Group* CreateNewGroup(const CallNode* call) { + Group* group = arena_->make(); + if (Optional codegen_name = GetCodegenName(call->op)) { + group->attrs.Set(attr::kCodegen, codegen_name.value()); + } + return group; + } + + void MergeGroup(Group* from, Group* to) { + ICHECK_EQ(GetCodegenName(from), GetCodegenName(to)); + + Group* from_root = from->FindRoot(); + Group* to_root = to->FindRoot(); + if (from_root == to_root) { + return; + } + + from_root->parent = to_root; + to_root->num_nodes += from_root->num_nodes; + + // Update the group_deps_, maintaining the invariant that + // all groups in the map are root groups. + group_deps_[to_root].merge(group_deps_[from_root]); + group_deps_.erase(from_root); + for (auto& it : group_deps_) { + if (it.second.count(from_root)) { + it.second.erase(from_root); + it.second.insert(to_root); + } + } + } + + std::unordered_set GetParentGroupDependencies(const Array& args) { + // Collect groups that parent groups depend on + std::unordered_set dependencies; + + for (const auto& arg : args) { + for (auto dep : group_deps_[memo_[arg]->FindRoot()]) { + dependencies.insert(dep); + } + } + + return dependencies; + } + + void UpdateGroupDependencies(Group* group, const Array& args) { + Group* group_root = group->FindRoot(); + + for (const auto& arg : args) { + auto arg_group_root = memo_[arg]->FindRoot(); + if (arg_group_root == group_root) { + // If arg and the current node are in the same group, + // there is nothing to update. + continue; + } + // Add the group of arg as dependency + group_deps_[group_root].insert(arg_group_root); + // Propagate dependencies of arg + for (auto dep : group_deps_[arg_group_root]) { + group_deps_[group_root].insert(dep); + } + } + } + + std::vector GetGroupsToMerge(const CallNode* call) { + Optional codegen_name = GetCodegenName(call->op); + if (!codegen_name.defined()) { + return {}; + } + + std::vector groups_to_merge; + std::unordered_set parent_dependencies = GetParentGroupDependencies(call->args); + + for (const auto& arg : call->args) { + auto arg_group = memo_[arg]; + Optional arg_codegen_name = GetCodegenName(arg_group); + if (arg_codegen_name == codegen_name && !parent_dependencies.count(arg_group->FindRoot())) { + // If there is a parent group with the same target, which none of the parent dependency + // groups depends on, merging "this" call node into the parent group will not form a cyclic + // dependency. + groups_to_merge.push_back(arg_group); + } + } + + return groups_to_merge; + } + + IRModule mod_; + support::Arena* arena_; + // Map from group to its dependencies. All groups in this map, whether it's + // the key or in value, should be root node (that is, group->parent == nullptr). + std::unordered_map> group_deps_; +}; + +/*! \brief Inline definitions of composite functions at the global level into their call sites. + This is necessary to make functions created by MergeCompositeFunctions self-contained - each + external backend compiler does not need to refer to the original containing module. + */ +class CompositeInliner : public ExprMutator { + public: + explicit CompositeInliner(IRModule mod) : ExprMutator(mod), mod_(mod) {} + using ExprMutator::VisitExpr_; + + Function Run(Function func) { + inlined_functions_ = Map(); + auto new_body = VisitExpr(func->body); + auto new_func = + Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span); + return new_func; + } + + Expr VisitExpr_(const CallNode* call) { + if (call->op->IsInstance()) { + auto gvar = Downcast(call->op); + auto func = Downcast(mod_->Lookup(gvar)); + + if (func->GetAttr(attr::kComposite)) { + if (!inlined_functions_.count(func)) { + inlined_functions_.Set(func, CopyWithNewVars(func)); + } + return Call(inlined_functions_[func], call->args); + } + } + + return ExprMutator::VisitExpr_(call); + } + + private: + IRModule mod_; + Map inlined_functions_; +}; + +} // namespace + +IRModule MergeCompositeFunctions(IRModule mod) { + auto gvar = mod->GetGlobalVar("main"); + auto func = Downcast(mod->Lookup(gvar)); + support::Arena arena; + auto group_map = CompositeGroupsBuilder(mod, &arena).Run(func); + auto new_mod = MakeGroupedFunctions(mod, group_map); + + CompositeInliner inliner(mod); + std::vector> to_update; + for (const auto& [gvar, func] : new_mod->functions) { + if (func->GetAttr(attr::kCodegen)) { + auto new_func = inliner.Run(Downcast(func)); + new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, gvar->name_hint); + to_update.emplace_back(gvar, new_func); + } + } + for (const auto& [gvar, func] : to_update) { + new_mod->Update(gvar, func); + } + // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better way to handle this. + return DeadCodeElimination(new_mod, {"main"}); +} + +namespace transform { + +Pass MergeCompositeFunctions() { + runtime::TypedPackedFunc pass_func = // + [=](IRModule mod, PassContext pc) { return relax::MergeCompositeFunctions(mod); }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseOpsByPattern", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.MergeCompositeFunctions") + .set_body_typed(MergeCompositeFunctions); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc new file mode 100644 index 000000000000..e205e984df02 --- /dev/null +++ b/src/relax/transform/meta_schedule.cc @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/meta_schedule.cc + * \brief Pass for meta_schedule tuning + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace transform { + +class MetaScheduleTuner { + public: + explicit MetaScheduleTuner(Target target, String work_dir, Integer max_trials_global, + Map params = {}) + : target_(target), + work_dir_(work_dir), + max_trials_global_(max_trials_global), + params_(params) { + candgen_func_ = runtime::Registry::Get("relax.tuning_api.default_generate_candidate"); + ICHECK(candgen_func_) << "Default candidate generation function is not found."; + normalize_mod_func_ = runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); + ICHECK(normalize_mod_func_) << "Normalization function is not found."; + } + + // TODO(@sunggg): Currently, only supports basic arguments. + IRModule TuneIRMod(IRModule mod, transform::PassContext ctx) { + Trace trace = Downcast(ctx->GetCurrentTrace()); + ctx->PopTrace(); + Choice choice("tvm.meta_schedule.tune_relax", {params_, target_, work_dir_, max_trials_global_}, + "relax.tuning_api.Choice.default_constr_func", {}); + Knob knob("meta_schedule.tune_irmod", {{"0", choice}}); + Array candidates = (*candgen_func_)(Array({knob}), trace); + ICHECK(candidates.size() == 1); + Trace best_trace = candidates[0]; + ctx->PushTrace(best_trace); + // since we separate tuning from application, return original IRModule + return mod; + } + + // TODO(@sunggg): Currently, only supports basic arguments. + tir::PrimFunc TuneTIR(tir::PrimFunc f, transform::PassContext ctx) { + // TODO(@sunggg): Whenever we tune tir, assume we start a new trace w/o pushing to the trace + // stack. Revisit later when we collect more usecases. + Trace trace = Trace((*normalize_mod_func_)(f), {}, {}); + + Choice choice("tvm.meta_schedule.tune_tir", {target_, work_dir_, max_trials_global_}, + "relax.tuning_api.Choice.default_constr_func", {}); + Knob knob("meta_schedule.tune_primfunc", {{"0", choice}}); + Array candidates = (*candgen_func_)(Array({knob}), trace); + ICHECK(candidates.size() == 1); + // since we separate tuning from application, return original IRModule + return f; + } + + private: + Target target_; + String work_dir_; + Integer max_trials_global_; + Map params_; + const runtime::PackedFunc* candgen_func_; + const runtime::PackedFunc* normalize_mod_func_; +}; + +Pass MetaScheduleApplyDatabase(Optional work_dir) { + using tvm::meta_schedule::Database; + Target target = Target::Current(false); + const runtime::PackedFunc* normalize_mod_func_ = + runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); + ICHECK(normalize_mod_func_) << "Normalization function is not found."; + + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext ctx) { + Database database{nullptr}; + if (Database::Current().defined()) { + database = Database::Current().value(); + } else { + ICHECK(work_dir.defined()); + String path_workload = work_dir.value() + "/database_workload.json"; + String path_tuning_record = work_dir.value() + "/database_tuning_record.json"; + LOG(WARNING) << "Creating JSONDatabase. Workload at: " << path_workload + << ", Tuning records at: " << path_tuning_record; + database = meta_schedule::Database::JSONDatabase(path_workload, path_tuning_record, true); + } + + Map result; + for (const auto& iter : mod->functions) { + GlobalVar gv = iter.first; + BaseFunc base_func = iter.second; + if (const auto* prim_func_node = base_func.as()) { + tir::PrimFunc prim_func = GetRef(prim_func_node); + + IRModule tir_mod = (*normalize_mod_func_)(prim_func); + if (Optional sch = database->QuerySchedule(tir_mod, target, gv->name_hint)) { + IRModule new_mod = sch.value()->mod(); + ICHECK_EQ(new_mod->functions.size(), 1); + BaseFunc new_base_func = (*new_mod->functions.begin()).second; + ICHECK(new_base_func->IsInstance()); + tir::PrimFunc new_prim_func = Downcast(new_base_func); + // copy the original attrs + new_prim_func = WithAttrs(std::move(new_prim_func), {prim_func->attrs->dict}); + new_prim_func = WithAttr(std::move(new_prim_func), tir::attr::kIsScheduled, Bool(true)); + result.Set(gv, new_prim_func); + continue; + } else { + LOG(WARNING) << "Tuning record is not found for primfunc: " << gv->name_hint; + } + } + result.Set(gv, base_func); + } + return IRModule(result, // functions + {}, // type_definitions + {}, // import_set + {}, // map + mod->attrs); // attrs); + }; + return CreateModulePass(pass_func, 0, "MetaScheduleApplyDatabase", {}); +} + +Pass MetaScheduleTuneIRMod(Map params, String work_dir, + Integer max_trials_global) { + Target target = Target::Current(false); + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext ctx) { + return MetaScheduleTuner(target, work_dir, max_trials_global, params).TuneIRMod(m, ctx); + }; + return CreateModulePass(/*pass function*/ pass_func, /*opt level*/ 0, + /*pass name*/ "MetaScheduleTuneIRModule", + /*required*/ {}, + /*traceable*/ true); +} + +Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { + Target target = Target::Current(false); + runtime::TypedPackedFunc pass_func = + [=](tir::PrimFunc f, IRModule mod, PassContext ctx) { + return MetaScheduleTuner(target, work_dir, max_trials_global).TuneTIR(f, ctx); + }; + return tir::transform::CreatePrimFuncPass(/*pass function*/ pass_func, /*opt level*/ 0, + /*pass name*/ "MetaScheduleTuneTIR", + /*required*/ {}, + /*traceable*/ true); +} + +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleApplyDatabase") + .set_body_typed(MetaScheduleApplyDatabase); +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneIRMod").set_body_typed(MetaScheduleTuneIRMod); +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneTIR").set_body_typed(MetaScheduleTuneTIR); +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc new file mode 100644 index 000000000000..915498178f0f --- /dev/null +++ b/src/relax/transform/normalize.cc @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/normalize.cc + * \brief Pass for transforming Relax IR to normal form, i.e., the expressions are normalized(no + * nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are + * available. + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +// TODO(@altanh): LCA binding lifting +class NormalizeMutator : public ExprMutatorBase { + public: + NormalizeMutator() { builder_ = BlockBuilder::Create(NullOpt); } + + Expr VisitExpr(const Expr& expr) override { + return builder_->Normalize(ExprMutatorBase::VisitExpr(expr)); + } + + Expr VisitExpr_(const FunctionNode* op) final { + Expr body = this->VisitWithNewScope(op->body, op->params); + + if (body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->params, body, op->ret_struct_info, op->attrs); + } + } + + Expr VisitExpr_(const IfNode* op) final { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitWithNewScope(op->true_branch); + Expr false_b = this->VisitWithNewScope(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } + } + + Expr VisitWithNewScope(const Expr& expr, Optional> params = NullOpt) { + builder_->BeginBindingBlock(); + builder_->BeginScope(params); + Expr ret = this->VisitExpr(expr); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + ret = SeqExpr({prologue}, ret); + } + builder_->EndScope(); + return ret; + } + + Expr VisitExpr_(const SeqExprNode* op) final { + bool all_blocks_unchanged = true; + Array blocks; + for (auto block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + all_blocks_unchanged &= block.same_as(new_block); + } + + builder_->BeginBindingBlock(); + Expr body = this->VisitExpr(op->body); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + blocks.push_back(prologue); + all_blocks_unchanged = false; + } + + if (all_blocks_unchanged && body.same_as(op->body)) { + return GetRef(op); + } else { + return SeqExpr(blocks, body); + } + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) final { + BindingBlock ret; + if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return ret; + } + + BindingBlock VisitBindingBlock_(const BindingBlockNode* block) { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) { + builder_->BeginDataflowBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } + + void VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + } + + void VisitBinding_(const VarBindingNode* binding) { + Expr new_value = this->VisitExpr(binding->value); + if (!binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, GetStructInfo(new_value)); + } + + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + builder_->EmitNormalized(VarBinding(binding->var, new_value)); + } + } + + void VisitBinding_(const MatchCastNode* binding) { + Expr new_value = this->VisitExpr(binding->value); + + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + builder_->EmitNormalized( + MatchCast(binding->var, builder_->NormalizeArgument(new_value), binding->struct_info)); + } + } + + private: + /*! \brief Internal block builder to emit bindings during rewriting. */ + BlockBuilder builder_; +}; // namespace relax + +Expr Normalize(const Expr& e) { return NormalizeMutator().VisitExpr(e); } + +namespace transform { + +Pass Normalize() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(Normalize(f)); }; + return CreateFunctionPass(pass_func, 1, "Normalize", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.Normalize").set_body_typed(Normalize); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc new file mode 100644 index 000000000000..e5d654fba355 --- /dev/null +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/rewrite_dataflow_reshape.cc + * \brief Transform all reshape within dataflow block to a relax.reshape operator + */ +#include +#include +#include + +#include "../op/tensor/manipulate.h" + +namespace tvm { +namespace relax { + +class DataflowReshapeRewriter : public ExprMutator { + public: + explicit DataflowReshapeRewriter(const IRModule& mod) : mod_(mod) {} + + private: + using ExprMutator::VisitExpr_; + + BindingBlock VisitBindingBlock(const BindingBlock& block) final { + // We only rewrite the bindings inside dataflow blocks. + if (const auto* dataflow_block = block.as()) { + return VisitBindingBlock_(dataflow_block); + } else { + return block; + } + } + + void VisitBinding_(const VarBindingNode* binding) final { + // We only rewrite the bindings that are not dataflow output (which means they are not + // externally referenced) + if (!binding->var->IsInstance()) { + this->builder_->EmitNormalized(GetRef(binding)); + } else { + ExprMutator::VisitBinding_(binding); + } + } + + Expr VisitExpr_(const CallNode* call) final { + if (!IsCallingTIRReshape(call)) { + return GetRef(call); + } + + // We bring the calls of reshape PrimFunc back to calls of high-level + // relax.reshape op, which will be lowered to calls of the ExternFunc + // vm.builtin.reshape in the VMBuiltinLower pass. + Array args = Downcast(call->args[1])->fields; + ICHECK_EQ(args.size(), 1); + TensorStructInfo res_sinfo = Downcast(call->struct_info_); + ICHECK(res_sinfo->shape.defined()); + return reshape(args[0], res_sinfo->shape.value()); + } + + bool IsCallingTIRReshape(const CallNode* call) { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (call->op != call_tir_op) { + return false; + } + const GlobalVar& global_var = Downcast(call->args[0]); + const auto* func = mod_->functions.Get(global_var).as(); + ICHECK_NOTNULL(func); + return HasReshapePattern(GetRef(func)); + } + + const IRModule& mod_; +}; + +Expr RewriteDataflowReshape(const Function& f, const IRModule& mod) { + return DataflowReshapeRewriter(mod)(f); +} + +namespace transform { + +Pass RewriteDataflowReshape() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(RewriteDataflowReshape(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "RewriteDataflowReshape", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.RewriteDataflowReshape") + .set_body_typed(RewriteDataflowReshape); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc new file mode 100644 index 000000000000..25867419a99e --- /dev/null +++ b/src/relax/transform/run_codegen.cc @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/run_codegen.cc + * \brief Run codegen for annotated relax functions. + */ + +#include +#include + +#include + +#include "utils.h" + +namespace tvm { +namespace relax { + +class CodeGenRunner : ExprMutator { + public: + using OptionMap = Map; + + explicit CodeGenRunner(IRModule mod) : ExprMutator(mod) {} + + IRModule Run(Optional> target_options, Array entry_functions) { + IRModule mod = builder_->GetContextIRModule(); + for (const String& entry_func_name : entry_functions) { + auto entry_func = mod->Lookup(entry_func_name); + auto gvar = mod->GetGlobalVar(entry_func_name); + builder_->UpdateFunction(gvar, Downcast(VisitExpr(entry_func))); + } + + auto ext_mods = InvokeCodegen(mod, target_options.value_or({})); + auto out_mod = builder_->GetContextIRModule(); + + if (ext_mods.size()) { + out_mod = WithAttr(out_mod, tvm::attr::kExternalMods, std::move(ext_mods)); + } + + if (constant_names.size()) { + // Some backends (e.g. TensorRT) expect constants to be passed when they are instantiated + Map constants; + for (const auto& [constant, name] : constant_names) { + ICHECK(!constants.count(name)) << "More than one constant with the name " << name; + constants.Set(name, constant->data); + } + out_mod = WithAttr(out_mod, tvm::attr::kConstNameToConstant, std::move(constants)); + } + + // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better way to handle this. + return DeadCodeElimination(out_mod, entry_functions); + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) override { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + if (auto const* gvar_node = call_node->op.as()) { + const GlobalVar gvar = GetRef(gvar_node); + + auto create_call_dps_packed = [call_node, this](Expr extern_func, + StructInfo ret_struct_info) { + Array new_args({extern_func}); + new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return VisitExpr(arg); }))); + + static const Op& call_op = Op::Get("relax.call_dps_packed"); + + return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info}); + }; + + if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) { + return create_call_dps_packed(it->second.first, it->second.second); + } else { + // TODO(@sunggg): Is there any better way to get this func? + Function func = Downcast(builder_->GetContextIRModule()->Lookup(gvar)); + Expr new_func = VisitExpr(func); + + if (new_func->IsInstance()) { + auto ret_sinfo = GetStructInfo(call); + extern_funcs_[gvar_node] = {new_func, ret_sinfo}; + // Remove the global symbol and codegen attributes from the function so that it can be + // removed the module. + static const runtime::PackedFunc* RemoveFuncAttrFunc = + runtime::Registry::Get("ir.BaseFuncWithoutAttr"); + ICHECK(RemoveFuncAttrFunc); + func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol); + func = (*RemoveFuncAttrFunc)(func, attr::kCodegen); + builder_->UpdateFunction(gvar, func); + return create_call_dps_packed(new_func, ret_sinfo); + } + } + } + Array new_args; + for (const auto& arg : call_node->args) { + new_args.push_back(VisitExpr(arg)); + } + + return Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); + } + + Expr VisitExpr_(const FunctionNode* func_node) override { + Function func = GetRef(func_node); + auto opt_codegen = func->GetAttr(attr::kCodegen); + if (opt_codegen) { + auto ext_symbol = GetExtSymbol(func); + size_t count = 0; + PostOrderVisit(func->body, [=, &count](Expr e) { + if (e->IsInstance()) { + // Make sure to pick a unique name + auto name = ext_symbol + "_" + opt_codegen.value() + "_const_" + std::to_string(count++); + auto constant = Downcast(e); + constant_names.Set(constant, name); + } + }); + return ExternFunc(GetExtSymbol(func)); + } else { + return ExprMutator::VisitExpr_(func_node); + } + } + + private: + Array InvokeCodegen(IRModule mod, Map target_options) { + std::unordered_map> target_functions; + + for (const auto& entry : mod->functions) { + if (entry.second->IsInstance()) { + continue; + } + PostOrderVisit(entry.second, [&target_functions](Expr e) { + if (e->IsInstance()) { + auto f = Downcast(e); + if (auto target_opt = f->GetAttr(attr::kCodegen)) { + String target = target_opt.value(); + target_functions[target].push_back(f); + } + } + }); + } + + Array ext_mods; + + for (const auto& [target, functions] : target_functions) { + OptionMap options = target_options.Get(target).value_or({}); + // Start the codegen process. + // Get the codegen with its ffi key. + String codegen_name = "relax.ext." + target; + auto codegen = runtime::Registry::Get(codegen_name); + ICHECK(codegen) << "Codegen is not found: " << codegen_name << "\n"; + + Array compiled_functions = (*codegen)(functions, options, constant_names); + ext_mods.insert(ext_mods.end(), compiled_functions.begin(), compiled_functions.end()); + } + + return ext_mods; + } + + /*! \brief The names of all constants in the original module. */ + Map constant_names; + /*! \brief Extern funcs and their return struct infos for each global variable. */ + std::unordered_map> extern_funcs_; +}; + +} // namespace relax + +namespace transform { +Pass RunCodegen(Optional>> target_options, + Array entry_functions) { + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { + return relax::CodeGenRunner(m).Run(target_options, entry_functions); + }; + return CreateModulePass(pass_func, 0, "RunCodegen", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.RunCodegen").set_body_typed(RunCodegen); + +} // namespace transform +} // namespace tvm diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc new file mode 100644 index 000000000000..7fcc2cb34a76 --- /dev/null +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -0,0 +1,782 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/to_non_dataflow.cc + * \brief Transform all dataflow structure to non-dataflow version. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../tir/schedule/ir_comparator.h" + +namespace tvm { + +static const constexpr char* kLibraryKernel = "library_kernel"; +static const constexpr char* kCSource = "c_source"; +static const constexpr char* kCSourceFmt = "c_source_fmt"; +static const constexpr char* kCSourceFmtCuda = "cu"; + +namespace tir { + +using relax::FCodegen; +using relax::MatchResult; +using relax::TIRPattern; + +/*! \brief helper to match a for stmt to a pattern*/ +class ForMatcher : public TensorizeComparator { + public: + using SymbolMap = std::unordered_map; + explicit ForMatcher(const tir::PrimFunc& pattern, const Array& pattern_vars) + : TensorizeComparator(IRModule({{GlobalVar(""), pattern}}), false), pattern_(pattern) { + for (const auto& pattern_var : pattern_vars) { + this->pattern_vars_.insert(pattern_var); + } + this->evaluated_symbols.push_back(SymbolMap()); + } + + bool Match(const For& top) { + const ForNode* pattern_top = pattern_->body.as()->block->body.as(); + ICHECK(pattern_top) << "Invalid pattern function"; + if (!VisitStmt(top, GetRef(pattern_top))) { + return false; + } + // Get evaluated symbols, buffers from the pattern. + for (const auto& arg : pattern_->params) { + auto it = pattern_->buffer_map.find(arg); + if (it != pattern_->buffer_map.end()) { + auto itt = rhs_buffer_map_.find((*it).second); + ICHECK(itt != rhs_buffer_map_.end()); + evaluated_buffers.push_back(itt->second); + } + } + return true; + } + + std::vector evaluated_symbols; + std::vector evaluated_buffers; + + private: + using ExprComparator::VisitExpr_; + + Optional QueryEvaluatedSymbols(const Var& var) { + for (const SymbolMap& symbol_map : evaluated_symbols) { + auto it = symbol_map.find(var); + if (it != symbol_map.end()) { + return it->second; + } + } + return NullOpt; + } + + bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final { + if (const auto* op = rhs.as()) { + if (pattern_vars_.count(GetRef(op))) { + // special case for pattern vars + const auto* lhs_ptr = lhs.as(); + if (lhs_ptr == nullptr) { + if (lhs->IsInstance() || lhs->IsInstance()) { + Optional value = QueryEvaluatedSymbols(GetRef(op)); + if (value.defined()) { + if (!analyzer_.CanProveEqual(lhs, value.value())) return false; + } else { + evaluated_symbols.back()[GetRef(op)] = lhs; + } + return true; + } else { + return false; + } + } + } + } + // pattern_var * expr + if (const auto* rhs_ptr = rhs.as()) { + const auto* operand_a = rhs_ptr->a.as(); + const auto* operand_b = rhs_ptr->b.as(); + if (operand_a != nullptr && pattern_vars_.count(GetRef(operand_a))) { + // pattern var is on the left + evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(lhs, rhs_ptr->b); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + evaluated_symbols.pop_back(); + if (match) { + evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + evaluated_symbols.back()[GetRef(operand_a)] = MakeConstScalar(rhs_ptr->b.dtype(), 1); + return true; + } + } + if (operand_b != nullptr && pattern_vars_.count(GetRef(operand_b))) { + // pattern var is on the right + evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(lhs, rhs_ptr->a); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + evaluated_symbols.pop_back(); + if (match) { + evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + evaluated_symbols.back()[GetRef(operand_b)] = MakeConstScalar(rhs_ptr->a.dtype(), 1); + return true; + } + } + } + // pattern_Var + expr + if (const auto* rhs_ptr = rhs.as()) { + const auto* operand_a = rhs_ptr->a.as(); + const auto* operand_b = rhs_ptr->b.as(); + if (operand_a != nullptr && pattern_vars_.count(GetRef(operand_a))) { + // pattern var is on the left + evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(lhs, rhs_ptr->b); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + evaluated_symbols.pop_back(); + if (match) { + evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + evaluated_symbols.back()[GetRef(operand_a)] = MakeConstScalar(rhs_ptr->b.dtype(), 0); + return true; + } + } + if (operand_b != nullptr && pattern_vars_.count(GetRef(operand_b))) { + // pattern var is on the right + evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(lhs, rhs_ptr->a); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + evaluated_symbols.pop_back(); + if (match) { + evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + evaluated_symbols.back()[GetRef(operand_b)] = MakeConstScalar(rhs_ptr->a.dtype(), 0); + return true; + } + } + } + return TensorizeComparator::VisitExpr(lhs, rhs); + } + + bool VisitExpr_(const tir::AddNode* add, const PrimExpr& other) final { + const auto* rhs = other.as(); + if (rhs == nullptr) return false; + { + this->evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(add->a, rhs->a) && VisitExpr(add->b, rhs->b); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + this->evaluated_symbols.pop_back(); + if (match) { + this->evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + return true; + } + } + { + this->evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(add->a, rhs->b) && VisitExpr(add->b, rhs->a); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + this->evaluated_symbols.pop_back(); + if (match) { + this->evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + return true; + } + } + return false; + } + + bool VisitExpr_(const tir::MulNode* mul, const PrimExpr& other) final { + const auto* rhs = other.as(); + if (rhs == nullptr) return false; + { + this->evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(mul->a, rhs->a) && VisitExpr(mul->b, rhs->b); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + this->evaluated_symbols.pop_back(); + if (match) { + this->evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + return true; + } + } + { + this->evaluated_symbols.push_back(SymbolMap()); + bool match = VisitExpr(mul->a, rhs->b) && VisitExpr(mul->b, rhs->a); + SymbolMap symbol_map = std::move(evaluated_symbols.back()); + this->evaluated_symbols.pop_back(); + if (match) { + this->evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); + return true; + } + } + return false; + } + + bool VisitExpr_(const tir::CallNode* call, const PrimExpr& other) final { + const auto* rhs = other.as(); + if (rhs == nullptr) return false; + const auto* lhs_op = call->op.as(); + const auto* rhs_op = rhs->op.as(); + if (lhs_op == nullptr || rhs_op == nullptr) return false; + if (lhs_op->name != rhs_op->name) return false; + if (call->args.size() != rhs->args.size()) return false; + for (size_t i = 0; i < call->args.size(); ++i) { + if (!VisitExpr(call->args[i], rhs->args[i])) return false; + } + return true; + } + + bool VisitStmt_(const tir::ForNode* op, const Stmt& other) final { + const auto* rhs = other.as(); + loop_stack_lhs_.push_back(GetRef(op)); + loop_stack_rhs_.push_back(GetRef(rhs)); + // The body of loop must be loop or BlockRealize + if (!op->body->IsInstance() && !op->body->IsInstance()) { + return false; + } + if (!rhs->body->IsInstance() && !rhs->body->IsInstance()) { + return false; + } + // Build mapping between the loop vars + if (!DefEqual(op->loop_var, rhs->loop_var)) return false; + // Only handle the case where the loop start from 0 + if (!is_zero(op->min) || !is_zero(rhs->min)) return false; + if (op->thread_binding.defined() || rhs->thread_binding.defined()) return false; + if (op->kind != ForKind::kSerial || op->kind != rhs->kind) return false; + if (!op->annotations.empty() || !rhs->annotations.empty()) return false; + // Match the extents of loops + if (!VisitExpr(op->extent, rhs->extent)) return false; + return VisitStmt(op->body, rhs->body); + } + + bool VisitStmt_(const tir::BlockNode* op, const Stmt& other) final { + const auto* rhs = other.as(); + // Check block equality. + // All iter vars and buffer regions including the order should match. + // When checking iter vars, DefEqual is used to remap variables. + if (!CompareArray(op->iter_vars, rhs->iter_vars, &ForMatcher::CompareIterVar)) { + return false; + } + // disallow alloc buffers inside the block + if (!op->alloc_buffers.empty() || !rhs->alloc_buffers.empty()) return false; + if (!CompareArray(op->writes, rhs->writes, &ForMatcher::CompareBufferRegion)) { + return false; + } + if (!CompareArray(op->reads, rhs->reads, &ForMatcher::CompareBufferRegion)) { + return false; + } + // The body of the block has to be BufferStore + if (!op->body->IsInstance() || !rhs->body->IsInstance()) { + return false; + } + // Handle init block + if (op->init.defined() && !rhs->init.defined()) return false; + if (!op->init.defined() && rhs->init.defined()) return false; + if (op->init.defined() && rhs->init.defined()) { + if (!VisitStmt(op->init.value(), rhs->init.value())) return false; + } + return VisitStmt(op->body, rhs->body); + } + + bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) final { + const auto* rhs = other.as(); + // Only allow trivial bindings + for (size_t i = 0; i < op->iter_values.size(); ++i) { + if (!op->iter_values[i].same_as(loop_stack_lhs_[i]->loop_var)) return false; + } + for (size_t i = 0; i < rhs->iter_values.size(); ++i) { + if (!rhs->iter_values[i].same_as(loop_stack_rhs_[i]->loop_var)) return false; + } + // Disallow predicates now + if (!is_one(op->predicate) || !is_one(rhs->predicate)) return false; + return VisitStmt(op->block, rhs->block); + } + + bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); + } + + bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs); + } + + bool CompareBuffer(const Buffer& lhs, const Buffer& rhs) { + if (lhs.same_as(rhs)) return true; + auto it = rhs_buffer_map_.find(rhs); + bool equal; + if (it != rhs_buffer_map_.end()) { + equal = (*it).second.same_as(lhs); + } else { + // Compare shape + if (lhs->shape.size() != rhs->shape.size()) return false; + for (size_t i = 0; i < lhs->shape.size(); ++i) { + if (!VisitExpr(lhs->shape[i], rhs->shape[i])) return false; + } + // Remap both buffer itself and buffer data + equal = + DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype && lhs.scope() == rhs.scope(); + if (equal) { + rhs_buffer_map_[rhs] = lhs; + } + } + return equal; + } + + bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) { + return false; + } + return CompareArray(lhs->region, rhs->region, &ForMatcher::CompareRange); + } + + template + bool CompareBufferAccess(const T* lhs, const T* rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + return CompareArray(lhs->indices, rhs->indices, &ForMatcher::VisitExpr); + } + + template + bool CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!(static_cast(this)->*cmp)(lhs[i], rhs[i])) return false; + } + return true; + } + + arith::Analyzer analyzer_; + std::vector loop_stack_lhs_, loop_stack_rhs_; + tir::PrimFunc pattern_; + std::unordered_set pattern_vars_; +}; + +/*! \brief Analyze the function and match it with a list of patterns */ +class TIRPatternMatcher { + public: + static Array Match(Array patterns, Stmt body) { + TIRPatternMatcher matcher(patterns); + matcher.OpMatternMatch(body); + if (matcher.fail_) return {}; + return matcher.match_results_; + } + + private: + explicit TIRPatternMatcher(Array patterns) : patterns_(patterns) {} + + // Find an op that matches this block + bool BlockPatternMatch(const For& top) { + for (const TIRPattern& pattern : patterns_) { + tir::PrimFunc pattern_func = pattern; + Array pattern_symbolic_vars; + int buffer_count = pattern_func->buffer_map.size(); + for (int i = buffer_count; i < static_cast(pattern_func->params.size()); i++) { + pattern_symbolic_vars.push_back(pattern_func->params[i]); + } + ForMatcher block_matcher(pattern_func, pattern_symbolic_vars); + if (block_matcher.Match(top)) { + // We have found a match + Array symbol_values; + for (int i = buffer_count; i < static_cast(pattern_func->params.size()); i++) { + symbol_values.push_back(block_matcher.evaluated_symbols.back()[pattern_func->params[i]]); + } + match_results_.push_back( + MatchResult(pattern, symbol_values, block_matcher.evaluated_buffers)); + return true; + } + } + // The block fails to match any pattern + return false; + } + + // For each block in the body, try to find its corresponding pattern one by one + void OpMatternMatch(const Stmt& body) { + Array blocks; + if (body->IsInstance()) { + // {for} + blocks = {body}; + } else if (const SeqStmtNode* seq = body.as()) { + blocks = seq->seq; + } else { + fail_ = true; + return; + } + for (const Stmt& stmt : blocks) { + const ForNode* loop = stmt.as(); + if (loop == nullptr || !BlockPatternMatch(GetRef(loop))) { + break; + } + } + if (match_results_.empty()) { + fail_ = true; + } + } + /*! \brief Indicate whether we fail to match.*/ + bool fail_ = false; + /*! \brief The patterns we match the target stmt to.*/ + Array patterns_; + /*! \brief The results of the matching process.*/ + Array match_results_; +}; + +/*! \brief helper class to partition a function into 2 parts. Return function information which we + * can use to construct the two partitioned parts.*/ +class FunctionPartitioner : public StmtExprVisitor { + public: + explicit FunctionPartitioner(int num_matched_ops) : num_matched_ops_(num_matched_ops) {} + /*! \brief alloc_buffers for the first function */ + std::unordered_set allocs1; + /*! \brief alloc_buffers for the second function */ + std::unordered_set allocs2; + /*! \brief whether the current block is in the first function */ + Map block_partition; + /*! \brief input buffers for the first function */ + std::unordered_set input1; + /*! \brief input buffers for the second function */ + std::unordered_set input2; + /*! \brief The output buffer for the first function, which is also the input buffer for the second + function */ + Buffer intermediate_buffer; + /*! \brief Indicate whether we have failed. If failed, we will not do any further analysis and + directly return the original one. */ + bool fail = false; + + private: + void VisitStmt_(const BlockNode* op) final { + block_counter_++; + bool is_matching_ = block_counter_ <= num_matched_ops_; + if (block_counter_ == num_matched_ops_) { + allocs1.erase(intermediate_buffer); + } + for (const auto& read : op->reads) { + if (is_matching_) { + input1.insert(read->buffer); + } else { + input2.insert(read->buffer); + } + } + for (const auto& write : op->writes) { + if (is_matching_) { + allocs1.insert(write->buffer); + } else if (allocs1.count(write->buffer)) { + fail = true; + return; + } else { + allocs2.insert(write->buffer); + } + if (is_matching_) { + intermediate_buffer = write->buffer; + } else { + input2.insert(write->buffer); + } + } + block_partition.Set(GetRef(op), Bool(is_matching_)); + } + // The number of matched ops in the function + size_t num_matched_ops_; + size_t block_counter_ = 0; +}; + +/*! \brief remove parts according to block partition, and update the alloc_buffers for blocks */ +class BlockRemover : public StmtExprMutator { + public: + static Stmt RemoveBlockByPartition( + Stmt stmt, const Map& block_partition, + const std::unordered_set& allocs, + bool is_library_part) { + BlockRemover remover(block_partition, allocs, is_library_part); + return remover(stmt); + } + + private: + BlockRemover(const Map& block_partition, + const std::unordered_set& allocs, + bool is_library_part) + : block_partition(block_partition), allocs_(allocs), is_library_part_(is_library_part) {} + + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + ObjectPtr n = make_object(*block.operator->()); + if (op->name_hint != "root") { + ICHECK(block_partition.count(GetRef(op))); + bool block_is_library = block_partition[GetRef(op)]->value; + if (!(is_library_part_ ^ block_is_library)) { + n->body = block->body; + } else { + erased_ = true; + } + } + Array alloc_buffers; + for (const Buffer& b : block->alloc_buffers) { + if (allocs_.count(b)) { + alloc_buffers.push_back(b); + } + } + n->alloc_buffers = alloc_buffers; + return Block(n); + } + + Stmt VisitStmt_(const SeqStmtNode* op) final { + Array seq; + for (const Stmt& s : op->seq) { + Stmt new_s = VisitStmt(s); + if (erased_) { + erased_ = false; + } else { + seq.push_back(new_s); + } + } + return SeqStmt::Flatten(seq); + } + + bool erased_ = false; + Map block_partition; + std::unordered_set allocs_; + bool is_library_part_ = false; +}; + +/*! + * \brief Split the input function into two functions, one for the library kernel and one for the + * rest. + * \param func The input function. + * \param arg_partition The input arg for the functions after split. + * \param patterns The patterns to match. + * \param f_codegen The function to generate the code for the library kernel. + * \return A pair of functions, the first one is the library kernel and the second one is the + * rest. + */ +std::pair> SplitFunctions(PrimFunc func, + std::vector>* arg_partition, + Array patterns, + FCodegen f_codegen) { + // Step 1. Find the library kernel and the rest. + Stmt body = func->body.as()->block->body; + Array match_results = + TIRPatternMatcher::Match(patterns, func->body.as()->block->body); + if (match_results.empty()) { + return {func, NullOpt}; + } + Array codegen_result = f_codegen(match_results); + ICHECK(codegen_result.size() == 3); + String library_code = Downcast(codegen_result[0]); + int num_matched_ops = Downcast(codegen_result[1])->value; + Array func1_args = Downcast>(codegen_result[2]); + if (num_matched_ops == 0) { + return {func, NullOpt}; + } + FunctionPartitioner partitioner(num_matched_ops); + partitioner(body); + if (partitioner.fail) { + return {func, NullOpt}; + } + bool has_second_func = false; + for (const auto& pr : partitioner.block_partition) { + if (!pr.second->value) { + has_second_func = true; + break; + } + } + if (!has_second_func) { + // No need to split the function. + return {WithAttr(func, kLibraryKernel, library_code), NullOpt}; + } + // Step 2. Split the function into two functions. + Stmt body1 = BlockRemover::RemoveBlockByPartition(func->body, partitioner.block_partition, + partitioner.allocs1, true); + Stmt body2 = BlockRemover::RemoveBlockByPartition(func->body, partitioner.block_partition, + partitioner.allocs2, false); + // Step 3. Craft the first function. + Array new_params1; + std::vector arg_partition1; + ICHECK_LE(func1_args.size(), partitioner.input1.size()); + for (const auto& buffer : func1_args) { + ICHECK(partitioner.input1.find(buffer) != partitioner.input1.end()); + for (size_t i = 0; i < func->params.size(); i++) { + if (func->buffer_map[func->params[i]].same_as(buffer)) { + new_params1.push_back(func->params[i]); + arg_partition1.push_back(i); + break; + } + } + } + arg_partition->push_back(arg_partition1); + new_params1.push_back(Var("output", DataType::Handle())); + Map new_buffer_map1; + for (const auto& kv : func->buffer_map) { + if (partitioner.input1.count(kv.second)) { + new_buffer_map1.Set(kv.first, kv.second); + } + } + new_buffer_map1.Set(new_params1.back(), partitioner.intermediate_buffer); + PrimFunc func1 = PrimFunc(new_params1, body1, func->ret_type, new_buffer_map1, func->attrs); + func1 = WithAttr(func1, kLibraryKernel, library_code); + // Step 4. Craft the second function. + Array new_params2; + std::vector arg_partition2; + new_params2.push_back(Var("input", DataType::Handle())); + for (int i = 0; i < static_cast(func->params.size()); i++) { + Var param = func->params[i]; + if (partitioner.input2.count(func->buffer_map[param])) { + new_params2.push_back(param); + if (i != static_cast(func->params.size()) - 1) { + arg_partition2.push_back(i); + } + } + } + arg_partition->push_back(arg_partition2); + Map new_buffer_map2; + new_buffer_map2.Set(new_params2[0], partitioner.intermediate_buffer); + for (const auto& kv : func->buffer_map) { + if (partitioner.input2.count(kv.second)) { + new_buffer_map2.Set(kv.first, kv.second); + } + } + PrimFunc func2 = PrimFunc(new_params2, body2, func->ret_type, new_buffer_map2, func->attrs); + return {func1, func2}; +} +} // namespace tir + +namespace relax { +void StringReplace(std::string* subject, const std::string& search, const std::string& replace) { + for (size_t pos = 0; (pos = subject->find(search, pos)) != std::string::npos; + pos += replace.length()) { + subject->replace(pos, search.length(), replace); + } +} + +tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, String global_symbol) { + using namespace tvm::tir; + Optional library_code = pf->attrs.GetAttr(kLibraryKernel); + if (!library_code.defined()) { + return GetRef(pf); + } + std::string source = library_code.value(); + StringReplace(&source, "{global_symbol}", global_symbol); + ExternFunc ret(global_symbol); + ret = WithAttrs(std::move(ret), Map{ + {String(kCSource), String(source)}, + {String(kCSourceFmt), String(kCSourceFmtCuda)}, + }); + return ret; +} + +/*! \brief Emit 2 calls to the library kernel and the rest of the function. */ +class SplitMutator : public ExprMutator { + public: + SplitMutator(const tvm::IRModule& mod, Array patterns, FCodegen fcodegen) + : ExprMutator(mod), mod_(mod), patterns_(patterns), fcodegen_(fcodegen) {} + static IRModule Transform(const IRModule& mod, Array patterns, FCodegen fcodegen) { + SplitMutator mutator(mod, patterns, fcodegen); + for (auto& kv : mod->functions) { + if (auto* func = kv.second.as()) { + Function new_func = Downcast(mutator(GetRef(func))); + mutator.builder_->UpdateFunction(kv.first, new_func); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + using ExprMutator::VisitExpr_; + + inline Array GetCallTIRArgs(Expr args) { + if (args.as()) { + return args.as()->fields; + } else { + return {args}; + } + } + + Expr VisitExpr_(const CallNode* op) final { + Call call = Downcast(ExprMutator::VisitExpr_(op)); + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + static const Op& call_dps_packed_ = Op::Get("relax.call_dps_packed"); + if (!call->op.same_as(call_tir_op_)) return call; + // the first argument is the function to be called + const auto* gv_ptr = call->args[0].as(); + if (gv_ptr == nullptr) return call; + GlobalVar gv = GetRef(gv_ptr); + // retrieve the function from the module and split it + tir::PrimFunc func = Downcast(mod_->Lookup(gv)); + std::vector> arg_partition; + // split the function into two functions, one for the library kernel and one for the rest. + std::pair> split_funcs = + tir::SplitFunctions(func, &arg_partition, patterns_, fcodegen_); + if (!split_funcs.second.defined()) { + // no need to split, the function itself a library kernel + tvm::BaseFunc lib_func = CodegenWithLibrary(split_funcs.first.get(), gv->name_hint); + if (lib_func->IsInstance()) return GetRef(op); + // Update the function in the module with the library kernel + ICHECK(lib_func->IsInstance()); + builder_->UpdateFunction(gv, lib_func); + // emit the call to the library kernel + ObjectPtr new_call = make_object(*call.operator->()); + new_call->op = this->call_dps_packed_; + new_call->args = {lib_func, call->args[1]}; + return Call(new_call); + } + tir::PrimFunc func1 = tir::RenewDefs(split_funcs.first); + tir::PrimFunc func2 = tir::RenewDefs(split_funcs.second.value()); + ICHECK(arg_partition.size() == 2); + // emit the first call to the library kernel + Array args1; + for (int p : arg_partition[0]) { + args1.push_back(GetCallTIRArgs(call->args[1])[p]); + } + // replace the function in the module with the library kernel + tvm::BaseFunc lib_func = CodegenWithLibrary(func1.get(), gv->name_hint); + if (lib_func->IsInstance()) return GetRef(op); + ICHECK(lib_func->IsInstance()); + builder_->UpdateFunction(gv, lib_func); + tir::Buffer intermediate_buffer = func1->buffer_map.at(func1->params.back()); + DataType dtype = intermediate_buffer->dtype; + Call call1(call_dps_packed_, {lib_func, Tuple(args1)}, call->attrs, + {TensorStructInfo(ShapeExpr(intermediate_buffer->shape), dtype)}); + Var call_var1 = builder_->Emit(call1); + // emit the second call to the rest of the function + Array args2; + args2.push_back(call_var1); + for (int p : arg_partition[1]) { + args2.push_back(GetCallTIRArgs(call->args[1])[p]); + } + GlobalVar gv2 = builder_->AddFunction(func2, "unfused_epilogue"); + Call call2(call_tir_op_, {gv2, Tuple(args2)}, call->attrs, call->sinfo_args); + builder_->UpdateFunction(gv, WithoutAttr(func, "global_symbol")); + return call2; + } + + const Op& call_dps_packed_ = Op::Get("relax.call_dps_packed"); + tvm::IRModule mod_; + Array patterns_; + FCodegen fcodegen_; +}; + +namespace transform { +Pass SplitCallTIRByPattern(Array patterns, FCodegen fcodegen) { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { return SplitMutator::Transform(m, patterns, fcodegen); }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"SplitCallTIRByPattern", // + /*required=*/{}); +} +TVM_REGISTER_GLOBAL("relax.transform.SplitCallTIRByPattern").set_body_typed(SplitCallTIRByPattern); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc new file mode 100644 index 000000000000..952513db4c3d --- /dev/null +++ b/src/relax/transform/static_plan_block_memory.cc @@ -0,0 +1,789 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/static_plan_block_memory.cc + * \brief The static memory planning pass on BindingBlock level. + * \details + * The core data structure of the planning pass is StorageToken, which denotes + * reusable memory in this planning pass. + * + * The memory planning pass contains three stages: + * + * The first stage is initialization. A storage token object will be created + * for each builtin alloc_tensor as long as the allocated storage satisfies + * the requirements (which are described in the code). The reference counter + * (i.e., the times of reference) for each token is recorded. + * + * The second stage is allocation planning. We maintain a pool of available + * allocated storage, in the form of storage tokens. For the storage token of + * each builtin alloc_tensor, we check if there is appropriate available token + * in the pool under certain criterion. If there is, we reuse that storage + * for this alloc_tensor. Otherwise, we decide to allocate a storage for the + * alloc_tensor. + * + * The third stage is IR rewrite. Based on the decision made in the second + * stage, we insert memory alloc_storage, alloc_tensor, kill_tensor, and + * kill_storage accordingly. Specifically, we + * - insert alloc_storage before the site that each storage token is firstly + * used, + * - insert memory alloc_tensor for each builtin alloc_tensor, + * - insert kill_tensor after the site that a tensor created by alloc_tensor + * is last referenced, and + * - insert kill_storage at the end of each binding block, for all the storage + * tokens that are allocated inside the binding block, as the memory planning + * only works on block level. + */ +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A representation of a block of reusable memory required at runtime. + * \details Only the tensors whose memory can be "possibly reused" will have + * their storage token. In other words, we do not have storage token for tensor + * - that is a function parameter, + * - that is a function return value, + * - one of whose use site is a BindingBlock different from its allocation site, + * - that is used as a condition or branch return of a IfNode, + * - that is used as the body of a SeqExprNode, + * - that is used as arguments in a Call whose op is not a PrimFunc. + * + * In practice, we do create a storage token for such tensor at first. But at + * any time we find a tensor satisfying any of the conditions above, we erase + * its storage token. + */ +class StorageTokenNode : public Object { + public: + /*! \brief Reference counter. */ + int ref_counter{0}; + /*! \brief Number of bytes that this token requires. */ + int64_t bytes; + /*! \brief The dtype of this token. */ + DataType dtype; + /*! \brief The storage id, reserved for debug and demo use. */ + int storage_id{-1}; + + static constexpr const char* _type_key = "relax.transform.StorageToken"; + TVM_DECLARE_BASE_OBJECT_INFO(StorageTokenNode, Object); +}; + +/*! + * \brief Managed reference to StorageTokenNode. + * \sa StorageTokenNode + */ +class StorageToken : public ObjectRef { + public: + explicit StorageToken(Array shape, DataType dtype) { + // Compute the tensor size from the shape. + int64_t size = 1; + for (const PrimExpr& dim_len : shape) { + const auto* int_len = dim_len.as(); + ICHECK_NOTNULL(int_len); + size *= int_len->value; + } + + ObjectPtr n = make_object(); + n->bytes = size * dtype.bytes() * dtype.lanes(); + n->dtype = dtype; + data_ = std::move(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(StorageToken, ObjectRef, StorageTokenNode); +}; + +// We use NestedMsg to store the tokens used by each Expr. +using Tokens = NestedMsg; + +/*! + * \brief Memory manager for flattened 1d memory (buffers) + * \note We can generalize this implementation to multi-dimensional memory + * following the same flow in the future. + */ +class TokenAllocator1D { + public: + /*! + * \brief Request a storage token from the available token pool for a + * given prototype, or report no appropriate available token in the pool. + * \param prototype The requesting prototype storage token. + * \return The request result token. Return NullOpt if there is no + * appropriate available token in the pool. + */ + Optional RequestReuse(StorageToken prototype) { + // Step 0. Sanity check: the prototype token is supposed not to be allocated with actual storage + ICHECK_EQ(prototype->storage_id, -1) << "The token is expected not to be allocated before."; + // If the prototype has no reference at all, feel free to allocate new storage. + // The unused binding can be removed by cleaning passes. + if (prototype->ref_counter == 0) { + return NullOpt; + } + + // Step 1. Get the available pool of the token dtype. + std::multimap& pool = available_pool_[prototype->dtype]; + + // Step 2. Get the range of memory blocks in [size / match_range_, size * match_range_) + int64_t size = prototype->bytes; + auto begin = pool.lower_bound(size / match_range_); + auto mid = pool.lower_bound(size); + auto end = pool.upper_bound(size * match_range_); + // Step 3. Search for memory block that equals or is larger than the requested size. + if (mid != end) { + StorageToken available_token = mid->second; + ICHECK_EQ(available_token->ref_counter, 0) + << "Available tokens are expected to have 0 reference."; + ICHECK_LE(size, available_token->bytes); + available_token->ref_counter = prototype->ref_counter; + pool.erase(mid); + return available_token; + } + // Step 4. Then search for memory block that is smaller than the requested size. + if (mid != begin) { + --mid; + StorageToken available_token = mid->second; + ICHECK_EQ(available_token->ref_counter, 0) + << "Available tokens are expected to have 0 reference."; + ICHECK_GE(size, available_token->bytes); + // Enlarge the token size. + available_token->bytes = size; + available_token->ref_counter = prototype->ref_counter; + pool.erase(mid); + return available_token; + } + // Return `NullOpt` indicating that no satisfiable storage token is found in the available pool. + return NullOpt; + } + + /*! + * \brief Allocate a storage token for the input prototype token. + * \param prototype The prototype token. + * \param storage_id The id of this token. + */ + StorageToken Alloc(StorageToken prototype, int storage_id) { + // Sanity check: the prototype token is supposed not to be allocated with actual storage yet + ICHECK_EQ(prototype->storage_id, -1) << "The token is expected not to be allocated before."; + prototype->storage_id = storage_id; + full_pool_.push_back(prototype); + return prototype; + } + + /*! + * \brief Release the input token, putting it into the available pool. + * \param token The token to be released. + */ + void Release(StorageToken token) { + // Sanity check: the token has been allocated with actual storage, and should have 0 reference. + ICHECK_GE(token->storage_id, 0) + << "The token to be released is expected to be allocated before"; + ICHECK_EQ(token->ref_counter, 0) << "The token to be released is expected to have 0 reference."; + available_pool_[token->dtype].insert({token->bytes, token}); + } + + private: + /*! \brief A constant scale representing the token search range. */ + const int match_range_{16}; + /*! \brief The pool of available storage tokens for each dtype. */ + std::unordered_map> available_pool_; + /*! \brief All the storage tokens that have been allocated with actual storage. */ + std::vector full_pool_; +}; + +/*! \brief Check if the input op is "relax.reshape". */ +bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); } + +/*! \brief The base class for the storage allocation visitor. */ +class StorageAllocatorBaseVisitor : public ExprVisitor { + protected: + using ExprVisitor::VisitExpr_; + + void VisitBindingBlock_(const BindingBlockNode* block) override { + // We maintain a block stack for token allocation-site and use-site check. + block_stack_.push_back(block); + ExprVisitor::VisitBindingBlock_(block); + ICHECK(!block_stack_.empty()); + ICHECK(block_stack_.back() == block); + block_stack_.pop_back(); + } + + void VisitBinding_(const VarBindingNode* binding) override { + ExprVisitor::VisitBinding_(binding); + // The binding var has the same tokens as the binding value. + SetTokens(binding->var.get(), token_map_[binding->value.get()]); + } + + void VisitExpr_(const TupleNode* tuple) final { + Array tokens; + tokens.reserve(tuple->fields.size()); + for (const Expr& field : tuple->fields) { + Tokens field_tokens = GetTokens(field); + tokens.push_back(field_tokens); + } + SetTokens(tuple, Tokens(tokens)); + } + + void VisitExpr_(const TupleGetItemNode* tuple_item) final { + Tokens tokens = GetTokens(tuple_item->tuple); + // If the tuple has no token, every of its field has no token as well. + if (tokens.IsNull()) { + token_map_[tuple_item] = Tokens(); + return; + } + ICHECK(tokens.IsNested()); + Array field_tokens = tokens.NestedArray(); + ICHECK_GT(static_cast(field_tokens.size()), tuple_item->index); + ICHECK_GE(tuple_item->index, 0); + SetTokens(tuple_item, field_tokens[tuple_item->index]); + } + + /******************** Utilities ********************/ + + Tokens GetTokens(const Expr& expr) { + this->VisitExpr(expr); + return token_map_[expr.get()]; + } + + virtual void SetTokens(const ExprNode* expr, Tokens tokens) { token_map_[expr] = tokens; } + + /*! \brief The mapping from each Expr to its corresponding storage tokens. */ + std::unordered_map token_map_; + /*! \brief The binding block stack. */ + std::vector block_stack_; +}; + +/*! + * \brief The visitor class for storage token initialization. + * \details It goes through the entire function to get the storage tokens + * used by each Expr. After the initialization, we + * - know the tokens that each Expr is using, + * - know the number of references for each token, + * - rule out the builtin alloc_tensors to which the planning does not apply. + */ +class StorageAllocatorInit : public StorageAllocatorBaseVisitor { + public: + /*! + * \brief The entry of the initialization. + * \param mod The IRModule to be planned + * \return The mapping from each Expr to the token it uses. + */ + static std::unordered_map Initialize(const IRModule& mod) { + StorageAllocatorInit initializer(mod); + + for (auto it : mod->functions) { + const auto* func = it.second.as(); + if (func == nullptr) { + continue; + } + initializer(GetRef(func)); + } + return initializer.token_map_; + } + + private: + using ExprVisitor::VisitExpr_; + + explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {} + + void VisitExpr_(const FunctionNode* func) final { + // Recurse into the function to get its tokens. + Tokens body_tokens = GetTokens(func->body); + // Discard the tokens used by the function return value, as they are external referenced. + DiscardTokensIn(body_tokens); + } + + void VisitExpr_(const CallNode* call) final { + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + if (call->op == alloc_tensor_op) { + // Create a storage token for builtin alloc_tensor. + this->CreateToken(call); + return; + } else if (IsReshape(call->op)) { + // Reuse the input's token for builtin reshape. + SetTokens(call, GetTokens(call->args[0])); + return; + } + + // - Increase the reference counters of the arguments when the callee is + // a PrimFunc of the context module or an external function via 'call_packed'. + // It assumes external function calls via 'call_packed' do not retain memory + // from the arguments. + // - Otherwise, discard the tokens used by the arguments, as there might be + // potential external reference. + if (IsPrimFuncGlobalVar(call->op) || call->op->IsInstance()) { + ICHECK(!block_stack_.empty()); + for (const Expr& arg : call->args) { + Tokens tokens = GetTokensWithAllocSiteCheck(arg, block_stack_.back()); + ForEachLeaf(tokens, [](StorageToken token) { token->ref_counter += 1; }); + } + } else { + for (const Expr& arg : call->args) { + DiscardTokensIn(GetTokens(arg)); + } + } + } + + void VisitExpr_(const IfNode* if_node) final { + Tokens cond_tokens = GetTokens(if_node->cond); + Tokens then_tokens = GetTokens(if_node->true_branch); + Tokens else_tokens = GetTokens(if_node->false_branch); + // Discard the tokens used by the condition, then-body and else-body, + // as the planning works on block level. + DiscardTokensIn(cond_tokens); + DiscardTokensIn(then_tokens); + DiscardTokensIn(else_tokens); + } + + void VisitExpr_(const SeqExprNode* seq) final { + for (const BindingBlock& binding_block : seq->blocks) { + this->VisitBindingBlock(binding_block); + } + Tokens body_tokens = GetTokens(seq->body); + // Discard the tokens used by the body, as the planning works on block level. + DiscardTokensIn(body_tokens); + } + + /******************** Utilities ********************/ + + /*! + * \brief Check if the input op is GlobalVar corresponding to a PrimFunc inside the ctx module. + * \param op The op to be checked + * \return A boolean indicating if the input op corresponds to a PrimFunc. + */ + bool IsPrimFuncGlobalVar(const Expr& op) { + const auto* global_var = op.as(); + if (global_var == nullptr) { + return false; + } + auto func_it = ctx_mod_->functions.find(GetRef(global_var)); + if (func_it == ctx_mod_->functions.end()) { + return false; + } + return (*func_it).second->IsInstance(); + } + + /*! + * \brief Create a storage token for the builtin alloc_tensor call. + * \param call The call to be processed. + * \return The created token. + */ + Tokens CreateToken(const CallNode* call) { + // Sanity checks about + // - the call return value is a Tensor; + // - the shape of the tensor is known, in the form of ShapeExpr; + // - the tensor has known dtype; + // - no storage token was created for this call before. + const auto* sinfo = call->struct_info_.as(); + const auto* shape = sinfo->shape.as(); + ICHECK_NOTNULL(sinfo); + ICHECK_NOTNULL(shape); + ICHECK(!sinfo->IsUnknownDtype()); + ICHECK(sinfo->dtype == Downcast(call->args[1])->value); + ICHECK(!token_map_.count(call)); + + // No support for symbolic shape at this moment. + for (const PrimExpr& dim_len : shape->values) { + const auto* int_len = dim_len.as(); + if (!int_len) { + token_map_[call] = Tokens(); + return Tokens(); + } + } + + // Create and set token. + StorageToken token(shape->values, sinfo->dtype); + + Tokens tokens(token); + SetTokens(call, tokens); + ICHECK(!block_stack_.empty()); + token2block_[token.get()] = block_stack_.back(); + return tokens; + } + + /*! + * \brief Override the token setter in the base visitor. + * For each token, we keep record of all Expr that are using that token. + * When we want to discard one token, we use the records to remove the token + * from the Expr that are using it. + */ + void SetTokens(const ExprNode* expr, Tokens tokens) final { + StorageAllocatorBaseVisitor::SetTokens(expr, tokens); + ForEachLeaf(tokens, [this, expr](StorageToken token) { + this->token2exprs_[token.get()].push_back(expr); + }); + } + + /*! + * \brief Token getter with allocation site check. + * We first get the tokens used by the input Expr, and check if the allocation + * site of each token is the input current block. + * Since the planning works on block level, if some token's allocation site + * is not the current block, we discard the token so that it will not be planned. + * \param expr The Expr whose tokens is to be got. + * \param cur_block The pointer to the current block. + * \return The tokens used by the input Expr. + */ + Tokens GetTokensWithAllocSiteCheck(const Expr& expr, const BindingBlockNode* cur_block) { + Tokens tokens = GetTokens(expr); + ForEachLeaf(tokens, [this, cur_block](StorageToken token) { + auto it = this->token2block_.find(token.get()); + ICHECK(it != this->token2block_.end()); + if (it->second != cur_block) { + this->DiscardToken(token); + } + }); + return token_map_[expr.get()]; + } + + /*! \brief Discard the input tokens. */ + void DiscardTokensIn(Tokens tokens) { + ForEachLeaf(tokens, [this](StorageToken token) { this->DiscardToken(token); }); + } + + /*! + * \brief Discard the input token. + * For each Expr that is using the input token, remove the token from the Expr's token set. + * \param token_to_discard The token to be discarded. + */ + void DiscardToken(StorageToken token_to_discard) { + const std::vector& exprs = token2exprs_[token_to_discard.get()]; + for (const ExprNode* expr : exprs) { + token_map_[expr] = MapNestedMsg(token_map_[expr], [token_to_discard](StorageToken token) { + return token.same_as(token_to_discard) ? Tokens() : Tokens(token); + }); + } + token2exprs_.erase(token_to_discard.get()); + token2block_.erase(token_to_discard.get()); + } + + /*! + * \brief The context IRModule, used for checking if a callee function is + * a PrimFunc inside the IRModule. + */ + const IRModule& ctx_mod_; + /*! \brief The mapping from each token to the binding block where it is created. */ + std::unordered_map token2block_; + /*! \brief The mapping from each token to the Exprs that are using this token. */ + std::unordered_map> token2exprs_; +}; + +/*! + * \brief The visitor class for storage token allocation planning. + * \details + * - For each builtin alloc_tensor whose token is not discarded in the + * initialization stage, we request a storage reuse or decide to allocate + * storage for this token, depending on if there is appropriate available + * token in the token pool we maintain. + * - For each VM builtin reshape, we reuse the input's tokens. + * + * After the allocation planning, we + * - know the token that each builtin alloc_tensor plans to use. Compared + * with the initialization, here the token is possibly a reuse of some + * previous token, rather than we having one token for each alloc_tensor. + * - know the last referenced site for each builtin alloc_tensor. This + * information is used for inserting kill_tensor in the rewrite stage. + * - know the tokens allocated in each binding block. This information + * is used for inserting kill_storage in the rewrite stage. + */ +class StorageAllocator : public StorageAllocatorBaseVisitor { + public: + explicit StorageAllocator(std::unordered_map token_map) { + this->token_map_ = std::move(token_map); + } + + void Allocate(const IRModule& mod) { + for (auto it : mod->functions) { + const auto* func = it.second.as(); + if (func == nullptr) { + continue; + } + this->VisitExpr_(func); + } + } + + /*! + * \brief The mapping from each `builtin.alloc_tensor` to its corresponding + * underlying storage token that it is using. + */ + std::unordered_map alloc_tensor2token; + /*! \brief The mapping from each Expr to the tensors that need to be killed after it. */ + std::unordered_map> expr2killed_tensors; + /*! \brief The mapping from each binding block to the storage tokens that are create inside. */ + std::unordered_map> block2tokens; + + private: + using ExprVisitor::VisitBinding_; + using ExprVisitor::VisitExpr_; + + void VisitBindingBlock_(const BindingBlockNode* block) final { + StorageAllocatorBaseVisitor::VisitBindingBlock_(block); + // Sanity check: each token allocated inside the block should not be + // referenced by anyone at the end of the block. + for (const StorageTokenNode* token : block2tokens[block]) { + ICHECK_EQ(token->ref_counter, 0); + } + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + if (call->op == alloc_tensor_op) { + auto it = token_map_.find(call); + ICHECK(it != token_map_.end()); + + if (it->second.IsNull()) { + // IsNull being true means the token was discarded, and this alloc_tensor + // is not considered by the planning. + return; + } + ICHECK(it->second.IsLeaf()); + StorageToken new_token = this->RequestReuseOrAlloc(it->second.LeafValue()); + + // Record that this alloc_tensor is using the token. + alloc_tensor2token.insert({call, new_token}); + token2cur_tensor_[new_token.get()].push_back(binding->var); + SetTokens(call, Tokens(new_token)); + // Record that the token is allocated in the current block. + ICHECK(!block_stack_.empty()); + std::vector& block_tokens = block2tokens[block_stack_.back()]; + if (std::find(block_tokens.begin(), block_tokens.end(), new_token.get()) == + block_tokens.end()) { + block_tokens.push_back(new_token.get()); + } + return; + } else if (IsReshape(call->op)) { + Tokens tokens = GetTokens(call->args[0]); + ICHECK(!tokens.IsNested()); + if (tokens.IsLeaf()) { + // If the input is using a token, record that the reshape uses the token as well. + token2cur_tensor_[tokens.LeafValue().get()].push_back(binding->var); + SetTokens(call, tokens); + } else { + ICHECK(token_map_[call].IsNull()); + } + return; + } + + // Decrease the reference counter by one for each token that the arguments use. + // Check if a token can be released (i.e., has no reference) after decrease. + // And release it if so. + for (const Expr& arg : call->args) { + Tokens tokens = GetTokens(arg); + ForEachLeaf(tokens, [this, call](StorageToken token) { + ICHECK_GT(token->ref_counter, 0); + token->ref_counter -= 1; + this->CheckForRelease(token, call); + }); + } + } + + /*! \brief Request a storage reuse, or allocate storage if no appropriate storage is reusable. */ + StorageToken RequestReuseOrAlloc(StorageToken prototype) { + Optional token = allocator_.RequestReuse(prototype); + if (!token.defined()) { + return allocator_.Alloc(prototype, this->n_storage_++); + } else { + return token.value(); + } + } + + /*! + * \brief Check if a token has no reference and thus can be released. And release it if so. + * \param token The token to be checked. + * \param release_site The CallNode where the the input token is send for release. + * If the token is checked to release here, we keep record of the release site so that + * kill_tensor can be inserted here at the rewrite stage. + */ + void CheckForRelease(StorageToken token, const CallNode* release_site) { + // Sanity check: the token was allocated before and has non-negative reference. + ICHECK_GE(token->storage_id, 0); + ICHECK_GE(token->ref_counter, 0); + + if (token->ref_counter == 0) { + allocator_.Release(token); + auto it = token2cur_tensor_.find(token.get()); + ICHECK(it != token2cur_tensor_.end()); + // Record that the tensors that are using this token will be killed + // immediately after the release site. + std::vector& killed_tensors = expr2killed_tensors[release_site]; + killed_tensors.insert(killed_tensors.end(), it->second.begin(), it->second.end()); + token2cur_tensor_.erase(it); + } + } + + /*! \brief Number of allocated storages. */ + int n_storage_{0}; + /*! \brief The 1D memory allocator. */ + TokenAllocator1D allocator_; + /*! \brief The mapping from each token to the tensors that are currently using it. */ + std::unordered_map> token2cur_tensor_; +}; + +/*! + * \brief The rewriter class based on the token allocation planning. + * \details + * - For each builtin alloc_tensor that was planned, substitute it with a memory + * alloc_tensor. If no memory alloc_storage was created for it before, create one. + * - Insert memory kill_tensor at the release site of each tensor. + * - Insert memory kill_storage at the end of each binding block, for the tokens allocated in it. + */ +class StorageAllocationRewriter : public ExprMutator { + public: + explicit StorageAllocationRewriter( + IRModule mod, std::unordered_map alloc_tensor2token, + std::unordered_map> expr2killed_tensors, + std::unordered_map> + block2tokens) + : ExprMutator(std::move(mod)), + alloc_tensor2token_(std::move(alloc_tensor2token)), + expr2killed_tensors_(std::move(expr2killed_tensors)), + block2tokens_(std::move(block2tokens)) {} + + IRModule Rewrite() { + const IRModule& mod = builder_->GetContextIRModule(); + for (const auto& [gv, base_func] : mod->functions) { + const auto* func_ = base_func.as(); + if (func_ == nullptr) { + continue; + } + token2storage_var_.clear(); + Function func = Downcast(this->VisitExpr_(func_)); + builder_->UpdateFunction(gv, func); + } + return builder_->GetContextIRModule(); + } + + private: + using ExprMutator::VisitExpr_; + + BindingBlock VisitBindingBlock_(const BindingBlockNode* block) final { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + + // Insert `memory.kill_storage` for the storage tokens allocated inside this block. + for (const StorageTokenNode* token : block2tokens_[block]) { + auto it_token = token2storage_var_.find(token); + ICHECK(it_token != token2storage_var_.end()); + static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage"); + this->builder_->Emit(Call(mem_kill_storage, {it_token->second}), /*name_hint=*/"_"); + } + + BindingBlock new_block = builder_->EndBlock(); + return new_block; + } + + void VisitBinding_(const VarBindingNode* binding) final { + ExprMutator::VisitBinding_(binding); + + // Insert `memory.kill_tensor` for the tensors that need to be killed after this binding. + auto it = expr2killed_tensors_.find(binding->value.get()); + if (it != expr2killed_tensors_.end()) { + for (const Var& var : it->second) { + static const Op& mem_kill_tensor = Op::Get("relax.memory.kill_tensor"); + this->builder_->Emit(Call(mem_kill_tensor, {Downcast(this->VisitExpr(var))}), + /*name_hint=*/"_"); + } + } + } + + Expr VisitExpr_(const CallNode* call) final { + auto it = alloc_tensor2token_.find(call); + if (it != alloc_tensor2token_.end()) { + const auto* sinfo = call->struct_info_.as(); + ICHECK_NOTNULL(sinfo); + ICHECK_NOTNULL(sinfo->shape.as()); + PrimValue runtime_device_index = Downcast(call->args[2]); + + // If the token is visited for the first time, create a storage variable using + // `memory.alloc_storage` for it. + StorageToken token = it->second; + Var storage_var{nullptr}; + auto it_token = token2storage_var_.find(token.get()); + if (it_token == token2storage_var_.end()) { + static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage"); + ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)}); + PrimValue virtual_device_index = runtime_device_index; + std::string storage_scope = "global"; + DataType dtype = token->dtype; + Call alloc_storage( + mem_alloc_storage, + {std::move(size), virtual_device_index, StringImm(storage_scope), DataTypeImm(dtype)}, + Attrs()); + storage_var = builder_->Emit(alloc_storage, "storage"); + token2storage_var_[token.get()] = storage_var; + } else { + storage_var = it_token->second; + } + + // And always create a `memory.alloc_tensor` for the old `builtin.alloc_tensor`. + static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor"); + PrimValue offset = PrimValue::Int64(0); + DataType dtype = sinfo->dtype; + return Call(mem_alloc_tensor, {storage_var, offset, sinfo->shape.value(), DataTypeImm(dtype)}, + Attrs()); + } + + return ExprMutator::VisitExpr_(call); + } + + /*! + * \brief The mapping from each memory-reusable `builtin.alloc_tensor` to + its corresponding underlying storage token that it is using. + */ + std::unordered_map alloc_tensor2token_; + /*! \brief The mapping from each Expr to the tensors that need to be killed after it. */ + std::unordered_map> expr2killed_tensors_; + /*! \brief The mapping from each binding block to the storage tokens that are create inside. */ + std::unordered_map> block2tokens_; + /*! \brief The mapping from each token to its corresponding storage var in each function. */ + std::unordered_map token2storage_var_; +}; + +IRModule StaticPlanBlockMemory(IRModule mod) { + // Step 1. Initialize. + std::unordered_map token_map = StorageAllocatorInit::Initialize(mod); + // Step 2. Collect the memory allocation info. + StorageAllocator allocator(std::move(token_map)); + allocator.Allocate(mod); + // Step 3. Rewrite the function. + StorageAllocationRewriter rewriter(std::move(mod), // + std::move(allocator.alloc_tensor2token), + std::move(allocator.expr2killed_tensors), + std::move(allocator.block2tokens)); + return rewriter.Rewrite(); +} + +namespace transform { + +Pass StaticPlanBlockMemory() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relax::StaticPlanBlockMemory(std::move(m)); }; + return CreateModulePass(pass_func, /*opt_level=*/0, "StaticPlanBlockMemory", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.StaticPlanBlockMemory").set_body_typed(StaticPlanBlockMemory); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc new file mode 100644 index 000000000000..e74ca7a4d86d --- /dev/null +++ b/src/relax/transform/to_mixed_precision.cc @@ -0,0 +1,538 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/to_mixed_precision.cc + * \brief Automatic mixed precision pass. + */ + +#include +#include +#include + +#include + +#include "../op/nn/convolution.h" +#include "../op/tensor/datatype.h" +#include "../op/tensor/linear_algebra.h" +#include "infer_amp_utils.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +using runtime::String; + +int GetMixedPrecisionInfo(const CallNode* call_node) { + const OpNode* op_node = call_node->op.as(); + if (op_node == nullptr) { + return -1; + } + Op op = GetRef(op_node); + auto attr_map = Op::GetAttrMap("TMixedPrecisionPolicy"); + return attr_map.count(op) ? attr_map[op] : MixedPrecisionPolicyKind::kNever; +} + +/*! + * \brief Main logic to automatically cast fp32 input modules to fp16 for certain ops. + * + * Structurally speaking, a Relax function is composed of a series of VarBinding and + * MatchCast. And a specific class of VarBindings is the basic unit we want to rewrite. + * Formally, they are of the form: + * + * var = Call(Op, [args], attrs) + * + * where Op is a specific op we want to rewrite, and attrs is the attributes of the op. + * var and args are all exprs with type Tensor or Tuple of Tensors. They might + * be vars, constants, or Tuple of vars and constants. + * Depending on the properties of the op, we may have 3 different ways to rewrite it: + * + * 1. kAlways: Always cast the args to fp16 + * Currently, this is only used for gemm and conv ops (to favor the use of TensorCore) + * We always cast the input args to fp16, and the dtype of the accumulator is configured + * by the global output_dtype parameter (default to fp32). We cast the output to fp16. + * + * 2. kFollow: If any of the args if fp32, cast all args to fp32. Otherwise, use fp16. + * + * 3. kNever: Never cast the args to fp16. Always cast all args to fp32 (the original dtype). + * Some ops, such as softmax, have numerical issues when using fp16. We will always use fp32 + * to ensure the correctness. + * + * Note that in this case, we will actively cast the arg to fp16 only when it's used in kAlways. + * This is to ensure that we have numerical stability to the best effort. + * + * DTypeDecisionCollector: + * Note that if some tensor is only used in kAlways ops, we can store it in fp16 without worsening + * numerical stability or using more storage. We use a backward propagation pass to detect such + * tensors. We will store the information of each var in the only_fp16_map_. + * + * We reuse the NTtype struct to store the information of each var. There are 3 kinds of info: + * - Unknown (Float0): we never encounter a use of this tensor + * - Float16: we only encounter uses of this tensor in kAlways ops + * - Float32: we encounter some use of this tensor outside of kAlways ops + * The info value forms a semi-lattice, where Float8 is the top, Float16 is the middle, and + * Float32 is the bottom. The lower bound of two info values is the one with more bits. + * + * ToMixedPrecisionRewriter: + * We will then use a forward propagation pass to rewrite the program. Since we only keep one + * specific data type for each var, and we will cast the var to the required dtype locally when we + * encounter its use if needed. Note that we may cast the var to some certain dtype multiple + * times, but we decide not to store and reuse the casted copy due to the storage concern and to + * be more friendly to inlining and operator fusion. We will store the var to fp16 if it's only + * used in kAlways ops, otherwise we will store it as the natural output dtype of the op. + * + * The information of each op is registered in the + * Op::GetAttr("FInferMixedPrecision"). The registered function has signature: + * FInferMixedPrecision. We will call the registered function with the original call and the global + * output_dtype parameter. The registered function will return the policy of the op, whether the op + * can adjust the dtype of the accumulator, and the new call node with output_dtype set to the + * global output_dtype parameter. + * + * Key design: wrap_param op + * We need to use fp16 parameters (which appear as constants in the program), but the type + * inference will fail if some parameters are fp16 and some are fp32 in the original module. To + * solve this, we introduce a new op wrap_param, which will wrap the original parameter and cast + * it to fp32 var. + * + * When we encounter the var afterwards, we will directly replace it with the parameter. This + * information is tracked by the const_map_. + */ +class DTypeDecisionCollector : public ExprVisitor { + public: + explicit DTypeDecisionCollector(DataType output_dtype) : output_dtype_(output_dtype) {} + + static VarDTypeMap Collect(Function func, DataType output_dtype) { + DTypeDecisionCollector collector(output_dtype); + collector.VisitExpr(func); + return std::move(collector.only_fp16_map_); + } + + private: + NType GetDType(const Var& var) { + auto it = only_fp16_map_.find(var); + if (it == only_fp16_map_.end()) { + // we never encounter this var before + NType unknown = NTypeFrom(var, unknown_); + only_fp16_map_[var] = unknown; + return unknown; + } + return it->second; + } + + // merge the message for a var + void UpdateVarDTypeMap(const Var& var, const NType& dtype) { + auto it = only_fp16_map_.find(var); + if (it == only_fp16_map_.end()) { + only_fp16_map_[var] = dtype; + } else { + only_fp16_map_[var] = NTypeMerge(it->second, dtype); + } + } + + // merge the message for all vars in the expr list + void RequireArgsToType(Array args, Array to) { + ICHECK(args.size() == to.size()) << "Invalid target dtypes"; + for (size_t i = 0; i < args.size(); ++i) { + auto fvisitleaf = [&](const Expr& expr, NType to) { + if (const auto* var = expr.as()) { + UpdateVarDTypeMap(GetRef(var), to); + } else if (expr->IsInstance()) { + // Constant can be casted anyway, so we don't need to do anything here + return; + } else { + LOG(FATAL) << "Unsupported argument type: " << expr->GetTypeKey(); + } + }; + DecomposeNestedMsg(args[i], to[i], fvisitleaf); + } + } + + // merge the message for all vars in the expr list + void RequireArgsToType(Array args, DataType to) { + std::vector arg_arr; + std::vector to_arr; + for (const Expr& arg : args) { + if (IsNestedTensor(arg)) { + // only require the nested tensor args + arg_arr.push_back(arg); + to_arr.push_back(NTypeFrom(arg, to)); + } + } + RequireArgsToType(std::move(arg_arr), std::move(to_arr)); + } + + void VisitVars_(const VarNode* op) { + Var var = GetRef(op); + if (IsNestedTensor(var)) { + // require the var to be fp32 (its original dtype) + UpdateVarDTypeMap(var, NTypeFrom(var, fp32_)); + return; + } + ExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const VarNode* op) final { VisitVars_(op); } + + void VisitExpr_(const DataflowVarNode* op) final { VisitVars_(op); } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { + auto policy = GetMixedPrecisionInfo(call_node); + if (policy == -1) { + ExprVisitor::VisitBinding_(binding, call_node); + return; + } + if (policy == kAlways) { + // require inputs to be fp16 + RequireArgsToType(call_node->args, fp16_); + } else if (policy == kFollow || policy == kNever) { + // require inputs to be fp32 (the original dtype) + RequireArgsToType(call_node->args, fp32_); + } else { + LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy; + } + } + + void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple_node) final { + // require input fields to be the type of the lhs field respectively + NType lhs_type = GetDType(binding->var); + RequireArgsToType(tuple_node->fields, lhs_type.NestedArray()); + } + + void VisitBinding_(const VarBindingNode* binding, + const TupleGetItemNode* tuple_get_item_node) final { + // require the i-th field rhs tuple to be the type of the lhs + NType lhs_type = GetDType(binding->var); + std::vector require_rhs; + const TupleStructInfoNode* sinfo = + tuple_get_item_node->tuple->struct_info_.as(); + ICHECK(sinfo != nullptr) << "TupleGetItemNode must have TupleStructInfo"; + for (size_t i = 0; i < sinfo->fields.size(); ++i) { + if (i == static_cast(tuple_get_item_node->index)) { + require_rhs.push_back(lhs_type); + } else { + require_rhs.push_back(NTypeFrom(sinfo->fields[i], unknown_)); + } + } + RequireArgsToType({tuple_get_item_node->tuple}, {NType(require_rhs)}); + } + + // override the following methods to visit in backward order + void VisitExpr_(const SeqExprNode* op) final { + this->VisitSpan(op->span); + this->VisitExpr(op->body); + for (auto it = op->blocks.rbegin(); it != op->blocks.rend(); it++) { + this->VisitBindingBlock(*it); + } + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } + } + + void VisitBindingBlock_(const BindingBlockNode* block) { return; } + + void VisitBindingBlock_(const DataflowBlockNode* block) { + for (auto it = block->bindings.rbegin(); it != block->bindings.rend(); it++) { + this->VisitBinding(*it); + } + } + + void VisitExpr_(const IfNode* op) final { + this->VisitSpan(op->span); + this->VisitExpr(op->true_branch); + this->VisitExpr(op->false_branch); + this->VisitExpr(op->cond); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } + } + + DataType unknown_ = DataType(DataType::TypeCode::kFloat, 0, 1); + DataType fp16_ = DataType(DataType::TypeCode::kFloat, 16, 1); + DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1); + DataType output_dtype_; + VarDTypeMap only_fp16_map_; +}; + +class ToMixedPrecisionRewriter : public ExprMutator { + public: + explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DataType output_dtype) + : only_fp16_map_(only_fp16_map), output_dtype_(output_dtype) {} + + private: + Var GetRemapped(const Var& var) { + auto it = var_remap_.find(var->vid); + return it == var_remap_.end() ? var : it->second; + } + + Array RemapArgs(const Array& args) { + Array new_args; + for (const auto& arg : args) { + new_args.push_back(VarReplacer::Replace(arg, var_remap_)); + } + return new_args; + } + + // Util function to rewrite the expr to the given dtype + // rewrite each leaf tensor to the given dtype if necessary + // Note that this function only accepts expr with nested tensor type + Expr RewriteExpr(const Expr& expr, const NType& to) { + auto fvisitleaf = [&](const Expr& expr, std::array to) -> Expr { + const auto* tensor = GetStructInfoAs(expr); + ICHECK(tensor != nullptr) << "Only support rewriting tensor expr"; + // We only rewrite the expr if the dtype is not the same as the given dtype + if (NTypeEqual()(to[0], NTypeFrom(expr))) return expr; + // We only rewrite the expr if the dtype is fp16 or fp32, dtypes such as int32, float64 is not + // supported to be rewritten + if (tensor->dtype != fp16_ && tensor->dtype != fp32_) return expr; + return astype(expr, DataType(String2DLDataType(to[0].LeafValue()))); + }; + return TransformTupleLeaf(expr, std::array({to}), fvisitleaf); + } + + Array RewriteArgs(const Array& args, DataType to) { + Array new_args; + for (const Expr& arg : args) { + if (IsNestedTensor(arg)) { + new_args.push_back(RewriteExpr(arg, NTypeFrom(arg, to))); + } else { + new_args.push_back(arg); + } + } + return new_args; + } + + // Util function to check if any of the tensors in the args is fp32 + bool AnyArgIsFP32(const NType& cur_type) { + bool result = false; + auto fvisitleaf = [&](const String& dtype) { + if (dtype == "float32") { + result = true; + } + }; + ForEachLeaf(cur_type, fvisitleaf); + return result; + } + + bool AnyArgIsFP32(const Array& args) { + for (const Expr& arg : args) { + if (IsNestedTensor(arg)) { + if (AnyArgIsFP32(NTypeFrom(arg))) return true; + } + } + return false; + } + + void CastIfFp16Only(const Var& var) { + ICHECK(builder_->CurrentBlockIsDataFlow()); + // Get the current remapped var + Var cur_var = GetRemapped(var); + // Store the tensors that are fp16 only to fp16 + auto it = only_fp16_map_->find(var); + if (it == only_fp16_map_->end()) return; + // Get the to dtype, cast to fp16 if the var is fp16 only, otherwise do nothing + auto fcombine = [](const String& from, const String& required) -> String { + return required == "float16" ? required : from; + }; + NType from = NTypeFrom(cur_var); + NType to = CombineNestedMsg(from, it->second, fcombine); + Expr rewrite = RewriteExpr(cur_var, to); + // If cur_var is not rewritten, we don't need to emit a new var + if (!rewrite.same_as(cur_var)) { + // Emit a new var, and update the var remap + var_remap_[var->vid] = builder_->Emit(rewrite); + } + } + + Expr VisitVar_(const Var& var) { + // We rewrite the remapped var to the original dtype + auto it = var_remap_.find(var->vid); + if (it != var_remap_.end()) { + return RewriteExpr(it->second, NTypeFrom(var)); + } + return var; + } + + Expr VisitExpr_(const VarNode* op) final { + if (!builder_->CurrentBlockIsDataFlow()) { + return ExprMutator::VisitExpr_(op); + } + return VisitVar_(GetRef(op)); + } + + Expr VisitExpr_(const DataflowVarNode* op) final { + if (!builder_->CurrentBlockIsDataFlow()) { + return ExprMutator::VisitExpr_(op); + } + return VisitVar_(GetRef(op)); + } + + void VisitBinding(const Binding& binding) { + ExprMutator::VisitBinding(binding); + if (!builder_->CurrentBlockIsDataFlow()) return; + CastIfFp16Only(binding->var); + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { + if (!builder_->CurrentBlockIsDataFlow()) { + ExprMutator::VisitBinding_(binding, call_node); + return; + } + auto policy = GetMixedPrecisionInfo(call_node); + if (policy == -1) { + // not an op call + ExprMutator::VisitBinding_(binding, call_node); + return; + } + // var = Call(op) + const auto* op_node = call_node->op.as(); + ICHECK(op_node != nullptr); + Op op = GetRef(op_node); + if (wrap_param_op.same_as(op)) { + // wrap_param + ReEmitBinding(binding, call_node->args[0]); + return; + } + DataType to; + ObjectPtr new_call = make_object(*call_node); + // We first to remap the args to the current vars according to the var_remap_ + new_call->args = std::move(RemapArgs(call_node->args)); + // Then we rewrite the args according to the policy + if (policy == kAlways) { + to = fp16_; + auto attr_map = Op::GetAttrMap("FInferMixedPrecision"); + ICHECK(attr_map.count(op)); + auto f = attr_map[op]; + new_call = make_object(*(f(Call(new_call), output_dtype_).get())); + } else if (policy == kFollow) { + to = AnyArgIsFP32(new_call->args) ? fp32_ : fp16_; + } else if (policy == kNever) { + to = fp32_; + } else { + LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy; + } + new_call->args = std::move(RewriteArgs(new_call->args, to)); + new_call->struct_info_ = NullOpt; + Expr new_value = builder_->Normalize(Call(new_call)); + if (policy == kAlways && binding->var->IsInstance()) { + // kAlways: store the tensors to fp16 + // But global vars will be stored to the original dtype anyway (see below) + new_value = RewriteExpr(new_value, NTypeFrom(new_value, fp16_)); + } + if (!binding->var->IsInstance()) { + // Global var: store the tensors to the original dtype + NType to = NTypeFrom(binding->var); + new_value = RewriteExpr(new_value, to); + } + ReEmitBinding(binding, builder_->Normalize(new_value)); + } + + void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple_node) final { + if (!builder_->CurrentBlockIsDataFlow()) { + ExprMutator::VisitBinding_(binding, tuple_node); + return; + } + ObjectPtr new_tuple = make_object(*tuple_node); + new_tuple->fields = std::move(RemapArgs(tuple_node->fields)); + new_tuple->struct_info_ = NullOpt; + Expr new_value = builder_->Normalize(Tuple(new_tuple)); + if (!binding->var->IsInstance()) { + // Global var: store the tensors to the original dtype + NType to = NTypeFrom(binding->var); + new_value = RewriteExpr(new_value, to); + } + ReEmitBinding(binding, builder_->Normalize(new_value)); + } + + void VisitBinding_(const VarBindingNode* binding, + const TupleGetItemNode* tuple_get_item_node) final { + if (!builder_->CurrentBlockIsDataFlow()) { + // We don't need to rewrite the tuple_get_item in dataflow block + ExprMutator::VisitBinding_(binding, tuple_get_item_node); + return; + } + ObjectPtr new_tuple_get_item = + make_object(*tuple_get_item_node); + new_tuple_get_item->tuple = RemapArgs({tuple_get_item_node->tuple})[0]; + new_tuple_get_item->struct_info_ = NullOpt; + Expr new_value = TupleGetItem(new_tuple_get_item); + if (!binding->var->IsInstance()) { + // Global var: store the tensors to the original dtype + NType to = NTypeFrom(binding->var); + new_value = RewriteExpr(new_value, to); + } + ReEmitBinding(binding, builder_->Normalize(new_value)); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) { + builder_->BeginDataflowBlock(); + // prepare local versions of params here, if they are fp16 expected only + for (auto param : params_) { + CastIfFp16Only(param); + } + for (auto binding : block->bindings) { + this->VisitBinding(binding); + } + for (auto param : params_) { + // remove the local version of params + auto it = var_remap_.find(param->vid); + if (it != var_remap_.end()) { + var_remap_.erase(it); + } + } + return builder_->EndBlock(); + } + + Expr VisitExpr_(const FunctionNode* op) final { + params_ = op->params; + return ExprMutator::VisitExpr_(op); + } + + const VarDTypeMap* only_fp16_map_; + + DataType fp16_ = DataType(DataType::TypeCode::kFloat, 16, 1); + DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1); + DataType output_dtype_; + Array params_; + + const Op& wrap_param_op = Op::Get("relax.wrap_param"); +}; + +Expr ToMixedPrecision(const Function& f, const DataType& out_dtype) { + VarDTypeMap only_fp16_map = std::move(DTypeDecisionCollector::Collect(f, out_dtype)); + ToMixedPrecisionRewriter mutator(&only_fp16_map, out_dtype); + return mutator(f); +} + +namespace transform { + +Pass ToMixedPrecision(const DataType& out_dtype) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(ToMixedPrecision(f, out_dtype)); + }; + return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ToMixedPrecision").set_body_typed(ToMixedPrecision); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc new file mode 100644 index 000000000000..db2e9d7ee5e7 --- /dev/null +++ b/src/relax/transform/to_non_dataflow.cc @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/to_non_dataflow.cc + * \brief Transform all dataflow structure to non-dataflow version. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class ToNonDFMutator : public ExprMutator { + public: + Var VisitVarDef(const Var& var) final { + if (var.as()) { + Var new_var = Var(var->vid, GetStructInfo(var), var->span); + this->var_remap_[var->vid] = new_var; + return new_var; + } + return var; + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } +}; + +Expr ToNonDataflow(const Expr& e) { return ToNonDFMutator().VisitExpr(e); } + +namespace transform { + +Pass ToNonDataflow() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(ToNonDataflow(f)); }; + return CreateFunctionPass(pass_func, 0, "ToNonDataflow", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ToNonDataflow").set_body_typed(ToNonDataflow); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/tuning_api/database.cc b/src/relax/transform/tuning_api/database.cc new file mode 100644 index 000000000000..0d239e5fbf81 --- /dev/null +++ b/src/relax/transform/tuning_api/database.cc @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/tuning_api/database.cc + * \brief Database of tuning APIs. + */ +#include + +#include +#include +#include + +#include "../../../meta_schedule/utils.h" + +namespace tvm { +namespace meta_schedule { + +void JSONFileAppendLine(const String& path, const std::string& line); +std::vector JSONFileReadLines(const String& path, int num_threads, bool allow_missing); + +} // namespace meta_schedule +} // namespace tvm + +namespace tvm { +namespace relax { + +TuningRecord::TuningRecord(Trace trace, Optional> run_secs) { + ObjectPtr n = make_object(); + n->trace = trace; + n->run_secs = run_secs; + this->data_ = n; +} + +ObjectRef TuningRecordNode::AsJSON(bool include_irmod) const { + return Array{trace->AsJSON(include_irmod), // + run_secs}; +} + +TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj) { + Trace trace{nullptr}; + Optional> run_secs{nullptr}; + try { + const ArrayNode* json_array = json_obj.as(); + CHECK(json_array && json_array->size() == 2); + // Load json[0] => trace + { + const ObjectRef& json_trace = json_array->at(0); + trace = Trace::FromJSON(json_trace); + } + + // Load json[1] => run_secs + if (json_array->at(1).defined()) { + run_secs = meta_schedule::AsFloatArray(json_array->at(1)); + } + } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + return TuningRecord(trace, run_secs); +} + +/*! \brief The struct defining comparison function of sorting by mean run seconds. */ +struct SortTuningRecordByMeanRunSecs { + static const constexpr double kMaxMeanTime = 1e10; + + static double Mean(const Array& a) { + if (a.empty()) { + return kMaxMeanTime; + } + double sum = 0.0; + for (const FloatImm& i : a) { + sum += i->value; + } + return sum / a.size(); + } + + bool operator()(const TuningRecord& a, const TuningRecord& b) const { + double a_time = Mean(a->run_secs.value_or({})); + double b_time = Mean(b->run_secs.value_or({})); + return a_time < b_time; + } +}; + +// TODO(tvm-team): Currently, we strictly treat each target separately. +// Since not every option in the target matters, this might be the overkill. +// Revisit this when we have better approach with target equality check. +inline std::string get_database_key(int workload_idx, Target target) { + return std::to_string(workload_idx) + "/" + target->str(); +} + +/*! \brief The default database implementation, which mimics two database tables with two files. + */ +class JSONDatabaseNode : public DatabaseNode { + public: + /*! \brief The path to the workload table */ + String path_workload; + /*! \brief The path to the tuning record table */ + String path_tuning_record; + /*! \brief The path to the measurement table */ + String path_measurement_record; + /*! \brief All the workloads in the database */ + std::unordered_map + workloads2idx_; + /*! \brief All the tuning records in the database */ + std::unordered_map> + tuning_records_; + + /*! \brief Measurement logs in the database */ + std::unordered_map> measurement_records_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("path_workload", &path_workload); + v->Visit("path_tuning_record", &path_tuning_record); + v->Visit("path_measurement_record", &path_measurement_record); + // `workloads2idx_` is not visited + // `tuning_records_` is not visited + // `measurement_records_` is not visited + } + + static constexpr const char* _type_key = "relax.tuning_api.JSONDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode); + + public: + bool HasWorkload(const IRModule& mod) { + return workloads2idx_.find(meta_schedule::Workload(mod, tvm::StructuralHash()(mod))) != + workloads2idx_.end(); + } + + bool HasMeasurementRecord(const meta_schedule::Workload& workload, const Target& target) { + int workload_idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(workload_idx, target); + return measurement_records_.count(key) > 0; + } + + bool HasTuningRecord(const meta_schedule::Workload& workload, const Target& target) { + int workload_idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(workload_idx, target); + return tuning_records_.count(key) > 0; + } + + meta_schedule::Workload CommitWorkload(const IRModule& mod) { + // Try to insert `mod` into `workloads_` + decltype(this->workloads2idx_)::iterator it; + bool inserted = false; + std::tie(it, inserted) = + this->workloads2idx_.emplace(meta_schedule::Workload(mod, tvm::StructuralHash()(mod)), -1); + meta_schedule::Workload workload = it->first; + // If `mod` is new in `workloads2idx_`, append it to the workload file + if (inserted) { + it->second = static_cast(this->workloads2idx_.size()) - 1; + meta_schedule::JSONFileAppendLine(this->path_workload, + meta_schedule::JSONDumps(workload->AsJSON())); + } + return it->first; + } + + void CommitMeasurementRecord(const meta_schedule::Workload& workload, const Target& target, + const Array& run_secs) { + int workload_idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(workload_idx, target); + + if (measurement_records_[key].size() == 0) { + measurement_records_[key] = run_secs; + meta_schedule::JSONFileAppendLine(this->path_measurement_record, + meta_schedule::JSONDumps(Array{ + Integer(workload_idx), target->Export(), + run_secs // + })); + } else { + LOG(WARNING) << "Measurement record for " << key + << " already exists. Use the existing one instead."; + } + } + + void CommitTuningRecord(const meta_schedule::Workload& workload, const Target& target, + const TuningRecord& record) { + int workload_idx = this->workloads2idx_.at(workload); + // There may exist multiple tuning records (with different traces) for a single key pair. + std::string key = get_database_key(workload_idx, target); + this->tuning_records_[key].insert(record); + + meta_schedule::JSONFileAppendLine( + this->path_tuning_record, meta_schedule::JSONDumps(Array{ + Integer(workload_idx), target->Export(), record->AsJSON()})); + } + + Array GetTopK(const meta_schedule::Workload& workload, const Target& target, + int top_k) { + CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; + if (top_k == 0) { + return {}; + } + Array results; + results.reserve(top_k); + int counter = 0; + int idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(idx, target); + for (const TuningRecord& record : this->tuning_records_[key]) { + results.push_back(record); + if (++counter == top_k) { + break; + } + } + + return results; + } + + Array GetMeasurementRecord(const meta_schedule::Workload& workload, + const Target target) { + int workload_idx = this->workloads2idx_.at(workload); + return this->measurement_records_[get_database_key(workload_idx, target)]; + } +}; + +Database Database::JSONDatabase(String path_workload, String path_tuning_record, + String path_measurement_record, bool allow_missing) { + int num_threads = std::thread::hardware_concurrency(); + ObjectPtr n = make_object(); + // Load `n->workloads2idx_` from `path_workload` + std::vector workloads; + { + std::vector json_objs = + meta_schedule::JSONFileReadLines(path_workload, num_threads, allow_missing); + int n_objs = json_objs.size(); + n->workloads2idx_.reserve(n_objs); + workloads.reserve(n_objs); + for (int i = 0; i < n_objs; ++i) { + meta_schedule::Workload workload = meta_schedule::Workload::FromJSON(json_objs[i]); + n->workloads2idx_.emplace(workload, i); + workloads.push_back(workload); + } + } + // Load `n->tuning_records_` from `path_tuning_record` + { + std::vector json_objs = + meta_schedule::JSONFileReadLines(path_tuning_record, num_threads, allow_missing); + + std::vector workload_idxs; + std::vector targets; + std::vector records; + int size = json_objs.size(); + workload_idxs.resize(size, -1); + targets.resize(size, Target{nullptr}); + records.resize(size, TuningRecord{nullptr}); + support::parallel_for_dynamic( + 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { + const ObjectRef& json_obj = json_objs[task_id]; + try { + const ArrayNode* arr = json_obj.as(); + ICHECK_EQ(arr->size(), 3); + workload_idxs[task_id] = Downcast(arr->at(0)).IntValue(); + targets[task_id] = Target(Downcast>(arr->at(1))); + records[task_id] = TuningRecord::FromJSON(arr->at(2)); + } catch (std::runtime_error& e) { + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + }); + + for (int i = 0; i < size; i++) { + std::string key = get_database_key(workload_idxs[i], targets[i]); + n->tuning_records_[key].insert(records[i]); + } + } + + // Load `n->measuremet_log` from `path_measurement_record` + { + std::vector json_objs = + meta_schedule::JSONFileReadLines(path_measurement_record, num_threads, allow_missing); + std::vector workload_idxs; + std::vector targets; + std::vector> measurements; + int size = json_objs.size(); + workload_idxs.resize(size, -1); + targets.resize(size, Target{nullptr}); + measurements.resize(size, Array({})); + support::parallel_for_dynamic( + 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { + const ObjectRef& json_obj = json_objs[task_id]; + try { + const ArrayNode* arr = json_obj.as(); + ICHECK_EQ(arr->size(), 3); + workload_idxs[task_id] = Downcast(arr->at(0)).IntValue(); + targets[task_id] = Target(Downcast>(arr->at(1))); + measurements[task_id] = meta_schedule::AsFloatArray(arr->at(2)); + } catch (std::runtime_error& e) { + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + }); + for (int i = 0; i < size; i++) { + n->measurement_records_[get_database_key(workload_idxs[i], targets[i])] = measurements[i]; + } + } + + n->path_workload = path_workload; + n->path_tuning_record = path_tuning_record; + n->path_measurement_record = path_measurement_record; + return Database(n); +} + +/**************** FFI ****************/ +TVM_REGISTER_NODE_TYPE(TuningRecordNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecord") + .set_body_typed([](Trace trace, Optional> run_secs) { + return TuningRecord(trace, run_secs); + }); +TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecordAsJSON") + .set_body_method(&TuningRecordNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); + +TVM_REGISTER_OBJECT_TYPE(DatabaseNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasWorkload") + .set_body_method(&DatabaseNode::HasWorkload); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasMeasurementRecord") + .set_body_method(&DatabaseNode::HasMeasurementRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasTuningRecord") + .set_body_method(&DatabaseNode::HasTuningRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitMeasurementRecord") + .set_body_method(&DatabaseNode::CommitMeasurementRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitWorkload") + .set_body_method(&DatabaseNode::CommitWorkload); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitTuningRecord") + .set_body_method(&DatabaseNode::CommitTuningRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetTopK") + .set_body_method(&DatabaseNode::GetTopK); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetMeasurementRecord") + .set_body_method(&DatabaseNode::GetMeasurementRecord); + +TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseJSONDatabase").set_body_typed(Database::JSONDatabase); +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/tuning_api/primitives.cc b/src/relax/transform/tuning_api/primitives.cc new file mode 100644 index 000000000000..ef4a3d41bdf0 --- /dev/null +++ b/src/relax/transform/tuning_api/primitives.cc @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/tuning_api/primitives.cc + * \brief Primitives of tuning APIs. + */ + +#include + +#include "../../../meta_schedule/utils.h" +namespace tvm { +namespace relax { + +Choice::Choice(String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args) { + ObjectPtr n = make_object(); + n->transform_func_key = std::move(transform_func_key); + n->transform_func_args = std::move(transform_func_args); + n->constr_func_key = std::move(constr_func_key); + n->constr_func_args = std::move(constr_func_args); + data_ = std::move(n); +} + +// TODO(sunggg): Currently, it only supports an array of primitive data types. +ObjectRef ChoiceNode::AsJSON() const { + Array json_transfrom_args, json_constr_args; + for (ObjectRef arg : this->transform_func_args) { + std::string json_arg = tvm::SaveJSON(arg); + std::string b64_arg = meta_schedule::Base64Encode(json_arg); + json_transfrom_args.push_back(String(b64_arg)); + } + for (ObjectRef arg : this->constr_func_args) { + std::string json_arg = tvm::SaveJSON(arg); + std::string b64_arg = meta_schedule::Base64Encode(json_arg); + json_constr_args.push_back(String(b64_arg)); + } + return Array{ + this->transform_func_key, + json_transfrom_args, + this->constr_func_key, + json_constr_args, + }; +} + +Choice Choice::FromJSON(const ObjectRef& json) { + // Parse `json` into `choice` + String transform_func_key, constr_func_key; + Array transform_func_args, constr_func_args; + try { + const ArrayNode* arr = json.as(); + ICHECK(arr && arr->size() == 4); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + const auto* arr2 = arr->at(2).as(); + const auto* arr3 = arr->at(3).as(); + ICHECK(arr0 && arr1 && arr2 && arr3); + transform_func_key = GetRef(arr0); + { + transform_func_args.reserve(arr1->size()); + for (const ObjectRef& elem : *arr1) { + String b64_arg = Downcast(elem); + std::string json_arg = meta_schedule::Base64Decode(b64_arg); + ObjectRef arg = LoadJSON(json_arg); + transform_func_args.push_back(arg); + } + } + constr_func_key = GetRef(arr2); + { + constr_func_args.reserve(arr3->size()); + for (const ObjectRef& elem : *arr3) { + String b64_arg = Downcast(elem); + std::string json_arg = meta_schedule::Base64Decode(b64_arg); + ObjectRef arg = LoadJSON(json_arg); + constr_func_args.push_back(arg); + } + } + } catch (const tvm::Error& e) { + LOG(FATAL) + << "ValueError: The json entry of a choice should contain a set of two strings, but gets: " + << json; + throw; + } + return Choice(transform_func_key, transform_func_args, constr_func_key, constr_func_args); +} + +Knob::Knob(String name, Map choices) { + ObjectPtr n = make_object(); + n->name = std::move(name); + n->choices = std::move(choices); + data_ = std::move(n); +} + +ObjectRef KnobNode::AsJSON() const { + Map json_choices; + for (auto const& x : choices) { + json_choices.Set(x.first, x.second->AsJSON()); + } + return Array{ + /* 0: name */ std::move(name), + /* 1: choices */ std::move(json_choices), + }; +} + +Knob Knob::FromJSON(const ObjectRef& json) { + // Parse `json` into `name` and `choices` + String name; + Map choices; + try { + const ArrayNode* arr = json.as(); + ICHECK(arr && arr->size() == 2); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + ICHECK(arr0 && arr1); + name = GetRef(arr0); + for (auto const& x : GetRef>(arr1)) { + String decision = x.first; + Choice choice = Choice::FromJSON(x.second); + choices.Set(decision, choice); + } + } catch (const tvm::Error& e) { + LOG(FATAL) + << "ValueError: The json entry of a choice should contain a set of two strings, but gets: " + << json; + throw; + } + return Knob(name, choices); +} + +Trace::Trace() { data_ = make_object(); } + +Trace::Trace(IRModule in_mod, Array knobs, Array decisions) { + ICHECK(knobs.size() == decisions.size()) << "Size of knobs and decisions should match"; + // Deep-copy IRModule + auto func_deepcopy = runtime::Registry::Get("relax.tuning_api.deepcopy_irmodule"); + ICHECK(func_deepcopy); + IRModule out_mod = (*func_deepcopy)(in_mod); + // Apply the decision history if provided + int size = knobs.size(); + for (int i = 0; i < size; i++) { + out_mod = knobs[i]->Apply(out_mod, decisions[i]); + } + + ObjectPtr n = make_object(); + n->in_mod = std::move(in_mod); + n->out_mod = std::move(out_mod); + n->knobs = std::move(knobs); + n->decisions = std::move(decisions); + n->size = std::move(size); + data_ = std::move(n); +} + +ObjectRef TraceNode::AsJSON(bool include_in_mod) const { + ICHECK(this->Verify()) << "Trace should be valid"; + + Array json_knobs; + Array json_decisions; + + int size = this->size; + json_knobs.reserve(size); + json_decisions.reserve(size); + + for (int i = 0; i < size; i++) { + const Knob& knob = this->knobs[i]; + const String& decision = this->decisions[i]; + + json_knobs.push_back(knob->AsJSON()); + json_decisions.push_back(decision); + } + if (include_in_mod) { + std::string json_mod = tvm::SaveJSON(this->in_mod); + std::string b64_mod = meta_schedule::Base64Encode(json_mod); + return Array{json_knobs, json_decisions, String(b64_mod)}; + } else { + return Array{json_knobs, json_decisions}; + } +} + +Trace Trace::FromJSON(const ObjectRef& json) { + // Parse `json` into `trace` + IRModule in_mod; + Array knobs; + Array decisions; + try { + const ArrayNode* arr = json.as(); + // A trace will have 2 or 3 entries depending on `include_irmod` parameter. + ICHECK(arr && (arr->size() == 2 || arr->size() == 3)); + + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + ICHECK(arr0 && arr1); + + for (const ObjectRef& elem : *arr0) { + knobs.push_back(Knob::FromJSON(elem)); + } + + for (const ObjectRef& elem : *arr1) { + decisions.push_back(Downcast(elem)); + } + + // When `include_irmod = true` + if (arr->size() == 3) { + const auto* arr2 = arr->at(2).as(); + String b64_mod = GetRef(arr2); + ICHECK(arr2); + std::string json_mod = meta_schedule::Base64Decode(b64_mod); + in_mod = Downcast(LoadJSON(json_mod)); + } + } catch (const tvm::Error& e) { + LOG(FATAL) << "ValueError: Malformed Trace format - " << json; + throw; + } + return Trace(in_mod, knobs, decisions); +} + +/**************** FFI ****************/ +TVM_REGISTER_NODE_TYPE(ChoiceNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.Choice") + .set_body_typed([](String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args) { + return Choice(transform_func_key, transform_func_args, constr_func_key, constr_func_args); + }); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceAsJSON").set_body_method(&ChoiceNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceFromJSON").set_body_typed(Choice::FromJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetTransformFunc") + .set_body_method(&ChoiceNode::GetTransformFunc); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetConstrFunc") + .set_body_method(&ChoiceNode::GetConstrFunc); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceApplyTransformFunc") + .set_body_method(&ChoiceNode::ApplyTransformFunc); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceCheckConstr") + .set_body_method(&ChoiceNode::CheckConstr); + +TVM_REGISTER_NODE_TYPE(KnobNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.Knob") + .set_body_typed([](String name, Map choices) { return Knob(name, choices); }); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobAsJSON").set_body_method(&KnobNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobFromJSON").set_body_typed(Knob::FromJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobIsValidDecision") + .set_body_method(&KnobNode::IsValidDecision); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobApply").set_body_method(&KnobNode::Apply); + +TVM_REGISTER_NODE_TYPE(TraceNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.Trace") + .set_body_typed([](IRModule in_mod, Array knobs, Array decisions) { + return Trace(in_mod, knobs, decisions); + }); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceVerify").set_body_method(&TraceNode::Verify); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceAdd").set_body_method(&TraceNode::Add); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceSetPerf").set_body_method(&TraceNode::SetPerf); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceSetOutMod") + .set_body_method(&TraceNode::SetOutMod); + +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceAsJSON").set_body_method(&TraceNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceFromJSON").set_body_typed(Trace::FromJSON); +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/utils.cc b/src/relax/transform/utils.cc new file mode 100644 index 000000000000..9a19115f6274 --- /dev/null +++ b/src/relax/transform/utils.cc @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "utils.h" + +namespace tvm { +namespace relax { + +bool IsNestedTensor(const StructInfo& sinfo) { + return IsNestedTensorConditioned(sinfo, [](const TensorStructInfo& sinfo) { return true; }); +} + +bool IsNestedTensor(const Expr& expr) { return IsNestedTensor(GetStructInfo(expr)); } + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h new file mode 100644 index 000000000000..d51fe5310146 --- /dev/null +++ b/src/relax/transform/utils.h @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/utils.h + * \brief Additional utility classes and functions for working with the Relax IR. + */ +#ifndef TVM_RELAX_TRANSFORM_UTILS_H_ +#define TVM_RELAX_TRANSFORM_UTILS_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../relay/analysis/graph_partitioner.h" +#include "../../support/array.h" +#include "../op/nn/convolution.h" +#include "../op/nn/nn.h" +#include "../op/nn/pooling.h" +#include "../op/tensor/binary.h" +#include "../op/tensor/create.h" +#include "../op/tensor/datatype.h" +#include "../op/tensor/index.h" +#include "../op/tensor/linear_algebra.h" +#include "../op/tensor/manipulate.h" +#include "../op/tensor/search.h" +#include "../op/tensor/set.h" +#include "../op/tensor/statistical.h" +#include "../op/tensor/ternary.h" +#include "../op/tensor/unary.h" + +namespace tvm { +namespace relax { + +/*! + * \brief A simple wrapper around ExprFunctor for a single argument case. + * The result of visit is memoized. + */ +template +class MemoizedExprTranslator : public ::tvm::relax::ExprFunctor { + using BaseFunctor = ::tvm::relax::ExprFunctor; + + public: + /*! \brief virtual destructor */ + virtual ~MemoizedExprTranslator() {} + + /*! + * \brief The memoized call. + * \param n The expression node. + * \return The result of the call + */ + virtual OutputType VisitExpr(const Expr& n) { + ICHECK(n.defined()); + auto it = memo_.find(n); + if (it != memo_.end()) { + return it->second; + } + auto res = BaseFunctor::VisitExpr(n); + memo_[n] = res; + return res; + } + + virtual OutputType VisitExpr_(const VarNode* vn) { + ICHECK(memo_.count(GetRef(vn))); + return memo_[GetRef(vn)]; + } + + virtual OutputType VisitBinding_(const VarBindingNode* binding) { + ICHECK_EQ(memo_.count(binding->var), 0); + auto v = VisitExpr(binding->value); + memo_[binding->var] = v; + return v; + } + + protected: + /*! \brief Internal map used for memoization. */ + std::unordered_map memo_; +}; + +/*! + * \brief Dead code elimination + * Currently it removes: + * 1. Unused local VarBindings in a DataflowBlock. + * The used var set is set to empty at the beginning of each DataflowBlock. + * We reverse scan the DataflowBlock, if a VarBinding + * - bindings to a dataflowvar, or + * - is used in the used var set + * We keep it and add its var to the used var set. Otherwise, we remove it. + * 2. Unused Relax functions in the module. + * We detect the call chain from the entry function, and remove all unused functions. + * \param mod The target module + * \param entry_functions list of entry functions + * \return The updated module. + */ +TVM_DLL IRModule DeadCodeElimination(const IRModule& mod, Array entry_funcs); + +/*! + * \brief Get the external symbol of the Relax function name. + * + * \param func The provided function. + * \return An external symbol. + */ +inline std::string GetExtSymbol(const Function& func) { + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(name_node.defined()) << "Fail to retrieve external symbol."; + return std::string(name_node.value()); +} + +/*! + * \brief Fuse ops or functions according to the given partition, and grouped them into a new + * function. + * + * \param mod The input module. + * \param partition A mapping from a subexpression to the containing group. + * \param lift_constants Whether or not to lift bound constants to parameters of the + * grouped function. + * \return A new module containing grouped functions. + */ +IRModule MakeGroupedFunctions( + IRModule mod, + const std::unordered_map& partition, + bool lift_constants = true); + +/*! + * \brief Check if the given StructInfo is a nested tensor StructInfo satisfying the given + * condition f_condition. + * \param sinfo The StructInfo to be checked. + * \param f_condition The condition function for each leaf StructInfo with signature + * `bool f_condition(TensorStructInfo)`. + * \tparam FType The condition function type. + * \return true if the given StructInfo is a nested tensor satisfying the given f_condition. + */ +template +bool IsNestedTensorConditioned(const StructInfo& sinfo, FType f_condition) { + if (const auto* tensor_sinfo = sinfo.as()) { + return f_condition(GetRef(tensor_sinfo)); + } else if (const auto* tuple_sinfo = sinfo.as()) { + return !std::any_of( + tuple_sinfo->fields.begin(), tuple_sinfo->fields.end(), + [&](const StructInfo& field) { return !IsNestedTensorConditioned(field, f_condition); }); + } + return false; +} + +/*! + * \brief Check if the given StructInfo is a nested tensor. + * \param sinfo The StructInfo to be checked. + * \return true if the given StructInfo is a nested tensor. + */ +bool IsNestedTensor(const StructInfo& sinfo); + +/*! + * \brief Check if the given expr is a nested tensor. + * \param expr The expr to be checked. + * \return true if the given expr is a nested tensor. + */ +bool IsNestedTensor(const Expr& expr); + +// TODO(@bohan): implements some postorder function accepts a visitor closure +class VarReplacer : public ExprMutator { + public: + using VarMap = std::unordered_map; + + explicit VarReplacer(const VarMap& var_remap) : var_remap_(var_remap) {} + + static Expr Replace(const Expr& expr, const VarMap& var_remap) { + VarReplacer replacer(var_remap); + return replacer(expr); + } + + private: + Expr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + auto it = var_remap_.find(var->vid); + return it == var_remap_.end() ? var : it->second; + } + + Expr VisitExpr_(const DataflowVarNode* op) final { + Var var = GetRef(op); + auto it = var_remap_.find(var->vid); + return it == var_remap_.end() ? var : it->second; + } + + const VarMap& var_remap_; +}; + +/*! + * \brief Create a Constant with a scalar + * + * \param dtype The data type. + * \param value The value of the scalar. + * \return A Constant. + */ +template +inline Constant MakeConstantScalar(T value, DataType dtype) { + runtime::NDArray arr = runtime::NDArray::Empty({}, dtype, {kDLCPU, 0}); + if (dtype == DataType::Float(32)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::Float(64)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::Int(32)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::Int(64)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::UInt(1)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::UInt(8)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::UInt(16)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::UInt(32)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::UInt(64)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::Int(8)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::Int(16)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::Int(32)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::Int(64)) { + *static_cast(arr->data) = static_cast(value); + } else if (dtype == DataType::Float(16)) { + // convert to float16 storage is uint16_t + *static_cast(arr->data) = + __truncXfYf2__(static_cast(value)); + } else if (dtype == DataType::BFloat(16)) { + // convert to bfloat16 storage is uint16_t + *static_cast(arr->data) = + __truncXfYf2__(static_cast(value)); + } else { + LOG(FATAL) << "Unsupported dtype " << dtype; + } + return Constant(arr); +} + +inline Array GetOrderedPositiveAxes(const Array& axes, int ndim) { + std::vector ret; + ret.reserve(axes.size()); + for (const auto& axis : axes) { + int64_t axis_val = axis->value; + if (axis_val < 0) { + axis_val += ndim; + } + ICHECK(axis_val >= 0 && axis_val < ndim) << "axis " << axis << " is out of bounds for array of " + << "dimension " << ndim; + ret.push_back(axis_val); + } + std::sort(ret.begin(), ret.end()); + return support::AsArray(ret); +} + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_TRANSFORM_UTILS_H_ diff --git a/src/relax/utils.cc b/src/relax/utils.cc new file mode 100644 index 000000000000..cf1d9bed98c1 --- /dev/null +++ b/src/relax/utils.cc @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Helper to implement bind params.*/ +class ExprBinder : public ExprMutator { + public: + explicit ExprBinder(const tvm::Map& args_map, + const tvm::Map& symbolic_var_map) + : args_map_(args_map), symbolic_var_map_(symbolic_var_map) {} + + private: + Expr VisitExpr_(const FunctionNode* op) final { + tvm::Array params; + bool all_params_unchanged = true; + for (const Var& param : op->params) { + if (args_map_.count(param)) { + all_params_unchanged = false; + } else { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + if (!param.same_as(new_param)) { + this->var_remap_[param->vid] = new_param; + all_params_unchanged = false; + } + } + } + + Expr body = this->VisitWithNewScope(op->body, params); + + // FuncStructInfo does not depend on Expr + if (all_params_unchanged && body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(params, body, VisitExprDepStructInfoField(op->ret_struct_info), op->attrs); + } + } + + Expr VisitExpr_(const VarNode* op) final { + auto id = GetRef(op); + auto it = args_map_.find(id); + if (it != args_map_.end()) { + return (*it).second; + } else { + return ExprMutator::VisitExpr_(op); + } + } + + PrimExpr VisitPrimExpr(const PrimExpr& expr) final { + if (const tir::VarNode* var = expr.as()) { + auto it = symbolic_var_map_.find(GetRef(var)); + if (it != symbolic_var_map_.end()) { + return (*it).second; + } + } + return ExprMutator::VisitPrimExpr(expr); + } + + private: + const tvm::Map& args_map_; + const tvm::Map& symbolic_var_map_; +}; + +/*! + * \brief Bind params on expr + * \param expr The expr where to bind params + * \param binds The map from param var to the expr it binds to + * \param symbolic_var_map The map from symbolic var to the expr it binds to + * \return The result expr after bind params + */ +Expr Bind(const Expr& expr, const tvm::Map& binds, + const tvm::Map& symbolic_var_map) { + return ExprBinder(binds, symbolic_var_map).VisitExpr(expr); +} + +bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank, + bool permit_unknown_dtype) { + const TensorStructInfoNode* tt = sinfo.as(); + if (!tt) { + return false; + } + bool correct_dtype = tt->dtype.is_bool() || (permit_unknown_dtype && tt->dtype.is_void()); + bool correct_rank = tt->ndim == 0 || (permit_unknown_rank && tt->ndim == -1); + return correct_dtype && correct_rank; +} + +bool IsLeafOrTuple(const Expr& expr) { + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as(); +} + +class FunctionCopier : public ExprMutator { + public: + static Function Transform(Function func) { + FunctionCopier copier; + // All variables that are bound inside the original function would be copied + // to satisfy the restriction in the well-formed check: Variables in Relax + // must be bound exactly once. + return Downcast(copier.VisitExpr(func)); + } + + Var VisitVarDef_(const DataflowVarNode* var) override { + Var new_var = ExprMutator::VisitVarDef_(var); + Var copied_var = DataflowVar(new_var->name_hint(), GetStructInfo(new_var), new_var->span); + var_remap_[var->vid] = copied_var; + return copied_var; + } + + Var VisitVarDef_(const VarNode* var) override { + Var new_var = ExprMutator::VisitVarDef_(var); + Var copied_var = Var(new_var->name_hint(), GetStructInfo(new_var), new_var->span); + var_remap_[var->vid] = copied_var; + return copied_var; + } +}; + +Function CopyWithNewVars(Function func) { return FunctionCopier::Transform(func); } + +TVM_REGISTER_GLOBAL("relax.CopyWithNewVars").set_body_typed(CopyWithNewVars); + +} // namespace relax +} // namespace tvm diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index db8e0329d943..cdbfbed8db89 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -133,7 +133,7 @@ class CodegenCBase { * \brief Gerenate C code for the external function. * * \param func_name The name of the external function. - * \param args arguments to the external function. + * \param arg_types Types of arguments represented as string * * \code * @@ -160,14 +160,14 @@ class CodegenCBase { * * \endcode */ - void GenerateBackendCFunc(const std::string& func_name, const Array& args, + void GenerateBackendCFunc(const std::string& func_name, const std::vector& arg_types, const std::string& const_arr_name, const std::vector& outs, bool pass_dl_tensor = false) { // Print signature code_stream_ << "\n"; code_stream_ << "int " << func_name << "_wrapper_("; - for (size_t i = 0; i < args.size(); i++) { + for (size_t i = 0; i < arg_types.size(); i++) { code_stream_ << "DLTensor* arg" << i << ",\n"; code_stream_ << "\t"; } @@ -182,12 +182,11 @@ class CodegenCBase { // Generate the internal call. PrintIndents(); code_stream_ << func_name << "_("; - for (size_t i = 0; i < args.size(); i++) { + for (size_t i = 0; i < arg_types.size(); i++) { if (pass_dl_tensor) { code_stream_ << "arg" << i << ",\n"; } else { - const auto& dtype_str = GetDtypeString(args[i]); - code_stream_ << "(" << dtype_str << "*)(arg" << i << "->data),\n"; + code_stream_ << "(" << arg_types[i] << "*)(arg" << i << "->data),\n"; } PrintIndents(); } @@ -212,21 +211,21 @@ class CodegenCBase { // Create the external function PrintRuntimeFunctionHeader(func_name); EnterScope(); - for (size_t i = 0; i < args.size(); i++) { + for (size_t i = 0; i < arg_types.size(); i++) { PrintArgToData(i); } for (size_t i = 0; i < outs.size(); i++) { - PrintRetToData(args.size() + i); + PrintRetToData(arg_types.size() + i); } PrintIndents(); code_stream_ << func_name << "_wrapper_("; - for (size_t i = 0; i < args.size(); i++) { + for (size_t i = 0; i < arg_types.size(); i++) { code_stream_ << "arg" << i << ","; } for (size_t i = 0; i < outs.size() - 1; i++) { - code_stream_ << "ret" << args.size() + i << ","; + code_stream_ << "ret" << arg_types.size() + i << ","; } - code_stream_ << "ret" << args.size() + outs.size() - 1 << ");\n"; + code_stream_ << "ret" << arg_types.size() + outs.size() - 1 << ");\n"; PrintIndents(); code_stream_ << "return 0;\n"; ExitScope(); @@ -256,6 +255,16 @@ class CodegenCBase { } } + void GenerateBackendCFunc(const std::string& func_name, const Array& args, + const std::string& const_arr_name, const std::vector& outs, + bool pass_dl_tensor = false) { + std::vector arg_types; + for (size_t i = 0; i < args.size(); i++) { + arg_types.push_back(GetDtypeString(args[i])); + } + return GenerateBackendCFunc(func_name, arg_types, const_arr_name, outs, pass_dl_tensor); + } + /*! * \brief Emit the code for external runtime. * @@ -370,6 +379,10 @@ class CodegenCBase { dtype = "int"; } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) { dtype = "int64_t"; + } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 8)) { + dtype = "int8_t"; + } else if (runtime::TypeMatch(ttype->dtype, kDLUInt, 8)) { + dtype = "uint8_t"; } else { LOG(FATAL) << "Unsupported dtype " << ttype->dtype; } diff --git a/src/relay/backend/contrib/cutlass/codegen.h b/src/relay/backend/contrib/cutlass/codegen.h index e70e97a2fafa..03b8e6afbddc 100644 --- a/src/relay/backend/contrib/cutlass/codegen.h +++ b/src/relay/backend/contrib/cutlass/codegen.h @@ -27,6 +27,11 @@ #include +#include +#include + +#include "../codegen_c/codegen_c.h" + namespace tvm { namespace relay { namespace contrib { @@ -40,6 +45,22 @@ namespace cutlass { */ transform::Pass CompileForCutlass(); +// The rest is sparsely documented since they are exposed only for code sharing between Relay +// and Relax backend implementations. + +/*! \brief Emit the function signature for a kernel */ +std::string EmitSignature(const std::vector& out, + const std::string& func_id, const std::vector& arg_names); + +/*! \brief Generate the body of the kernel */ +GenerateBodyOutput GenerateBody(const std::string& func_name, const std::string& ext_func_id, + const std::vector& output_types, + const Array& func_args, const Map& attrs, + int* buf_idx); + +/*! \brief Create a C-source module from the given kernel string */ +runtime::Module Finalize(const std::string& code, const Array& func_names); + } // namespace cutlass } // namespace contrib } // namespace relay diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index f009bda9cd98..f7af74c4dbe0 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -444,6 +444,13 @@ TVM_REGISTER_GLOBAL("relay.backend.tir_converter.allow_extern") return DefaultTIRConverterImpl(args, constants, true); }); +TVM_REGISTER_GLOBAL("relay.backend.GetPassPrefixSeq") + .set_body_typed([](bool is_homogeneous, bool is_vm) { + auto pass_seqs = GetPassPrefix(is_homogeneous, is_vm); + transform::Sequential seq(pass_seqs); + return seq; + }); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index fc1f3a15077e..dd31a1f7367d 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -154,8 +154,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return FunctionPass(pass_func, pass_info); } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index d2eb48073f7d..a152bbe9c3cb 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -950,7 +950,7 @@ TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const E }); Pass InferType() { - auto pass_info = PassInfo(0, "InferType", {}); + auto pass_info = PassInfo(0, "InferType", {}, /* trace */ false); return tvm::transform::CreateModulePass( [=](IRModule mod, const PassContext& pass_ctx) { // Execute the pass function and return a new module. diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index eed41dfc2b99..e4687bb20d38 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -82,7 +82,10 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& int ret_type_code = kTVMNullptr; int ret = (*faddr)(const_cast(args.values), const_cast(args.type_codes), args.num_args, &ret_value, &ret_type_code, nullptr); - ICHECK_EQ(ret, 0) << TVMGetLastError(); + // NOTE: important to keep the original error message. + if (ret != 0) { + LOG(FATAL) << TVMGetLastError(); + } if (ret_type_code != kTVMNullptr) { *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); } diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 1e81ac1bbb34..a89ef8b5efff 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -277,6 +277,15 @@ Module MetalModuleCreate(std::string data, std::string fmt, return Module(n); } +TVM_REGISTER_GLOBAL("runtime.module.create_metal_module") + .set_body_typed([](std::string data, std::string fmap_json) { + std::istringstream stream(fmap_json); + std::unordered_map fmap; + dmlc::JSONReader reader(&stream); + reader.Read(&fmap); + return MetalModuleCreate(data, "metal", fmap, ""); + }); + // Load module from module. Module MetalModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 298fd588d5e1..5b7b3837d0df 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -184,6 +184,10 @@ TVM_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int return mod->imports().at(index); }); +TVM_REGISTER_GLOBAL("runtime.ModuleClearImports").set_body_typed([](Module mod) { + mod->ClearImports(); +}); + TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { return std::string(mod->type_key()); }); diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc new file mode 100644 index 000000000000..5a7c1d662055 --- /dev/null +++ b/src/runtime/relax_vm/builtin.cc @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/runtime/relax_vm/builtin.cc + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../runtime_base.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +using tvm::runtime::NDArray; + +//------------------------------------------------- +// Shape/StructInfo handling. +//------------------------------------------------- +/*! + * \brief Builtin function to allocate shape heap. + * \param ctx_ptr The context module pointer. + * \param size the size of the heap. + * \return An allocate NDArray as shape heap. + */ +NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) { + VirtualMachine* vm = static_cast(ctx_ptr); + // use host allocator, which is always last element. + size_t host_device_index = vm->devices.size() - 1; + // specialy handle hexagon on-device RT. + // TODO(relax-team): visit and consider other possible choices. + if (vm->devices[0].device_type == kDLHexagon) { + host_device_index = 0; + } + auto* alloc = vm->allocators[host_device_index]; + return alloc->Empty({size}, DLDataType{kDLInt, 64, 1}, vm->devices[host_device_index]); +} + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap").set_body_typed(AllocShapeHeap); + +/*! + * \brief Builtin match shape function. + * \param args The packed function arguments. + * \param rv The return value. + * + * \sa MatchShapeCode + */ +void MatchShape(TVMArgs args, TVMRetValue* rv) { + // input shape the first argument can take in tensor or shape. + ShapeTuple input_shape; + if (args[0].IsObjectRef()) { + input_shape = args[0].operator NDArray().Shape(); + } else { + input_shape = args[0]; + } + DLTensor* heap = args[1]; + int64_t* heap_data = heap == nullptr ? nullptr : static_cast(heap->data); + int64_t size = args[2]; + const int64_t kBeginCode = 3; + ICHECK_LE(kBeginCode + size * 2, args.size()); + // a function that lazily get context for error reporting + const int64_t kErrorContextOffset = kBeginCode + size * 2; + Optional err_ctx = args[kErrorContextOffset]; + + CHECK_EQ(input_shape.size(), size) + << "RuntimeError: " << err_ctx.value_or("") << " match_cast shape size mismatch."; + + for (int64_t i = 0; i < size; ++i) { + MatchShapeCode code = static_cast(args[kBeginCode + i * 2].operator int()); + int64_t reg = args[kBeginCode + i * 2 + 1]; + + if (code == MatchShapeCode::kAssertEqualToImm) { + CHECK_EQ(input_shape[i], reg) + << "RuntimeError: " << err_ctx.value_or("") << " match_cast error, " + << " shape[" << i << "]" + << " mismatch to specified constant."; + } else if (code == MatchShapeCode::kStoreToHeap) { + heap_data[reg] = input_shape[i]; + } else if (code == MatchShapeCode::kNoOp) { + } else { + ICHECK(code == MatchShapeCode::kAssertEqualToLoad); + CHECK_EQ(input_shape[i], heap_data[reg]) + << "RuntimeError: " << err_ctx.value_or("") << " match_cast error, " + << " shape[" << i << "]" + << " mismatch to a previous populated value."; + } + } +} + +TVM_REGISTER_GLOBAL("vm.builtin.match_shape").set_body(MatchShape); + +/*! + * \brief Builtin make shape function. + * \param args The packed function arguments. + * \param rv The return value. + * + * \sa MakeShapeCode + */ +void MakeShape(TVMArgs args, TVMRetValue* rv) { + // NOTE: heap can be nullptr + DLTensor* heap = args[0]; + int64_t* heap_data = heap == nullptr ? nullptr : static_cast(heap->data); + int64_t size = args[1]; + const int64_t kBeginCode = 2; + + std::vector shape(size); + + for (int64_t i = 0; i < size; ++i) { + MakeShapeCode code = static_cast(args[kBeginCode + i * 2].operator int()); + int64_t reg = args[kBeginCode + i * 2 + 1]; + if (code == MakeShapeCode::kUseImm) { + shape[i] = reg; + } else { + ICHECK(code == MakeShapeCode::kLoadShape); + shape[i] = heap_data[reg]; + } + } + *rv = ShapeTuple(std::move(shape)); +} + +TVM_REGISTER_GLOBAL("vm.builtin.make_shape").set_body(MakeShape); + +/*! + * \brief Builtin function to check if arg is Tensor(dtype, ndim) + * \param arg The input argument. + * \param ndim Expected ndim of the Tensor, can be -1 (indicate unknown). + * \param dtype The expected content data type. + * \param err_ctx Additional context if error occurs. + */ +void CheckTensorInfo(TVMArgs args, TVMRetValue* rv) { + ObjectRef arg = args[0]; + int ndim = args[1]; + DataType dtype; + Optional err_ctx; + + if (args.size() == 3) { + dtype = DataType::Void(); + err_ctx = args[2].operator Optional(); + } else { + dtype = args[2]; + err_ctx = args[3].operator Optional(); + } + + auto* ptr = arg.as(); + CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Tensor but get " + << arg->GetTypeKey(); + + if (ndim != -1) { + CHECK(ptr->dl_tensor.ndim == ndim) + << "ValueError: " << err_ctx.value_or("") << " expect Tensor with ndim " << ndim + << " but get " << ptr->dl_tensor.ndim; + } + + if (dtype != DataType::Void()) { + CHECK(DataType(ptr->dl_tensor.dtype) == dtype) + << "ValueError: " << err_ctx.value_or("") << " expect Tensor with dtype " << dtype + << " but get " << ptr->dl_tensor.dtype; + } +} + +TVM_REGISTER_GLOBAL("vm.builtin.check_tensor_info").set_body(CheckTensorInfo); + +/*! + * \brief Builtin function to check if arg is Shape(ndim) + * \param arg The input argument. + * \param ndim Expected size of the shape, can be -1 (indicate unknown). + * \param err_ctx Additional context if error occurs. + */ +void CheckShapeInfo(ObjectRef arg, int ndim, Optional err_ctx) { + // a function that lazily get context for error reporting + auto* ptr = arg.as(); + CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Shape but get " + << arg->GetTypeKey(); + if (ndim != -1) { + CHECK(ptr->size == static_cast(ndim)) + << "ValueError: " << err_ctx.value_or("") << " expect Shape with ndim " << ndim + << " but get " << ptr->size; + } +} + +TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo); + +/*! + * \brief Builtin function to check if arg is Tuple with size elements. + * \param arg The input argument. + * \param size The expected size of the tuple. + * \param err_ctx Additional context if error occurs. + */ +void CheckTupleInfo(ObjectRef arg, int64_t size, Optional err_ctx) { + // a function that lazily get context for error reporting + auto* ptr = arg.as(); + CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Tuple but get " + << arg->GetTypeKey(); + CHECK(static_cast(ptr->size()) == size) + << "ValueError: " << err_ctx.value_or("") << " expect a Tuple with " << size << " elements, " + << " but get a Tuple with " << ptr->size() << " elements."; +} + +TVM_REGISTER_GLOBAL("vm.builtin.check_tuple_info").set_body_typed(CheckTupleInfo); + +/*! + * \brief Builtin function to check if arg is a callable function. + * \param arg The input argument. + * \param err_ctx Additional context if error occurs. + */ +void CheckFuncInfo(ObjectRef arg, Optional err_ctx) { + // a function that lazily get context for error reporting + bool is_func = arg.as() || arg.as(); + CHECK(is_func) << "TypeError: " << err_ctx.value_or("") << " expect a Function but get " + << arg->GetTypeKey(); +} + +TVM_REGISTER_GLOBAL("vm.builtin.check_func_info").set_body_typed(CheckFuncInfo); + +//------------------------------------------------- +// Storage management. +//------------------------------------------------- +Storage VMAllocStorage(void* ctx_ptr, ShapeTuple buffer_size, Index device_index, + DLDataType dtype_hint) { + VirtualMachine* vm = static_cast(ctx_ptr); + + ICHECK_EQ(buffer_size.size(), 1); + int alignment = runtime::kAllocAlignment; + ICHECK_LT(device_index, vm->devices.size()) + << "The device index is out of VM physical devices list"; + + if (device_index == -1) { + // Allocate on host. Host is always the last element of vm->devices. + device_index = vm->devices.size() - 1; + } + + int64_t size_imm = buffer_size[0]; + + auto storage_obj = runtime::SimpleObjAllocator().make_object(); + auto* alloc = vm->allocators[device_index]; + ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?"; + storage_obj->buffer = alloc->Alloc(size_imm, alignment, dtype_hint); + Storage storage(storage_obj); + return storage; +} + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage").set_body_typed(VMAllocStorage); + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_tensor").set_body_method(&StorageObj::AllocNDArray); + +//------------------------------------------------- +// Closure function handling, calling convention +//------------------------------------------------- +TVM_REGISTER_GLOBAL("vm.builtin.make_closure").set_body([](TVMArgs args, TVMRetValue* rv) { + VMClosure clo = args[0]; + std::vector saved_args; + saved_args.resize(args.size() - 1); + for (size_t i = 0; i < saved_args.size(); ++i) { + saved_args[i] = args[i + 1]; + } + auto impl = VMClosure::BindLastArgs(clo->impl, saved_args); + *rv = VMClosure(clo->func_name, impl); +}); + +TVM_REGISTER_GLOBAL("vm.builtin.invoke_closure").set_body([](TVMArgs args, TVMRetValue* rv) { + // args[0]: vm; args[1]: closure; args[2, 3, ...]: function arguments + VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); + ObjectRef vm_closure = args[1]; + vm->InvokeClosurePacked(vm_closure, + TVMArgs(args.values + 2, args.type_codes + 2, args.size() - 2), rv); +}); + +TVM_REGISTER_GLOBAL("vm.builtin.call_tir_dyn").set_body([](TVMArgs args, TVMRetValue* rv) { + PackedFunc func = args[0]; + ShapeTuple to_unpack = args[args.size() - 1]; + size_t num_tensor_args = args.size() - 2; + + std::vector values(num_tensor_args + to_unpack.size()); + std::vector tcodes(num_tensor_args + to_unpack.size()); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + + std::copy(args.values + 1, args.values + args.size() - 1, values.data()); + std::copy(args.type_codes + 1, args.type_codes + args.size() - 1, tcodes.data()); + + for (size_t i = 0; i < to_unpack.size(); ++i) { + setter(i + num_tensor_args, to_unpack[i]); + } + TVMArgs func_args(values.data(), tcodes.data(), values.size()); + func.CallPacked(func_args, rv); +}); + +//------------------------------------- +// Builtin runtime operators. +//------------------------------------- +TVM_REGISTER_GLOBAL("vm.builtin.shape_of").set_body_method(&NDArray::Shape); + +TVM_REGISTER_GLOBAL("vm.builtin.copy").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = args[0]; +}); + +TVM_REGISTER_GLOBAL("vm.builtin.reshape").set_body_typed([](NDArray data, ShapeTuple new_shape) { + return data.CreateView(new_shape, data->dtype); +}); + +/*! + * \brief Load the scalar value in cond and return the result value. + * \param cond The condition + * \return Bool + */ +bool ReadIfCond(TVMArgValue cond) { + if (cond.type_code() == kDLInt) return cond.operator bool(); + NDArray arr = cond.operator tvm::runtime::NDArray(); + if (arr->device.device_type != kDLCPU) { + arr = arr.CopyTo(DLDevice{kDLCPU, 0}); + } + ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt); + int64_t result; + switch (arr->dtype.bits) { + case 1: { + result = reinterpret_cast(arr->data)[0]; + break; + } + case 8: { + result = reinterpret_cast(arr->data)[0]; + break; + } + case 16: { + result = reinterpret_cast(arr->data)[0]; + break; + } + case 32: { + result = reinterpret_cast(arr->data)[0]; + break; + } + case 64: { + result = reinterpret_cast(arr->data)[0]; + break; + } + default: + LOG(FATAL) << "Unknown scalar int type: " << DLDataType2String(arr->dtype); + throw; + } + return result != 0; +} + +TVM_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond); + +//------------------------------------- +// Data structure API +//------------------------------------- +TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem") + .set_body_typed([](runtime::Array arr, int64_t index) { return arr[index]; }); + +TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Array arr; + for (int i = 0; i < args.num_args; ++i) { + arr.push_back(args[i].operator ObjectRef()); + } + *rv = arr; +}); + +TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data) { + NDArray arr = data; + if (data->device.device_type != kDLCPU) { + arr = data.CopyTo(DLDevice{kDLCPU, 0}); + } + + ICHECK_EQ(arr->ndim, 1); + ICHECK_EQ(arr->dtype.code, kDLInt); + + std::vector out_shape; + for (int i = 0; i < arr.Shape()[0]; ++i) { + int64_t result; + switch (arr->dtype.bits) { + case 16: { + result = reinterpret_cast(arr->data)[i]; + break; + } + case 32: { + result = reinterpret_cast(arr->data)[i]; + break; + } + case 64: { + result = reinterpret_cast(arr->data)[i]; + break; + } + default: + LOG(FATAL) << "Unknown scalar int type: " << DLDataType2String(arr->dtype); + throw; + } + out_shape.push_back(result); + } + return ShapeTuple(out_shape); +}); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +//------------------------------------------------- +// AnyList C runtime API: keep in relax for now. +//-------------------------------------------------- +extern "C" { +/*! + * \brief Backend function to get anylist item and set into Packed Func call arg stack. + * + * \param anylist The handle to the anylist, backed by TVMRetValue* + * \param int The index. + * \param args The args stack. + * \param type_codes The type codes stack. + * \param arg_offset The offset of argument. + * \return 0 when no error is thrown, -1 when failure happens + */ +TVM_DLL int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMValue* args, int* type_codes, + int arg_offset); +/*! + * \brief Backend function to get anylist item and set into Packed Func call arg stack. + * + * \param anylist The handle to the anylist, backed by TVMRetValue* + * \param int The index. + */ +TVM_DLL int TVMBackendAnyListResetItem(void* anylist, int index); + +/*! + * \brief Backend function to set anylist item by moving from packed func return. + * + * \param anylist The handle to the anylist, backed by TVMRetValue* + * \param int The index. + * \param args The args stack. + * \param type_codes The type codes stack. + * \param arg_offset The offset of argument. + * \return 0 when no error is thrown, -1 when failure happens. + */ +TVM_DLL int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMValue* args, + int* type_codes, int ret_offset); + +int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMValue* args, int* type_codes, + int arg_offset) { + using namespace tvm::runtime; + API_BEGIN(); + auto* list = static_cast(anylist); + TVMArgsSetter setter(args, type_codes); + setter(arg_offset, list[index]); + API_END(); +} + +int TVMBackendAnyListResetItem(void* anylist, int index) { + using namespace tvm::runtime; + API_BEGIN(); + auto* list = static_cast(anylist); + list[index] = nullptr; + API_END(); +} + +int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMValue* args, int* type_codes, + int ret_offset) { + using namespace tvm::runtime; + API_BEGIN(); + auto* list = static_cast(anylist); + if (type_codes[ret_offset] == kTVMStr || type_codes[ret_offset] == kTVMBytes) { + list[index] = TVMArgValue(args[ret_offset], type_codes[ret_offset]); + } else { + list[index] = TVMRetValue::MoveFromCHost(args[ret_offset], type_codes[ret_offset]); + } + API_END(); +} +} // extern "C" diff --git a/src/runtime/relax_vm/bytecode.cc b/src/runtime/relax_vm/bytecode.cc new file mode 100644 index 000000000000..9084207848b5 --- /dev/null +++ b/src/runtime/relax_vm/bytecode.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/relax_vm/bytecode.cc + * \brief The bytecode for Relax virtual machine. + */ + +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +Instruction Instruction::Call(Index func_idx, Index num_args, Instruction::Arg* args, RegName dst) { + Instruction instr; + instr.op = Opcode::Call; + instr.dst = dst; + instr.func_idx = func_idx; + instr.num_args = num_args; + instr.args = args; + return instr; +} + +Instruction Instruction::Ret(RegName result) { + Instruction instr; + instr.op = Opcode::Ret; + instr.result = result; + return instr; +} + +Instruction Instruction::Goto(Index pc_offset) { + Instruction instr; + instr.op = Opcode::Goto; + instr.pc_offset = pc_offset; + return instr; +} + +Instruction Instruction::If(RegName cond, Index false_offset) { + Instruction instr; + instr.op = Opcode::If; + instr.cond = cond; + instr.false_offset = false_offset; + return instr; +} +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc new file mode 100644 index 000000000000..45342cf4ffa2 --- /dev/null +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/relax_vm/cuda_graph_builtin.cc + * \brief The CUDA graph related builtin functions for Relax virtual machine. + */ + +#include +#include +#include + +#include "../../cuda/cuda_common.h" +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! \brief Represents a CUDA graph. */ +class CUDAGraphNode : public Object { + public: + cudaGraph_t handle_ = nullptr; + + ~CUDAGraphNode() { + if (handle_ != nullptr) { + cudaGraphDestroy(handle_); + } + } + + TVM_DECLARE_FINAL_OBJECT_INFO(CUDAGraphNode, Object); +}; + +/*! + * \brief Managed reference to CUDAGraphNode + * \sa CUDAGraphNode + */ +class CUDAGraph : public ObjectRef { + public: + explicit CUDAGraph(cudaGraph_t handle) { + auto n = make_object(); + n->handle_ = handle; + data_ = std::move(n); + } + TVM_DEFINE_OBJECT_REF_METHODS(CUDAGraph, ObjectRef, CUDAGraphNode); +}; + +/*! \brief The cache states of a CUDA graph. */ +class CUDAGraphCache : public Object { + public: + struct CaptureResult { + /*! + * \brief Tuple of intemediate tensors in the capture func that will be used outside the + * capture func + */ + ObjectRef states; + /*! \brief The cuda graph instance */ + CUDAGraph graph; + }; + + static CUDAGraphCache* Get() { return dmlc::ThreadLocalStore::Get(); } + + /*! + * \brief Launch the cuda graph if it has been cached, otherwise execute it in capture mode. + * \param vm The virtual machine. + * \param capture_func The function of type (args...) -> Tuple[ObjectRef], where 'args' are the + * static arguments that are the same for all invocations of the capture function, the returned + * tuple contains the intermediate tensors that will be used outside the capture function. + * \param args The static arguments of the capture function + * \param entry_index The unique index of the capture function used for lookup. + * \return The return value of the capture function. + */ + ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, ObjectRef args, + int64_t entry_index) { + if (auto it = capture_cache_.find(entry_index); it != capture_cache_.end()) { + LOG(INFO) << "HIT"; + // Launch CUDA graph + const auto& [states, cuda_graph] = it->second; + cudaGraphExec_t cuda_graph_exec; + CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, cuda_graph->handle_, NULL, NULL, 0)); + CUDA_CALL(cudaGraphLaunch(cuda_graph_exec, CUDAThreadEntry::ThreadLocal()->stream)); + CUDA_CALL(cudaGraphExecDestroy(cuda_graph_exec)); + return states; + } + + cudaStream_t capture_stream; + CUDA_CALL(cudaStreamCreate(&capture_stream)); + CUDAGraphCache::CaptureResult entry; + + // Set up arguments for the graph execution + Array tuple_args = Downcast>(args); + int nargs = static_cast(tuple_args.size()); + std::vector values(nargs); + std::vector tcodes(nargs); + TVMArgsSetter setter(values.data(), tcodes.data()); + for (int i = 0; i < nargs; ++i) { + ObjectRef arg = tuple_args[i]; + setter(i, arg); + } + + TVMRetValue capture_func_rv; + // Run the function without CUDA graph. This is a warm up step to do necessary initialization + // of the CUDA module such as loading module data, setting kernel attributes. + vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs), + &capture_func_rv); + + // Run the graph in capture mode + cudaGraph_t graph; + std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream); + CUDA_CALL(cudaStreamBeginCapture(CUDAThreadEntry::ThreadLocal()->stream, + cudaStreamCaptureModeGlobal)); + + vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs), + &capture_func_rv); + entry.states = capture_func_rv; + CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, &graph)); + std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream); + + entry.graph = CUDAGraph(graph); + capture_cache_[entry_index] = entry; + CUDA_CALL(cudaStreamDestroy(capture_stream)); + return entry.states; + } + + /*! + * \brief Get the cached allocation from the cache or run the allocation function. + * \param vm The virtual machine. + * \param alloc_func The function of type () -> ObjectRef, where the returned object is the + * tuple of allocated storage objects. + * \param entry_index The unique index of the allocation function used for lookup. + */ + ObjectRef GetCachedAllocation(VirtualMachine* vm, const ObjectRef& alloc_func, + int64_t entry_index) { + if (auto it = alloc_cache_.find(entry_index); it != alloc_cache_.end()) { + return it->second; + } + TVMRetValue alloc_func_rv; + vm->InvokeClosurePacked(alloc_func, TVMArgs(nullptr, nullptr, 0), &alloc_func_rv); + ObjectRef alloc_result = alloc_func_rv; + alloc_cache_[entry_index] = alloc_result; + return alloc_result; + } + + private: + /*! + * \brief The cache of captured cuda graphs. The key is a unique index for the capture function. + * The value is the result of the capture. + */ + std::unordered_map capture_cache_; + /*! + * \brief The cache of allocations. The key is a unique index for the allocation function. + * The value is the cached allocations, which is a tuple of storages. + */ + std::unordered_map alloc_cache_; +}; + +TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") + .set_body_typed([](TVMArgValue vm_ptr, ObjectRef capture_func, ObjectRef func_args, + int64_t entry_index) { + VirtualMachine* vm = VirtualMachine::GetContextPtr(vm_ptr); + CUDAGraphCache* cache = CUDAGraphCache::Get(); + return cache->RunOrCapture(vm, capture_func, func_args, entry_index); + }); + +TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc") + .set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args.size(), 3); + VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); + ObjectRef alloc_func = args[1]; + int64_t entry_index = args[2]; + CUDAGraphCache* cache = CUDAGraphCache::Get(); + *rv = cache->GetCachedAllocation(vm, alloc_func, entry_index); + }); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc new file mode 100644 index 000000000000..2090a3b25413 --- /dev/null +++ b/src/runtime/relax_vm/executable.cc @@ -0,0 +1,583 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/relax_vm/executable.cc + */ + +#include +#include +#include +#include + +#include +#include + +#include "../file_utils.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! \brief The magic number for the serialized VM bytecode file */ +constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D; + +/*! \brief Possible types in the constant pool */ +enum ConstantType : int { + kNDArray = 0, + kDLDataType = 1, + kShapeTuple = 2, + kString = 3, + kInt = 4, +}; + +#define STREAM_CHECK(val, section) \ + ICHECK(val) << "Invalid VM file format in the " << section << " section." \ + << "\n"; + +PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "stats") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); + } else if (name == "as_text") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->AsText(); }); + } else if (name == "as_python") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->AsPython(); }); + } else if (name == "vm_load_executable") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ObjectPtr vm = VirtualMachine::Create(); + ICHECK(sptr_to_self.get() == this); + vm->LoadExecutable(GetObjectPtr(this)); + *rv = Module(vm); + }); + } else if (name == "vm_profiler_load_executable") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ObjectPtr vm = VirtualMachine::CreateProfiler(); + ICHECK(sptr_to_self.get() == this); + vm->LoadExecutable(GetObjectPtr(this)); + *rv = Module(vm); + }); + } + return nullptr; +} + +std::string Executable::Stats() const { + std::ostringstream oss; + oss << "Relax VM executable statistics:" << std::endl; + + // Get the number of constants. + // If the constant is an NDArray, get the shape of each of them. + // If the constant is an DLDataType, get the data type of each of them. + oss << " Constant pool (# " << constants.size() << "): ["; + for (const auto& it : constants) { + if (it.IsObjectRef()) { + const auto ndarray = it.operator tvm::runtime::NDArray(); + const auto& shape = ndarray.Shape(); + // Scalar + if (shape.empty()) { + oss << "scalar, "; + continue; + } + oss << "["; + for (auto s : shape) { + oss << s << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], "; + } else if (it.IsObjectRef()) { + ShapeTuple shape = it.operator ShapeTuple(); + oss << "shapetuple["; + for (size_t i = 0; i < shape.size(); ++i) { + oss << shape.at(i) << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], "; + } else if (it.IsObjectRef()) { + std::string f = it.AsObjectRef().operator std::string(); + oss << "\""; + oss << f; + oss << "\", "; + } else if (it.type_code() == kDLInt) { + oss << static_cast(it); + oss << ", "; + } else { + try { + DLDataType dtype = it.operator DLDataType(); + oss << dtype; + oss << ", "; + } catch (std::exception& exc) { + LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " + << ArgTypeCode2Str(it.type_code()); + } + } + } + if (!constants.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of globals and the name of each of them. + oss << " Globals (#" << func_table.size() << "): ["; + for (const auto& it : func_table) { + oss << it.name << ", "; + } + if (!func_map.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + return oss.str(); +} + +void Executable::SetInstructionData(Index i, Index j, ExecWord val) { + ICHECK_LT(i, instr_offset.size()); + Index instr_idx = instr_offset[i]; + ICHECK_LT(instr_idx + j, instr_data.size()); + instr_data[instr_idx + j] = val; +} + +Instruction Executable::GetInstruction(Index i) const { + Index offset = instr_offset[i]; + Opcode op = static_cast(instr_data[offset]); + switch (op) { + case Opcode::Call: { + RegName dst = instr_data[offset + 1]; + Index func_idx = instr_data[offset + 2]; + Index num_args = instr_data[offset + 3]; + ExecWord* args = const_cast(&instr_data[offset + 4]); + return Instruction::Call(func_idx, num_args, reinterpret_cast(args), dst); + } + case Opcode::Ret: { + RegName result = instr_data[offset + 1]; + return Instruction::Ret(result); + } + case Opcode::Goto: { + Index pc_offset = instr_data[offset + 1]; + return Instruction::Goto(pc_offset); + } + case Opcode::If: { + RegName cond = instr_data[offset + 1]; + Index false_offset = instr_data[offset + 2]; + return Instruction::If(cond, false_offset); + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(op); + break; + } + return Instruction(); +} + +void SaveHeader(dmlc::Stream* strm) { + uint64_t header = kTVMVMBytecodeMagic; + strm->Write(header); + std::string version = TVM_VERSION; + strm->Write(version); +} + +void LoadHeader(dmlc::Stream* strm) { + // Check header. + uint64_t header; + STREAM_CHECK(strm->Read(&header), "header"); + STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); + + // Check version. + std::string version; + STREAM_CHECK(strm->Read(&version), "version"); + STREAM_CHECK(version == TVM_VERSION, "version"); +} + +void Executable::SaveToBinary(dmlc::Stream* stream) { + std::string code; + // Initialize the stream object. + dmlc::MemoryStringStream strm(&code); + + // Save header + SaveHeader(&strm); + + // Global section. + SaveGlobalSection(&strm); + + // Constant section. + SaveConstantSection(&strm); + + // Code section. + SaveCodeSection(&strm); + + stream->Write(code); +} + +void Executable::SaveToFile(const std::string& file_name, const std::string& format) { + std::string data; + dmlc::MemoryStringStream writer(&data); + dmlc::SeekStream* strm = &writer; + Executable::SaveToBinary(strm); + runtime::SaveBinaryToFile(file_name, data); +} + +Module Executable::LoadFromBinary(void* stream) { + std::string code; + static_cast(stream)->Read(&code); + dmlc::MemoryStringStream strm(&code); + + ObjectPtr exec = make_object(); + + // Load header. + LoadHeader(&strm); + + // Global section. + exec->LoadGlobalSection(&strm); + + // Constant section. + exec->LoadConstantSection(&strm); + + // Code section. + exec->LoadCodeSection(&strm); + + return Module(exec); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_relax.Executable") + .set_body_typed(Executable::LoadFromBinary); + +Module Executable::LoadFromFile(const std::string& file_name) { + std::string data; + runtime::LoadBinaryFromFile(file_name, &data); + dmlc::MemoryStringStream reader(&data); + dmlc::Stream* strm = &reader; + return Executable::LoadFromBinary(reinterpret_cast(strm)); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_relax.Executable") + .set_body_typed(Executable::LoadFromFile); + +void VMFuncInfo::Save(dmlc::Stream* strm) const { + int32_t temp_kind = static_cast(kind); + strm->Write(temp_kind); + strm->Write(name); + strm->Write(start_instr); + strm->Write(end_instr); + strm->Write(num_args); + strm->Write(register_file_size); + strm->Write(param_names); +} + +bool VMFuncInfo::Load(dmlc::Stream* strm) { + int32_t temp_kind; + if (!strm->Read(&temp_kind)) return false; + this->kind = static_cast(temp_kind); + if (!strm->Read(&name)) return false; + if (!strm->Read(&start_instr)) return false; + if (!strm->Read(&end_instr)) return false; + if (!strm->Read(&num_args)) return false; + if (!strm->Read(®ister_file_size)) return false; + if (!strm->Read(¶m_names)) return false; + return true; +} + +void Executable::SaveGlobalSection(dmlc::Stream* strm) { strm->Write(func_table); } + +void Executable::SaveConstantSection(dmlc::Stream* strm) { + strm->Write(static_cast(this->constants.size())); + for (const auto& it : this->constants) { + if (it.IsObjectRef()) { + strm->Write(ConstantType::kNDArray); + runtime::SaveDLTensor(strm, it.operator DLTensor*()); + } else if (it.IsObjectRef()) { + ShapeTuple shape = it.operator ShapeTuple(); + strm->Write(ConstantType::kShapeTuple); + strm->Write(shape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + strm->Write(shape.at(i)); + } + } else if (it.IsObjectRef()) { + String str = it.operator String(); + strm->Write(ConstantType::kString); + strm->Write(str.size()); + for (size_t i = 0; i < str.size(); ++i) { + strm->Write(str.at(i)); + } + } else if (it.type_code() == kDLInt) { + strm->Write(ConstantType::kInt); + strm->Write(it.value()); + } else { + try { + strm->Write(ConstantType::kDLDataType); + strm->Write(it.operator DLDataType()); + } catch (std::exception& exc) { + LOG(FATAL) << "Constant pool can only contain NDArray, DLDataType, and Integers but got " + << ArgTypeCode2Str(it.type_code()); + } + } + } +} + +void Executable::SaveCodeSection(dmlc::Stream* strm) { + strm->Write(instr_offset); + strm->Write(instr_data); +} + +void Executable::LoadGlobalSection(dmlc::Stream* strm) { + STREAM_CHECK(strm->Read(&func_table), "Global Section"); + // setup func map + for (size_t i = 0; i < func_table.size(); ++i) { + this->func_map[func_table[i].name] = i; + } +} + +void Executable::LoadConstantSection(dmlc::Stream* strm) { + uint64_t sz; + // Load the number of constants. + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); + + size_t size = static_cast(sz); + runtime::NDArray ndarray; + DLDataType dtype; + // Load each of the constants. + for (size_t i = 0; i < size; i++) { + int constant_type; + STREAM_CHECK(strm->Read(&constant_type, sizeof(constant_type)), "constant"); + if (constant_type == ConstantType::kNDArray) { + ndarray.Load(strm); + TVMRetValue cell; + cell = ndarray; + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kShapeTuple) { + uint64_t size; + strm->Read(&size); + std::vector data(size); + for (size_t i = 0; i < size; ++i) { + strm->Read(&(data[i])); + } + TVMRetValue cell; + cell = ShapeTuple(data); + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kDLDataType) { + strm->Read(&dtype); + TVMRetValue cell; + cell = dtype; + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kString) { + uint64_t size; + strm->Read(&size); + std::vector data(size); + for (size_t i = 0; i < size; ++i) { + strm->Read(&(data[i])); + } + TVMRetValue cell; + cell = String(std::string(data.begin(), data.end())); + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kInt) { + int64_t value; + strm->Read(&value); + TVMRetValue cell; + cell = value; + this->constants.push_back(cell); + } else { + LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " + << ArgTypeCode2Str(constant_type) << " when loading the VM constant pool."; + } + } +} + +void Executable::LoadCodeSection(dmlc::Stream* strm) { + STREAM_CHECK(strm->Read(&(this->instr_offset)), "instr offset"); + STREAM_CHECK(strm->Read(&(this->instr_data)), "instr data"); +} + +template +std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ", + std::function repr = std::to_string) { + if (cnt == 0) { + return ""; + } + std::ostringstream oss; + oss << repr(items[offset]); + for (int i = 1; i < cnt; ++i) { + oss << delim << repr(items[offset + i]); + } + return oss.str(); +} + +std::string RegNameToStr(RegName reg) { + if (reg == Instruction::kVoidRegister) { + return "%void"; + } + if (reg == Instruction::kVMRegister) { + return "%vm"; + } + return "%" + std::to_string(reg); +} + +String Executable::AsText() const { + auto get_func_name = [&](Index index) -> std::string { + if (static_cast(index) < func_table.size()) { + return func_table[index].name; + } else { + return "unknown_func_index(" + std::to_string(index) + ")"; + } + }; + + auto instr_to_str = [&](Instruction::Arg arg) -> std::string { + // only for argument + switch (arg.kind()) { + case Instruction::ArgKind::kRegister: + return RegNameToStr(arg.value()); + case Instruction::ArgKind::kImmediate: + return "i" + std::to_string(arg.value()); + case Instruction::ArgKind::kConstIdx: + return "c[" + std::to_string(arg.value()) + "]"; + case Instruction::ArgKind::kFuncIdx: + return "f[" + get_func_name(arg.value()) + "]"; + default: + LOG(FATAL) << "Wrong instruction kind: " << static_cast(arg.kind()); + return ""; + } + }; + + // print the text format + std::ostringstream os; + for (size_t fidx = 0; fidx < this->func_table.size(); ++fidx) { + const VMFuncInfo& gfunc = this->func_table[fidx]; + if (gfunc.kind == VMFuncInfo::FuncKind::kPackedFunc) { + os << "@" << gfunc.name << " packed_func;\n\n"; + continue; + } + if (gfunc.kind == VMFuncInfo::FuncKind::kVMTIRFunc) { + os << "@" << gfunc.name << " num_inputs=" << gfunc.num_args << " vm_tir_func;\n\n"; + continue; + } + ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); + os << "@" << gfunc.name << ":\n"; + size_t start_instr = gfunc.start_instr; + size_t end_instr = gfunc.end_instr; + + for (size_t idx = start_instr; idx < end_instr; ++idx) { + os << " "; + Instruction instr = this->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + os << std::setw(6) << std::left << "call" << std::setw(16) << std::left + << get_func_name(instr.func_idx) << " in: " << std::setw(12) << std::left + << StrJoin(instr.args, 0, instr.num_args, ", ", instr_to_str) + << " dst: " << RegNameToStr(instr.dst) << "\n"; + break; + } + case Opcode::Ret: { + os << std::setw(6) << std::left << "ret " << RegNameToStr(instr.result) << "\n"; + break; + } + case Opcode::Goto: { + os << std::setw(6) << std::left << "goto" << instr.pc_offset << "\n"; + break; + } + case Opcode::If: { + os << std::setw(6) << std::left << "If" << RegNameToStr(instr.cond) << ", " + << instr.false_offset << "\n"; + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + os << "\n"; + } + return String(os.str()); +} + +String Executable::AsPython() const { + auto get_func_name = [&](Index index) -> std::string { + if (static_cast(index) < func_table.size()) { + return "\"" + func_table[index].name + "\""; + } else { + return "ib.unknown_func_index(" + std::to_string(index) + ")"; + } + }; + + auto arg_to_py_str = [&](Instruction::Arg arg) -> std::string { + switch (arg.kind()) { + case Instruction::ArgKind::kRegister: + if (arg.value() == Instruction::kVMRegister) { + return "ib.r(vm)"; + } + return "ib.r(" + std::to_string(arg.value()) + ")"; + case Instruction::ArgKind::kImmediate: + return "ib.imm(" + std::to_string(arg.value()) + ")"; + case Instruction::ArgKind::kConstIdx: + return "ib.c(" + std::to_string(arg.value()) + ")"; + case Instruction::ArgKind::kFuncIdx: { + return "ib.f(" + get_func_name(arg.value()) + ")"; + } + default: + LOG(FATAL) << "Wrong instruction kind: " << static_cast(arg.kind()); + return ""; + } + }; + + // print the python format + std::ostringstream os; + os << "ib = rx.Builder()\n"; + for (size_t fidx = 0; fidx < this->func_table.size(); ++fidx) { + const VMFuncInfo& gfunc = this->func_table[fidx]; + if (gfunc.kind == VMFuncInfo::FuncKind::kPackedFunc) { + continue; + } + if (gfunc.kind == VMFuncInfo::FuncKind::kVMTIRFunc) { + continue; + } + ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); + + os << "with ib.function(\"" << gfunc.name << "\", num_inputs=" << gfunc.num_args << "):\n"; + size_t start_instr = gfunc.start_instr; + size_t end_instr = gfunc.end_instr; + + for (size_t idx = start_instr; idx < end_instr; ++idx) { + Instruction instr = this->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + os << " ib.emit_call(" << get_func_name(instr.func_idx) << ", args=[" + << StrJoin(instr.args, 0, instr.num_args, ", ", arg_to_py_str) + << "]"; + if (instr.dst != Instruction::kVoidRegister) os << ", dst=ib.r(" << instr.dst << ")"; + os << ")\n"; + break; + } + case Opcode::Ret: { + os << " ib.emit_ret(ib.r(" << instr.result << "))\n"; + break; + } + case Opcode::Goto: { + os << " ib.emit_goto(" << instr.pc_offset << ")\n"; + break; + } + case Opcode::If: { + os << " ib.emit_if(ib.r(" << instr.cond << "), " << instr.false_offset << ")\n"; + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + } + return String(os.str()); +} + +TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(Executable::LoadFromFile); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/memory_manager.cc b/src/runtime/relax_vm/memory_manager.cc new file mode 100644 index 000000000000..339045f515cf --- /dev/null +++ b/src/runtime/relax_vm/memory_manager.cc @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/memory_manager.cc + * \brief Allocate and manage memory for the Relay VM. + */ +#include + +#include +#include + +#include "naive_allocator.h" +#include "pooled_allocator.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +static void BufferDeleter(Object* obj) { + auto* ptr = static_cast(obj); + ICHECK(ptr->manager_ctx != nullptr); + Buffer* buffer = reinterpret_cast(ptr->manager_ctx); + MemoryManager::GetAllocator(buffer->device)->Free(*(buffer)); + delete buffer; + delete ptr; +} + +void StorageObj::Deleter(Object* obj) { + auto* ptr = static_cast(obj); + // When invoking AllocNDArray we don't own the underlying allocation + // and should not delete the buffer, but instead let it be reclaimed + // by the storage object's destructor. + // + // We did bump the reference count by 1 to keep alive the StorageObj + // allocation in case this NDArray is the sole owner. + // + // We decrement the object allowing for the buffer to release our + // reference count from allocation. + StorageObj* storage = reinterpret_cast(ptr->manager_ctx); + storage->DecRef(); + delete ptr; +} + +inline void VerifyDataType(DLDataType dtype) { + ICHECK_GE(dtype.lanes, 1); + if (dtype.code == kDLFloat) { + ICHECK_EQ(dtype.bits % 8, 0); + } else { + // allow uint1 as a special flag for bool. + if (dtype.bits == 1 && dtype.code == kDLUInt) return; + ICHECK_EQ(dtype.bits % 8, 0); + } + ICHECK_EQ(dtype.bits & (dtype.bits - 1), 0); +} + +inline size_t GetDataAlignment(const DLTensor& arr) { + size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes; + if (align < runtime::kAllocAlignment) return runtime::kAllocAlignment; + return align; +} + +runtime::NDArray StorageObj::AllocNDArray(uint64_t offset, ShapeTuple shape, DLDataType dtype) { + VerifyDataType(dtype); + + // critical zone: allocate header, cannot throw + runtime::NDArray::Container* container = + new runtime::NDArray::Container(nullptr, shape, dtype, this->buffer.device); + + container->SetDeleter(StorageObj::Deleter); + size_t needed_size = runtime::GetDataSize(container->dl_tensor); + this->IncRef(); + // The manager context pointer must continue to point to the storage object + // which owns the backing memory, and keeps track of the reference count. + // + // When we free a container we extract the storage object, decrement its + // reference count, then destroy the container, but leave the underlying + // buffer intact. + container->manager_ctx = reinterpret_cast(this); + + // is this UB? + // The only change we make w.r.t offset is modifying the data pointer + // of the backing tensor to point into the buffer instead of its start. + auto offset_ptr = reinterpret_cast(this->buffer.data) + offset; + container->dl_tensor.data = reinterpret_cast(offset_ptr); + + runtime::NDArray ret(runtime::GetObjectPtr(container)); + // RAII in effect, now run the check. + + ICHECK(offset + needed_size <= this->buffer.size) + << "storage allocation failure, attempted to allocate " << needed_size << " at offset " + << offset << " in region that is " << this->buffer.size << "bytes"; + + return ret; +} + +MemoryManager* MemoryManager::Global() { + // NOTE: explicitly use new to avoid exit-time destruction of global state + // Global state will be recycled by OS as the process exits. + static auto* inst = new MemoryManager(); + return inst; +} + +Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) { + MemoryManager* m = MemoryManager::Global(); + std::lock_guard lock(m->mutex_); + if (m->allocators_.find(dev) == m->allocators_.end()) { + std::unique_ptr alloc; + switch (type) { + case kNaive: { + DLOG(INFO) << "New naive allocator for " << runtime::DeviceName(dev.device_type) << "(" + << dev.device_id << ")"; + alloc.reset(new NaiveAllocator(dev)); + break; + } + case kPooled: { + DLOG(INFO) << "New pooled allocator for " << runtime::DeviceName(dev.device_type) << "(" + << dev.device_id << ")"; + alloc.reset(new PooledAllocator(dev)); + break; + } + default: + LOG(FATAL) << "Unknown allocator type: " << type; + } + auto ret = alloc.get(); + m->allocators_.emplace(dev, std::move(alloc)); + return ret; + } + auto alloc = m->allocators_.at(dev).get(); + if (alloc->type() != type) { + LOG(WARNING) << "The type of existing allocator for " << runtime::DeviceName(dev.device_type) + << "(" << dev.device_id << ") is different from the request type (" + << alloc->type() << " vs " << type << ")"; + } + return alloc; +} + +Allocator* MemoryManager::GetAllocator(Device dev) { + MemoryManager* m = MemoryManager::Global(); + std::lock_guard lock(m->mutex_); + auto it = m->allocators_.find(dev); + if (it == m->allocators_.end()) { + LOG(FATAL) << "Allocator for " << runtime::DeviceName(dev.device_type) << "(" << dev.device_id + << ") has not been created yet."; + } + return it->second.get(); +} + +runtime::NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice dev) { + VerifyDataType(dtype); + runtime::NDArray::Container* container = + new runtime::NDArray::Container(nullptr, shape, dtype, dev); + container->SetDeleter(BufferDeleter); + size_t size = runtime::GetDataSize(container->dl_tensor); + size_t alignment = GetDataAlignment(container->dl_tensor); + Buffer* buffer = new Buffer; + *buffer = this->Alloc(size, alignment, dtype); + container->manager_ctx = reinterpret_cast(buffer); + container->dl_tensor.data = buffer->data; + return runtime::NDArray(runtime::GetObjectPtr(container)); +} + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/naive_allocator.h b/src/runtime/relax_vm/naive_allocator.h new file mode 100644 index 000000000000..843a559602ab --- /dev/null +++ b/src/runtime/relax_vm/naive_allocator.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/naive_allocator.h + */ +#ifndef TVM_RUNTIME_RELAX_VM_NAIVE_ALLOCATOR_H_ +#define TVM_RUNTIME_RELAX_VM_NAIVE_ALLOCATOR_H_ + +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +class NaiveAllocator final : public Allocator { + public: + explicit NaiveAllocator(Device dev) : Allocator(kNaive), used_memory_(0), device_(dev) {} + + Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override { + Buffer buf; + buf.device = device_; + buf.size = nbytes; + buf.data = + runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, nbytes, alignment, type_hint); + used_memory_.fetch_add(nbytes, std::memory_order_relaxed); + DLOG(INFO) << "allocate " << nbytes << " B, used memory " << used_memory_ << " B"; + return buf; + } + + void Free(const Buffer& buffer) override { + runtime::DeviceAPI::Get(device_)->FreeDataSpace(buffer.device, buffer.data); + used_memory_.fetch_sub(buffer.size, std::memory_order_relaxed); + DLOG(INFO) << "free " << buffer.size << " B, used memory " << used_memory_ << " B"; + } + + private: + std::atomic used_memory_; + Device device_; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_NAIVE_ALLOCATOR_H_ diff --git a/src/runtime/relax_vm/pooled_allocator.h b/src/runtime/relax_vm/pooled_allocator.h new file mode 100644 index 000000000000..0dd7d8b0277b --- /dev/null +++ b/src/runtime/relax_vm/pooled_allocator.h @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/pooled_allocator.h + */ +#ifndef TVM_RUNTIME_RELAX_VM_POOLED_ALLOCATOR_H_ +#define TVM_RUNTIME_RELAX_VM_POOLED_ALLOCATOR_H_ + +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +class PooledAllocator final : public Allocator { + public: + static constexpr size_t kDefaultPageSize = 4096; + + explicit PooledAllocator(Device dev, size_t page_size = kDefaultPageSize) + : Allocator(kPooled), page_size_(page_size), used_memory_(0), device_(dev) {} + + ~PooledAllocator() { ReleaseAll(); } + + Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override { + std::lock_guard lock(mu_); + size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_; + auto&& it = memory_pool_.find(size); + if (it != memory_pool_.end() && !it->second.empty()) { + auto&& pool = it->second; + auto ret = pool.back(); + pool.pop_back(); + return ret; + } + Buffer buf; + buf.device = device_; + buf.size = size; + try { + buf.data = + runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + } catch (InternalError& err) { + LOG(WARNING) << "PooledAllocator got InternalError during allocation: " << err.message(); + LOG(WARNING) << "Trying to release all unused memory and reallocate..."; + ReleaseAll(); + buf.data = + runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + } + + used_memory_.fetch_add(size, std::memory_order_relaxed); + DLOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B"; + return buf; + } + + void Free(const Buffer& buffer) override { + std::lock_guard lock(mu_); + if (memory_pool_.find(buffer.size) == memory_pool_.end()) { + memory_pool_.emplace(buffer.size, std::vector{}); + } + memory_pool_.at(buffer.size).push_back(buffer); + DLOG(INFO) << "reclaim buffer " << buffer.size; + } + + private: + void ReleaseAll() { + std::lock_guard lock(mu_); + for (auto const& it : memory_pool_) { + auto const& pool = it.second; + for (auto const& buf : pool) { + runtime::DeviceAPI::Get(buf.device)->FreeDataSpace(buf.device, buf.data); + } + } + memory_pool_.clear(); + used_memory_ = 0; + DLOG(INFO) << "release all buffers"; + } + + private: + size_t page_size_; + std::atomic used_memory_; + std::unordered_map > memory_pool_; + std::recursive_mutex mu_; + Device device_; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_POOLED_ALLOCATOR_H_ diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc new file mode 100644 index 000000000000..2e6c3412138c --- /dev/null +++ b/src/runtime/relax_vm/vm.cc @@ -0,0 +1,1038 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/relax_vm/vm.cc + */ + +#include +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +//--------------------------------------------- +// VM Closure object +//--------------------------------------------- +TVM_REGISTER_OBJECT_TYPE(VMClosureObj); + +VMClosure::VMClosure(String func_name, PackedFunc impl) { + auto ptr = make_object(); + ptr->func_name = func_name; + ptr->impl = std::move(impl); + data_ = std::move(ptr); +} + +/*! + * \brief Create another PackedFunc with last arguments already bound to last_args. + * \param func The input func, can be a VMClosure or PackedFunc. + * \param last_args The arguments to bound to in the end of the function. + * \note The new function takes in arguments and append the last_args in the end. + */ +PackedFunc VMClosure::BindLastArgs(PackedFunc func, std::vector last_args) { + return PackedFunc([func, last_args](TVMArgs args, TVMRetValue* rv) { + std::vector values(args.size() + last_args.size()); + std::vector tcodes(args.size() + last_args.size()); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + std::copy(args.values, args.values + args.size(), values.data()); + std::copy(args.type_codes, args.type_codes + args.size(), tcodes.data()); + for (size_t i = 0; i < last_args.size(); ++i) { + setter(i + args.size(), last_args[i]); + } + func.CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), rv); + }); +} + +//----------------------------------------------------------- +// Utility functions. +//----------------------------------------------------------- +// Use the args after `starting_arg_idx` as a series of indices into `obj`, +// indexing into nested ADTs and returning the final indexed object. +ObjectRef IndexIntoNestedObject(ObjectRef obj, TVMArgs args, int starting_arg_idx) { + for (int i = starting_arg_idx; i < args.size(); i++) { + // the object must be an ADT to be able to index into it + if (!obj.as()) { + LOG(FATAL) << "ValueError: Attempted to index into an object that is not an ADT."; + } + int index = args[i]; + auto arr = Downcast>(obj); + // make sure the index is in bounds + if (index >= static_cast(arr.size())) { + LOG(FATAL) << "IndexError: Invalid index (" << index << " >= " << arr.size() << ")."; + } + obj = arr[index]; + } + return obj; +} + +NDArray ConvertNDArrayToDevice(NDArray src, const DLDevice& dev, Allocator* alloc) { + if (src->device.device_type == dev.device_type && src->device.device_id == dev.device_id) { + return src; + } else { + auto res = alloc->Empty(src.Shape(), src->dtype, dev); + res.CopyFrom(src); + return res; + } +} + +ObjectRef ConvertObjectToDevice(ObjectRef src, const Device& dev, Allocator* alloc) { + if (src->IsInstance()) { + return ConvertNDArrayToDevice(Downcast(src), dev, alloc); + } else if (src->IsInstance()) { + std::vector ret; + auto arr = Downcast>(src); + for (size_t i = 0; i < arr.size(); i++) { + ret.push_back(ConvertObjectToDevice(arr[i], dev, alloc)); + } + return Array(ret.begin(), ret.end()); + } else { + return src; + } +} + +TVMRetValue ConvertArgToDevice(TVMArgValue input, Device dev, Allocator* alloc) { + // NOTE: NDArray::FromExternalDLTensor is not safe + // in terms of memory-behavior. + // To be extra careful, we copy DLTensor. + // The developer can still explicitly allocate NDArray + // in TVM Native API or NDArray::FromDLPack to regain zero copy behavior. + TVMRetValue ret; + + if (input.type_code() == kTVMDLTensorHandle) { + DLTensor* tensor = input; + std::vector shape(tensor->shape, tensor->shape + tensor->ndim); + auto dst = alloc->Empty(shape, tensor->dtype, dev); + dst.CopyFrom(tensor); + ret = dst; + } else if (input.IsObjectRef()) { + ret = ConvertObjectToDevice(input.operator ObjectRef(), dev, alloc); + } else { + ret = input; + } + return ret; +} + +TVMRetValue ConvertRegToDevice(TVMRetValue input, Device dev, Allocator* alloc) { + TVMRetValue ret; + if (input.IsObjectRef()) { + ret = ConvertObjectToDevice(input.operator ObjectRef(), dev, alloc); + } else { + ret = input; + } + return ret; +} + +//----------------------------------------------------------- +// VM implementations. +//----------------------------------------------------------- +/*! + * \brief The register type. + */ +using RegType = TVMRetValue; + +/*! + * \brief A representation of a stack frame. + * + * A stack frame is a record containing the information needed + * to restore the caller's virtual machine state after returning + * from a function call. + */ +struct VMFrame { + /*! \brief The return program counter. */ + Index return_pc; + /*! \brief Statically allocated space for objects */ + std::vector register_file; + /*! \brief Register in caller's frame to put return value */ + RegName caller_return_register; + // The following fields are used for PackedFunc call within + // a single function scope. The space is reused across multiple + // packed func calls to increase cache locality and avoid re-allocation + /*! \brief Temporary argument value stack for packed func call. */ + std::vector call_arg_values; + /*! \brief Temporary argument tcode stack for packed func call. */ + std::vector call_arg_tcodes; + + VMFrame(Index pc, Index register_file_size) + : return_pc(pc), register_file(register_file_size), caller_return_register(0) {} +}; + +class VirtualMachineImpl : public VirtualMachine { + public: + //--------------------------------------------------- + // Public facing functions overloading + //--------------------------------------------------- + void LoadExecutable(ObjectPtr exec) final; + + void Init(const std::vector& devices, + const std::vector& alloc_types) final; + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override; + + VMClosure GetClosure(const String& func_name) final { + return this->GetClosureInternal(func_name, false).value(); + } + + void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args, + TVMRetValue* rv) final; + + void SetInstrument(PackedFunc instrument) final { this->instrument_ = instrument; } + + //-------------------------------------------------- + // Additional support arguments functions for VM + //-------------------------------------------------- + /*! + * \brief Internal implementation of GetClosure which also allow none. + * \param func_name The name of the function. + * \param allow_missing Whether none is allowed. + * \return The result + */ + Optional GetClosureInternal(const String& func_name, bool allow_missing); + + /*! + * \brief Set inputs to a function. + * \param func_name The function name. + * \param args args[offset:] are arguments to the function. If the arguments are not of the + * correct device for the function, they will be copied to the device. + * \param offset Starting offset of the arguments in \p args. + * \param with_param_module If set to true, the last argument will be a module and can be invoked + * to get the argument, this is mainly used for debugging purposes and setting composite + * objects. \note This interface works when using VM over RPC by internally converting NDArray in + * the arguments to DLTensor, which is supported in RPC where remote could only have a minimal C + * runtime. + */ + void SetInput(std::string func_name, TVMArgs args, int offset, bool with_param_module = false); + + /*! + * \brief Look up whether the VM has a function by the given name. + * \param func_name the function's name + * \return The function, if it exists. Logs a fatal error if not. + */ + VMFuncInfo LookupVMFuncInfo(const std::string& func_name); + + /*! + * \brief Look up whether the VM has outputs for the given function. + * \param func_name the function's name + * \return The output, if it exists. Logs a fatal error if not. + */ + RegType LookupVMOutput(const std::string& func_name); + + /*! + * \brief Fully bind the argument of a global function and save it in the env. + * \param func_name The global function name to be saved. + * \param save_name The saved name of the function. + * \param include_return Whether forward the return value, set it to false allows + * us to ignore forwarding return value, which can be helpful to do benchmarking + * in RPC environment when return value is complicated ADT. + * + * \param args The arguments to bound to the function. + * \note This function is used by RPC server to help benchmarking. + */ + void SaveClosure(const String& func_name, const String& save_name, bool include_return, + TVMArgs args); + /*! + * \brief Internal function to invoke a closure. + * \param closure_or_packed The closure to be invoked. + * \param args The arguments to the function. + * \return The result value. + */ + RegType InvokeClosureInternal(const ObjectRef& closure_or_packed, + const std::vector& args); + /*! + * \brief Invoke a VM function by interpreting bytecode. + * \param fidx The function index. + * \param args The arguments to the function. + * \return The object representing the result. + */ + RegType InvokeBytecode(Index fidx, const std::vector& args); + + protected: + /*! + * \brief Get function by querying all of the current module's imports. + * \param name The name of the function. + * \return The result function, can return PackedFunc(nullptr) if nothing is found. + */ + PackedFunc GetFuncFromImports(const String& name) { + for (auto& lib : this->imports_) { + PackedFunc func = lib->GetFunction(name, true); + if (func.defined()) return func; + } + return PackedFunc(nullptr); + } + /*! + * \brief Initialize function pool. + */ + void InitFuncPool(); + //------------------------------------------------- + // Instruction interpretations. + //------------------------------------------------- + /*! + * \brief Push a call frame onto the call stack. + * \param ret_pc The program counter to return to. + * \param vm_func The function to be pushed to the call stack. + */ + void PushFrame(Index ret_pc, const VMFuncInfo& vm_func) { + frames_.emplace_back(std::make_unique(ret_pc, vm_func.register_file_size)); + } + /*! + * \brief Pop a frame off the call stack. + */ + void PopFrame() { + ICHECK_GT(frames_.size(), 0); + pc_ = frames_.back()->return_pc; + frames_.pop_back(); + } + /*! + * \brief Write to a VM register. + * \param frame current vm frame. + * \param reg The register to write to. + * \param obj The object to write to. + */ + void WriteRegister(VMFrame* frame, RegName reg, const RegType& obj) { + ICHECK_LT(reg, frame->register_file.size()); + frame->register_file[reg] = obj; + } + /*! + * \brief Read a VM register. + * \param frame current vm frame. + * \param reg The register to read from. + * \return The value of the register. + */ + RegType ReadRegister(VMFrame* frame, RegName reg) { + if (reg < Instruction::kBeginSpecialReg) { + return frame->register_file[reg]; + } + RegType ret; + if (reg == Instruction::kVoidRegister) { + ret = nullptr; + } else { + ICHECK_EQ(reg, Instruction::kVMRegister); + // per convention, ctx ptr must be VirtualMachine* casted to void. + // this and VirtualMachine* may or maynot be the same + // do first cast to VirtualMachine* then to void* + ret = static_cast(static_cast(this)); + } + return ret; + } + /*! + * \brief Run call instruction. + * \param curr_frame The current frame. + * \param inst The call instruction. + */ + virtual void RunInstrCall(VMFrame* curr_frame, Instruction inst); + + /*! \brief Run VM dispatch loop. */ + void RunLoop(); + + /*! + * \brief Retrieve the name of the function identified by the given index. + * \param idx The index into the VM executable function table. + * \return The name of the function. + */ + const std::string& GetFuncName(int idx) { return exec_->func_table[idx].name; } + + /*! + * \brief Retrieve the inputs for a function. + * \param func_name The name of the function. + * \return The function inputs. + */ + const std::vector& GetInputsFor(const std::string& func_name) { + return inputs_[func_name]; + } + + void ClearInputsFor(const std::string& func_name) { inputs_.erase(func_name); } + + //-------------------------------------------------------- + // Internal states for execution. + //-------------------------------------------------------- + /*! \brief The loaded executable. */ + ObjectPtr exec_; + /*! \brief The global constant pool */ + std::vector const_pool_; + /*! + * \brief Function pool to cache functions in func_table + */ + std::vector func_pool_; + //-------------------------------------------------------- + // Executor interface support + //-------------------------------------------------------- + /*! \brief The function name to input register mapping. */ + std::unordered_map> inputs_; + /*! \brief The function name to output register. */ + std::unordered_map outputs_; + /*! \brief A store of closures created by `save_function`. */ + std::unordered_map saved_closures_; + //------------------------------------------------------------ + // VM Instruction execution. + //------------------------------------------------------------ + /*! + * \brief The current stack of call frames. + * \note: Use unique ptr to avoid re-allocation and copy when frames_ get resized. + */ + std::vector> frames_; + /*! \brief The virtual machine PC. */ + Index pc_{0}; + /*! \brief The special return register. */ + RegType return_value_; + /*!\ brief instrument function. */ + PackedFunc instrument_ = nullptr; +}; + +void VirtualMachineImpl::LoadExecutable(ObjectPtr exec) { + this->exec_ = exec; + this->imports_ = exec_->imports(); +} + +void VirtualMachineImpl::Init(const std::vector& devices, + const std::vector& alloc_types) { + // TODO(@yuchen): support multi-device heterogeneous execution + ICHECK_LT(devices.size(), 3) + << "Currently relax vm only supports at most 2 devices (host + device)"; + ICHECK_EQ(devices.size(), alloc_types.size()); + + this->devices.reserve(devices.size()); + this->allocators.reserve(alloc_types.size()); + for (size_t i = 0; i < devices.size(); i++) { + auto alloc = MemoryManager::GetOrCreateAllocator(devices[i], alloc_types[i]); + this->devices.push_back(devices[i]); + this->allocators.push_back(alloc); + } + // Setup constant sections. + this->const_pool_.reserve(exec_->constants.size()); + for (const auto& constant : exec_->constants) { + if (constant.type_code() != kTVMNDArrayHandle) { + this->const_pool_.push_back(constant); + } else { + this->const_pool_.push_back(ConvertRegToDevice(constant, devices[0], allocators[0])); + } + } + // Setup function sections. + this->InitFuncPool(); +} + +VMFuncInfo VirtualMachineImpl::LookupVMFuncInfo(const std::string& func_name) { + ICHECK(exec_) << "The executable is not created yet."; + auto it = this->exec_->func_map.find(func_name); + CHECK(it != this->exec_->func_map.end()) << "ValueError: Unknown function: " << func_name; + + return exec_->func_table[it->second]; +} + +RegType VirtualMachineImpl::LookupVMOutput(const std::string& func_name) { + if (!outputs_.count(func_name)) { + LOG(FATAL) << "ValueError: No output saved for call of \"" << func_name + << "\"; use `invoke_stateful` to call it first."; + } + return outputs_[func_name]; +} + +PackedFunc VirtualMachineImpl::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + if (name == "vm_initialization") { + // initialize the VirtualMachine, takes variable-length arguments + // first argument is a runtime::Module, followed by one or more device_type, device_id, + // and the AllocatorType associated with the device. + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args.size() % 3, 0); + std::vector devices; + std::vector alloc_types; + for (int i = 0; i < args.size(); i += 3) { + Device dev; + int device_type = args[i]; + dev.device_type = DLDeviceType(device_type); + dev.device_id = args[i + 1]; + int type = args[i + 2]; + devices.push_back(dev); + alloc_types.push_back(AllocatorType(type)); + } + this->Init(devices, alloc_types); + }); + } else if (name == "save_function") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK_GE(args.size(), 3); + this->SaveClosure(args[0], args[1], args[2], + TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); + }); + } else if (name == "invoke_closure") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + VMClosure clo = args[0]; + this->InvokeClosurePacked(clo, TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), + rv); + }); + } else if (name == "set_instrument") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + PackedFunc func; + if (args[0].type_code() != kTVMPackedFuncHandle) { + String func_name = args[0]; + const PackedFunc* factory = Registry::Get(func_name); + ICHECK(factory != nullptr) << "Cannot find factory " << func_name; + TVMRetValue rv; + factory->CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.num_args - 1), &rv); + func = rv; + } else { + func = args[0]; + } + this->SetInstrument(func); + }); + } else if (name == "invoke_stateful") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + const auto& m = this->exec_->func_map; + if (m.find(func_name) == m.end()) { + LOG(FATAL) << "ValueError: Unknown function: " << func_name; + } + Index gf_idx = m.at(func_name); + if (!inputs_.count(func_name)) { + LOG(FATAL) << "ValueError: No inputs set for stateful call of " << func_name + << "; use `set_input` first."; + return; + } + outputs_[func_name] = this->InvokeClosureInternal(func_pool_[gf_idx], inputs_[func_name]); + }); + } else if (name == "get_output_arity") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + RegType out = LookupVMOutput(func_name); + // use remaining args as indices + ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef(), args, 1); + // after chasing through the indices, examine the final object + if (const auto* arr = obj.as()) { + *rv = static_cast(arr->size()); + } else { + *rv = -1; + } + }); + } else if (name == "get_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + RegType out = LookupVMOutput(func_name); + // use remaining args as indices + ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef(), args, 1); + if (obj.as()) { + LOG(FATAL) << "ValueError: `get_output` cannot return a tuple for RPC compatibility. " + "Please specify another index argument."; + return; + } + *rv = obj; + }); + } else if (name == "set_input") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1); }); + } else if (name == "set_input_with_param_module") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1, true); }); + } else if (name == "get_function_arity") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name); + *rv = static_cast(vm_func.param_names.size()); + }); + } else if (name == "get_function_param_name") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + int index = args[1]; + const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name); + if (static_cast(index) >= vm_func.param_names.size()) { + LOG(FATAL) << "ValueError: Invalid index for " << func_name << " (" << index << " out of " + << vm_func.param_names.size() << ")"; + } + *rv = vm_func.param_names[index]; + }); + } else { + // default case, look up closure in VM. + if (Optional opt = this->GetClosureInternal(name, true)) { + auto clo = opt.value(); + return PackedFunc([sptr_to_self, this, clo](TVMArgs args, TVMRetValue* rv) { + this->InvokeClosurePacked(clo, args, rv); + }); + } else { + return PackedFunc(nullptr); + } + } +} + +void VirtualMachineImpl::SetInput(std::string func_name, TVMArgs args, int offset, + bool with_param_module) { + const auto& m = exec_->func_map; + if (m.find(func_name) != m.end()) { + Index gf_idx = m.at(func_name); + const VMFuncInfo& vm_func = exec_->func_table[gf_idx]; + size_t params_num = vm_func.num_args; + ICHECK_EQ(args.size() - offset, params_num) + << "The number of provided parameters doesn't match the number of arguments for"; + std::vector func_args(params_num); + + for (int i = offset; i < args.size(); ++i) { + int index = i - offset; + if (with_param_module && i == args.size() - 1) { + // call param func to get the arguments(usually corresponds to param pack.) + func_args[index] = (args[i].operator Module()).GetFunction("get_params")(); + } else { + func_args[index] = ConvertArgToDevice(args[i], devices[0], allocators[0]); + } + } + inputs_[func_name] = func_args; + } else { + LOG(FATAL) << "ValueError: Unknown function: " << func_name; + } +} + +//------------------------------------------ +// Closure handling +//------------------------------------------ +void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args, + TVMRetValue* rv) { + // run packed call if it is a packed func. + if (auto* packed = closure_or_packedfunc.as()) { + packed->CallPacked(args, rv); + return; + } + // run closure call. + auto* clo = closure_or_packedfunc.as(); + ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc "; + + std::vector values(args.size() + 1); + std::vector tcodes(args.size() + 1); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + // per convention, ctx ptr must be VirtualMachine* casted to void. + // this and VirtualMachine* may or maynot be the same + // do first cast to VirtualMachine* then to void* + setter(0, static_cast(static_cast(this))); + std::copy(args.values, args.values + args.size(), values.begin() + 1); + std::copy(args.type_codes, args.type_codes + args.size(), tcodes.begin() + 1); + clo->impl.CallPacked(TVMArgs(values.data(), tcodes.data(), args.size() + 1), rv); +} + +// internal variant version of invoke closurepacked +RegType VirtualMachineImpl::InvokeClosureInternal(const ObjectRef& closure_or_packed, + const std::vector& args) { + RegType ret; + auto* packed = closure_or_packed.as(); + auto* clo = closure_or_packed.as(); + int clo_offset = clo != nullptr ? 1 : 0; + std::vector values(args.size() + clo_offset); + std::vector tcodes(args.size() + clo_offset); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + + if (clo != nullptr) { + setter(0, static_cast(static_cast(this))); + } + for (size_t i = 0; i < args.size(); ++i) { + setter(i + clo_offset, args[i]); + } + + if (packed != nullptr) { + packed->CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), &ret); + } else { + ICHECK(clo != nullptr); + clo->impl.CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), &ret); + } + return ret; +} + +void VirtualMachineImpl::SaveClosure(const String& func_name, const String& save_name, + bool include_return, TVMArgs args) { + VMClosure clo = this->GetClosure(func_name); + std::vector inputs(args.size()); + for (int i = 0; i < args.size(); ++i) { + inputs[i] = ConvertArgToDevice(args[i], this->devices[0], this->allocators[0]); + } + PackedFunc impl = VMClosure::BindLastArgs(clo->impl, inputs); + if (!include_return) { + impl = PackedFunc([impl](TVMArgs args, TVMRetValue* rv) { + TVMRetValue temp; + impl.CallPacked(args, &temp); + }); + } + saved_closures_[save_name] = VMClosure(save_name, impl); +} + +Optional VirtualMachineImpl::GetClosureInternal(const String& func_name, + bool allow_missing) { + // look up saved closures. + auto saved_it = saved_closures_.find(func_name); + if (saved_it != saved_closures_.end()) { + return saved_it->second; + } + auto it = exec_->func_map.find(func_name); + if (it == exec_->func_map.end()) { + if (allow_missing) return NullOpt; + LOG(FATAL) << "ValueError: Unknown function: " << func_name; + } + + Index gf_idx = it->second; + const VMFuncInfo& finfo = exec_->func_table[gf_idx]; + + if (finfo.kind == VMFuncInfo::FuncKind::kVMFunc) { + // NOTE: should not capture strong ref to self and avoid cyclic ref. + auto impl = PackedFunc([gf_idx](TVMArgs args, TVMRetValue* rv) { + // Per convention, ctx ptr is a VirtualMachine* + VirtualMachine* ctx_ptr = static_cast(args[0].operator void*()); + + std::vector inputs(args.size() - 1); + for (size_t i = 0; i < inputs.size(); ++i) { + inputs[i] = args[i + 1]; + } + *rv = static_cast(ctx_ptr)->InvokeBytecode(gf_idx, inputs); + }); + return VMClosure(func_name, impl); + } else { + ICHECK(finfo.kind == VMFuncInfo::FuncKind::kVMTIRFunc) + << "Cannot support closure with function kind " << static_cast(finfo.kind); + PackedFunc tir_func = GetFuncFromImports("__vmtir__" + finfo.name); + ICHECK(tir_func != nullptr) << "Cannot find underlying compiled tir function of VMTIRFunc " + << finfo.name; + auto impl = PackedFunc([this, finfo, tir_func](TVMArgs args, TVMRetValue* rv) { + // Per convention, ctx ptr is a VirtualMachine* + VirtualMachine* ctx_ptr = static_cast(args[0].operator void*()); + ICHECK(ctx_ptr == this); + ICHECK_EQ(args.size() - 1, finfo.num_args) + << "Function " << finfo.name << " expects " << finfo.num_args << " arguments"; + ICHECK_GE(finfo.register_file_size, finfo.num_args + 1); + std::vector reg_file(finfo.register_file_size); + for (int64_t i = 0; i < finfo.num_args; ++i) { + reg_file[i] = args[i + 1]; + } + void* reg_anylist_handle = reg_file.data(); + void* const_anylist_handle = this->const_pool_.data(); + void* func_anylist_handle = this->func_pool_.data(); + tir_func(static_cast(ctx_ptr), reg_anylist_handle, const_anylist_handle, + func_anylist_handle); + // Return value always stored after inputs. + *rv = reg_file[finfo.num_args]; + }); + return VMClosure(func_name, impl); + } +} + +//-------------------------------------------------------------------- +// Instruction interpretations. +//-------------------------------------------------------------------- +RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vector& args) { + const VMFuncInfo& gfunc = exec_->func_table[gf_idx]; + ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); + + // Get the curr instr which might be a potential caller. + Instruction curr_instr = exec_->GetInstruction(pc_); + PushFrame(this->pc_, gfunc); + // Get new frame and set the caller info. + VMFrame* curr_frame = frames_.back().get(); + if (curr_instr.op == Opcode::Call) { + curr_frame->caller_return_register = curr_instr.dst; + } + + // load arguments to the register file + ICHECK_EQ(static_cast(gfunc.num_args), args.size()) + << "ValueError: Invoking function " << gfunc.name << " requires " << gfunc.num_args + << " inputs but only " << args.size() << " inputs are provided."; + for (size_t i = 0; i < args.size(); ++i) { + WriteRegister(frames_.back().get(), i, args[i]); + } + // set program counter + pc_ = gfunc.start_instr; + RunLoop(); + return return_value_; +} + +void VirtualMachineImpl::InitFuncPool() { + func_pool_.resize(exec_->func_table.size()); + + for (size_t func_index = 0; func_index < exec_->func_table.size(); ++func_index) { + const VMFuncInfo& info = exec_->func_table[func_index]; + if (info.kind == VMFuncInfo::FuncKind::kPackedFunc) { + // only look through imports first + PackedFunc func = GetFuncFromImports(info.name); + if (!func.defined()) { + const PackedFunc* p_func = Registry::Get(info.name); + if (p_func != nullptr) func = *(p_func); + } + ICHECK(func.defined()) + << "Error: Cannot find PackedFunc " << info.name + << " in either Relax VM kernel library, or in TVM runtime PackedFunc registry, or in " + "global Relax functions of the VM executable"; + func_pool_[func_index] = func; + + } else { + ICHECK(info.kind == VMFuncInfo::FuncKind::kVMFunc || + info.kind == VMFuncInfo::FuncKind::kVMTIRFunc); + auto clo = this->GetClosure(info.name); + func_pool_[func_index] = clo; + } + } +} + +void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { + DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << GetFuncName(instr.func_idx); + int args_begin_offset = instrument_ != nullptr ? 4 : 0; + // Use the call arg stack from the current frame to increase reuse + // and avoid re-allocation + curr_frame->call_arg_values.resize(args_begin_offset + instr.num_args); + curr_frame->call_arg_tcodes.resize(args_begin_offset + instr.num_args); + + // NOTE: no changes and resize to those vector ref(otherwise can leads to segfault) + // in the remainder part of the function. + std::vector& values = curr_frame->call_arg_values; + std::vector& tcodes = curr_frame->call_arg_tcodes; + + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + for (Index i = 0; i < instr.num_args; ++i) { + Instruction::Arg arg = instr.args[i]; + int arg_index = args_begin_offset + i; + switch (arg.kind()) { + case Instruction::ArgKind::kRegister: { + setter(arg_index, ReadRegister(curr_frame, arg.value())); + break; + } + case Instruction::ArgKind::kImmediate: { + setter(arg_index, arg.value()); + break; + } + case Instruction::ArgKind::kConstIdx: { + setter(arg_index, this->const_pool_[arg.value()]); + break; + } + case Instruction::ArgKind::kFuncIdx: { + ICHECK_LT(static_cast(arg.value()), this->func_pool_.size()); + setter(arg_index, this->func_pool_[arg.value()]); + break; + } + default: { + LOG(FATAL) << "ValueError: Unknown argument kind: " << int(arg.kind()); + } + } + } + TVMArgs args(values.data() + args_begin_offset, tcodes.data() + args_begin_offset, + instr.num_args); + TVMRetValue ret; + + ICHECK_LT(static_cast(instr.func_idx), this->func_pool_.size()); + + if (instrument_ == nullptr) { + this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret); + } else { + // insert light-weight instrument callback + setter(0, func_pool_[instr.func_idx]); + setter(1, GetFuncName(instr.func_idx)); + setter(2, true); + setter(3, nullptr); + TVMRetValue rv; + // store dtype to str since py callback cannot handle dtype atm. + std::vector> temp_dtype; + for (int i = 0; i < instr.num_args; ++i) { + if (tcodes[i + args_begin_offset] == kTVMDataType) { + std::string str_dtype = args[i]; + temp_dtype.emplace_back(std::make_unique(str_dtype)); + setter(i + args_begin_offset, *temp_dtype.back()); + } + } + int ret_kind = static_cast(VMInstrumentReturnKind::kNoOp); + instrument_.CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), &rv); + if (rv.type_code() == kDLInt) { + ret_kind = rv; + } + if (ret_kind != static_cast(VMInstrumentReturnKind::kSkipRun)) { + this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret); + setter(2, false); + setter(3, ret); + instrument_.CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), &rv); + } + } + + // save the return value to the register + // saving to special register is a NOP + if (instr.dst < Instruction::kBeginSpecialReg) { + WriteRegister(curr_frame, instr.dst, ret); + } + // increment pc + pc_++; +} + +void VirtualMachineImpl::RunLoop() { + VMFrame* curr_frame = frames_.back().get(); + + while (true) { + ICHECK_LT(static_cast(pc_), exec_->instr_offset.size()) << "run into invalide section"; + Instruction instr = exec_->GetInstruction(pc_); + switch (instr.op) { + case Opcode::Call: { + this->RunInstrCall(curr_frame, instr); + break; + } + case Opcode::Ret: { + // If we have hit the point from which we started + // running, we should return to the caller breaking + // the dispatch loop. + return_value_ = ReadRegister(curr_frame, instr.result); + RegName caller_return_register = curr_frame->caller_return_register; + PopFrame(); + if (frames_.size() == 0) { + // directly return if no frame in the call stack. + } else { + // return from a local call. + // Update the current frame to be the parent frame. + curr_frame = frames_.back().get(); + WriteRegister(curr_frame, caller_return_register, return_value_); + } + return; + } + case Opcode::Goto: { + pc_ += instr.pc_offset; + break; + } + case Opcode::If: { + int64_t cond_val = ReadRegister(curr_frame, instr.cond); + if (cond_val != 0) { + pc_++; + } else { + ICHECK_GT(instr.false_offset, 1); + pc_ += instr.false_offset; + } + break; + } + } + } +} + +ObjectPtr VirtualMachine::Create() { return make_object(); } + +//---------------------------------------------------------------- +// Profiler can be optionally disabled via a macro to reduce dep. +//---------------------------------------------------------------- +#if TVM_RELAX_VM_ENABLE_PROFILER + +/*! + * \brief An extension of VirtualMachineImpl to support per-op profiling + * It overrides RunInstrCall to add instrumentations around it. + */ +class VirtualMachineProfiler : public VirtualMachineImpl { + public: + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { + if (name == "profile") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string f_name = args[0]; + VMClosure clo = this->GetClosure(f_name); + + std::vector devices; + for (auto dev : this->devices) { + if (dev.device_type > 0) { + devices.push_back(dev); + } + } + + prof_ = profiling::Profiler(devices, {}, {{String("Executor"), String("VM")}}); + + auto inputs = GetInputsFor(f_name); + + bool clear_inputs = false; + if (inputs.size() == 0) { + ICHECK(args.num_args > 1) << "No input is provided"; + TVMArgs f_args(args.values + 1, args.type_codes + 1, args.num_args - 1); + SetInput(f_name, args, 1); + inputs = GetInputsFor(f_name); + clear_inputs = true; + } else { + ICHECK_EQ(args.num_args, 1) << "Inputs are already provided by set_input."; + } + + // warmup + this->InvokeClosureInternal(clo, inputs); + + prof_->Start(); + this->InvokeClosureInternal(clo, inputs); + prof_->Stop(); + + // Return the report as json, since profiling::Report object is not supported by RPC + std::string report_json = prof_->Report()->AsJSON(); + *rv = report_json; + + prof_ = std::nullopt; // releases hardware counters + if (clear_inputs) { + // SetInput modifies the internal states of VM. Undo the change after profiling. + ClearInputsFor(f_name); + } + }); + } else { + return VirtualMachineImpl::GetFunction(name, sptr_to_self); + } + } + + protected: + void RunInstrCall(VMFrame* curr_frame, Instruction inst) override { + bool profiling = false; + if (prof_ && prof_->IsRunning()) { + auto f_name = GetFuncName(inst.func_idx); + std::optional dev; + std::vector arrs; + + auto f_check_ndarray_arg = [&dev, &arrs](const RegType& arg) { + if (arg.type_code() == kTVMNDArrayHandle) { + NDArray arr = arg; + dev = arr->device; + arrs.push_back(arr); + } + }; + + for (Index i = 0; i < inst.num_args; ++i) { + Instruction::Arg arg = inst.args[i]; + if (arg.kind() == Instruction::ArgKind::kRegister) { + auto reg = ReadRegister(curr_frame, arg.value()); + f_check_ndarray_arg(reg); + } else if (arg.kind() == Instruction::ArgKind::kConstIdx) { + const auto& const_val = this->const_pool_[arg.value()]; + f_check_ndarray_arg(const_val); + } + } + + std::unordered_map metrics; + metrics["Argument Shapes"] = profiling::ShapeString(arrs); + + // If a sutiable device is found, enable profiling. + if (dev) { + profiling = true; + prof_->StartCall(f_name, *dev, metrics); + } + } + + VirtualMachineImpl::RunInstrCall(curr_frame, inst); + + if (profiling) { + prof_->StopCall(); + } + } + + private: + std::optional prof_; +}; + +ObjectPtr VirtualMachine::CreateProfiler() { + return make_object(); +} + +#else +ObjectPtr VirtualMachine::CreateProfiler() { + LOG(FATAL) << "Profiler support is disabled"; + return nullptr; +} +#endif // TVM_RELAX_VM_ENABLE_PROFILER +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 8303efff4f20..879db4f3d713 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() { return stack->back(); } +bool IRBuilder::IsInScope() { + std::vector* stack = ThreadLocalBuilderStack(); + return !stack->empty(); +} + namespace details { Namer::FType& Namer::vtable() { @@ -106,6 +111,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") .set_body_method(&IRBuilderNode::Get); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name); diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index a81c56922dff..3d917cee887b 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -26,15 +26,20 @@ namespace ir_builder { namespace ir { void IRModuleFrameNode::ExitWithScope() { - ICHECK_EQ(functions.size(), global_vars.size()); - int n = functions.size(); Map func_map; - for (int i = 0; i < n; ++i) { - func_map.Set(global_vars[i], functions[i]); + CHECK_EQ(functions.size(), global_var_map.size()) + << "All functions must be defined in the IRModule. Got " << global_var_map.size() + << "declared function(s), but only " << functions.size() << "defined function(s)."; + for (const auto& kv : functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined"; + func_map.Set(gv, func); } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; - builder->result = tvm::IRModule(func_map); + auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); + builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs, global_infos); } TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index a8cc452e4f0c..906c453ba0dc 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -17,9 +17,12 @@ * under the License. */ #include +#include #include #include +#include "./utils.h" + namespace tvm { namespace script { namespace ir_builder { @@ -27,12 +30,74 @@ namespace ir { IRModuleFrame IRModule() { ObjectPtr n = make_object(); - n->global_vars.clear(); + n->global_var_map.clear(); n->functions.clear(); return IRModuleFrame(n); } +GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) { + IRModuleFrame frame = FindModuleFrame(); + CHECK(!frame->global_var_map.count(func_name)) + << "ValueError: function " << func_name << " already exists"; + GlobalVar gv = GlobalVar(func_name); + if (func_signature->struct_info_.defined()) { + gv->struct_info_ = tvm::relax::GetStructInfo(func_signature); + } else if (const auto* prim_func = func_signature.as()) { + gv->struct_info_ = + tvm::relax::FuncStructInfo::OpaqueFunc(tvm::relax::StructInfoFromType(prim_func->ret_type)); + } else { + LOG(FATAL) << "Unsupported function type: " << func_signature->GetTypeKey(); + } + CHECK(frame->functions.find(gv) == frame->functions.end()) + << "ValueError: function " << func_name << " has already been defined."; + frame->global_var_map.Set(func_name, gv); + frame->functions.Set(gv, func_signature); + ICHECK(func_signature->checked_type_.defined()) + << "The checked_type_ of function signature must be defined."; + gv->checked_type_ = func_signature->checked_type_; + return gv; +} + +void DefFunction(const String& func_name, const BaseFunc& func) { + IRModuleFrame frame = FindModuleFrame(); + auto it = frame->global_var_map.find(func_name); + CHECK(it != frame->global_var_map.end()) + << "ValueError: function " << func_name << " does not exist, please declare it first."; + const GlobalVar& gv = (*it).second; + frame->functions.Set(gv, func); + CHECK(func->checked_type_.defined()) + << "The checked_type_ of function must be defined, but it is not defined for function `" + << func_name << "`."; + gv->checked_type_ = func->checked_type_; +} + +void ModuleAttrs(Map attrs) { + if (IRBuilder::IsInScope()) { + // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope + IRModuleFrame frame = FindModuleFrame("I.ModuleAttr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs; + } + frame->attrs = attrs; + } +} + +void ModuleGlobalInfos(Map> global_infos) { + if (IRBuilder::IsInScope()) { + IRModuleFrame frame = FindModuleFrame("I.ModuleGlobalInfos"); + if (!frame->global_infos.empty()) { + LOG(FATAL) << "ValueError: Duplicate module global_infos, previous one is:\n" + << frame->global_infos; + } + frame->global_infos = global_infos; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h new file mode 100644 index 000000000000..b12e5e270d89 --- /dev/null +++ b/src/script/ir_builder/ir/utils.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ + +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace ir { + +inline IRModuleFrame FindModuleFrame(const String& method) { + IRBuilder builder = IRBuilder::Current(); + if (Optional frame = builder->FindFrame()) { + const Optional& last_module_frame = builder->GetLastFrame(); + if (last_module_frame.defined() && last_module_frame.value() == frame) { + return frame.value(); + } + } else { + LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method + << "' is called under I.ir_module()"; + } + LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()"; + throw; +} + +inline IRModuleFrame FindModuleFrame() { + IRBuilder builder = IRBuilder::Current(); + if (Optional frame = builder->FindFrame()) { + return frame.value(); + } else { + LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure it" + << " is called under I.ir_module()"; + } + throw; +} + +} // namespace ir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc new file mode 100644 index 000000000000..c78b9e73c534 --- /dev/null +++ b/src/script/ir_builder/relax/frame.cc @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +void SeqExprFrameNode::ExitWithScope() { + // At this moment, there should be at most one BlockFrame which hasn't ended. In this case, call + // its `ExitBlockFrame` and check if there is any more unended BlockFrame. + if (Optional block_frame = IRBuilder::Current()->GetLastFrame()) { + block_frame.value()->ExitWithScope(); + ICHECK(!IRBuilder::Current()->GetLastFrame().defined()) + << "ValueError: There is some remaining BlockFrame that is not properly popped out."; + } + RelaxFrameNode::ExitWithScope(); +} + +void SeqExprFrameNode::EnterWithScope() { + RelaxFrameNode::EnterWithScope(); + BindingBlock()->EnterWithScope(); +} + +void FunctionFrameNode::ExitWithScope() { + using ir::IRModuleFrame; + using tvm::relax::Expr; + IRBuilder builder = IRBuilder::Current(); + SeqExprFrameNode::ExitWithScope(); + // Step 1: Create the function. + CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " + "`return` to return an Expr"; + this->block_builder->BeginScope(params); + Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); + auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); + this->block_builder->EndScope(); + tvm::relax::Function func(/*params=*/params, + /*body=*/body, + /*ret_struct_info=*/ret_struct_info, + /*attrs=*/dict_attrs); + // Step 2: Update IRModule. + if (builder->frames.empty()) { + // Case 0. No outer frame, return function directly + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = func; + } else if (Optional opt_frame = builder->FindFrame()) { + // Case 1. A global function of an IRModule + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // First time visiting the function. + ir::DeclFunction(func_name, func); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); + } else { + LOG(FATAL) << "ValueError: Cannot find where to insert Relax.Function"; + } +} + +void BlockFrameNode::EnterWithScope() { + // Step 1. If the last frame is a block frame. The start of a new block frame marks the end of the + // last block frame. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + if (block_frame.defined()) { + block_frame.value()->ExitWithScope(); + // Block frames cannot appear consecutively. + ICHECK(!IRBuilder::Current()->GetLastFrame()); + } + // Step 2. Deal with the new block frame. + RelaxFrameNode::EnterWithScope(); + Optional func_frame = IRBuilder::Current()->FindFrame(); + CHECK(func_frame.defined()) + << "ValueError: Cannot find FunctionFrame when creating BindingBlocks, Please ensure " + "creating the block under Relax function scope."; + const tvm::relax::BlockBuilder& block_builder = func_frame.value()->block_builder; + if (is_dataflow) { + block_builder->BeginDataflowBlock(); + } else { + block_builder->BeginBindingBlock(); + } +} + +class DataflowBlockRewriter : public tvm::relax::ExprMutator { + public: + static tvm::relax::DataflowBlock Rewrite(const tvm::relax::DataflowBlock& block, + const Array& output_vars) { + DataflowBlockRewriter rewriter(output_vars); + return Downcast(rewriter.VisitBindingBlock(block)); + } + + private: + explicit DataflowBlockRewriter(const Array& output_vars) { + for (const tvm::relax::Var& var : output_vars) { + output_var_set_.insert(var.get()); + } + } + + tvm::relax::Var VisitVarDef_(const tvm::relax::DataflowVarNode* op) final { + auto it = output_var_set_.find(op); + if (it != output_var_set_.end()) { + // Rewrite dataflow vars to global vars + auto n = make_object(*op); + tvm::relax::Var new_var(n); + this->var_remap_[op->vid] = new_var; + return new_var; + } else { + return GetRef(op); + } + } + + private: + std::unordered_set output_var_set_; +}; + +void BlockFrameNode::ExitWithScope() { + // Step 1. Pop the current frame out of the frame stack. + RelaxFrameNode::ExitWithScope(); + + // Step 2. Get the constructed binding block from the block builder. The block should have at + // lease one binding - otherwise, the block is not supposed to be created. + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + tvm::relax::BindingBlock block = block_builder->EndBlock(); + if (block->bindings.empty()) { + return; + } + + // Step 3. Rewrite the dataflow block. + if (is_dataflow) { + // Step 3.1. Rewrite block binding + block = DataflowBlockRewriter::Rewrite(Downcast(block), output_vars); + + // Step 3.2. Collect global vars' reference in bindings + Map new_global_vars; + for (const tvm::relax::Binding& binding : block->bindings) { + if (!binding->var->IsInstance()) { + new_global_vars.Set(binding->var->vid, binding->var); + } + } + + // Step 3.3. Rewrite output vars + Array new_output_vars; + for (const auto& var : output_vars) { + auto it = new_global_vars.find(var->vid); + ICHECK(it != new_global_vars.end()); + new_output_vars.push_back((*it).second); + } + output_vars = std::move(new_output_vars); + } + + // Step 3. Get the last frame from the IRBuilder frame stack. + Optional opt_last_frame = IRBuilder::Current()->GetLastFrame(); + ICHECK(opt_last_frame.defined()); + RelaxFrame last_frame = opt_last_frame.value(); + + // Step 4. Since we popped out any possible block frame when entering the "with" scope of the + // current frame, the last frame cannot be a block frame. + ICHECK(!last_frame->IsInstance()); + + // Step 5. Push the block frame into the corresponding field of the last frame. + if (const auto* seq_frame = last_frame.as()) { + ICHECK(!seq_frame->output.defined()) + << "The function is not expected to have output values when emitting blocks."; + auto frame = GetRef(seq_frame); + frame->binding_blocks.push_back(block); + } else { + LOG(FATAL) << "ValueError: Currently the last frame is supposed to be either a function frame " + "or a block frame. However, the last frame is \"" + << last_frame->GetTypeKey() << "\"."; + } + + // Step 6. Start another binding block when a dataflow block ended. + if (is_dataflow) { + BindingBlock()->EnterWithScope(); + } +} + +void IfFrameNode::EnterWithScope() { + const Array& frames = IRBuilder::Current()->frames; + for (const IRBuilderFrame& frame : frames) { + const auto* block_frame = frame.as(); + if (block_frame && block_frame->is_dataflow) { + LOG(FATAL) << "ValueError: Cannot create an IfFrame inside a dataflow block."; + } + } + RelaxFrameNode::EnterWithScope(); +} + +void IfFrameNode::ExitWithScope() { + RelaxFrameNode::ExitWithScope(); + CHECK(then_expr.defined()) + << "ValueError: The body of then part is expected to be defined before exiting."; + CHECK(then_expr.defined()) + << "ValueError: The body of else part is expected to be defined before exiting."; + auto body = tvm::relax::If(condition, then_expr.value(), else_expr.value()); + var = Emit(body); + IRBuilder::Name(var_name, var); +} + +void ThenFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("R.Then"); + CHECK(!frame->then_expr.defined()) + << "ValueError: Duplicate then branch declaration, previous one is " + << frame->then_expr.value(); + SeqExprFrameNode::EnterWithScope(); +} + +void ThenFrameNode::ExitWithScope() { + SeqExprFrameNode::ExitWithScope(); + String var_name; + output = GetSeqExprForBranch(GetRef(this), &var_name); + IfFrame frame = FindIfFrame("R.Then"); + frame->then_expr = output; + frame->var_name = var_name; +} + +void ElseFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("R.Else"); + CHECK(frame->then_expr.defined()) << "The else branch should follow then branch"; + CHECK(!frame->else_expr.defined()) + << "ValueError: Duplicate else branch declaration, previous one is " + << frame->else_expr.value(); + SeqExprFrameNode::EnterWithScope(); +} + +void ElseFrameNode::ExitWithScope() { + SeqExprFrameNode::ExitWithScope(); + String var_name; + output = GetSeqExprForBranch(GetRef(this), &var_name); + IfFrame frame = FindIfFrame("R.Else"); + frame->else_expr = output; + CHECK(frame->var_name == var_name) + << "This last binding of both branches must have the same variable."; +} + +TVM_REGISTER_NODE_TYPE(FunctionFrameNode); +TVM_REGISTER_NODE_TYPE(SeqExprFrameNode); +TVM_REGISTER_NODE_TYPE(BlockFrameNode); +TVM_REGISTER_NODE_TYPE(IfFrameNode); +TVM_REGISTER_NODE_TYPE(ThenFrameNode); +TVM_REGISTER_NODE_TYPE(ElseFrameNode); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc new file mode 100644 index 000000000000..71a0651de859 --- /dev/null +++ b/src/script/ir_builder/relax/ir.cc @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +///////////////////////////////// Vars ////////////////////////////////// + +using tvm::script::ir_builder::details::Namer; + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using tvm::relax::VarNode; + using tvm::relax::IdNode; + const VarNode* var = node.as(); + IdNode* vid = const_cast(var->vid.get()); + vid->name_hint = name; + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using tvm::relax::DataflowVarNode; + using tvm::relax::IdNode; + const DataflowVarNode* var = node.as(); + IdNode* vid = const_cast(var->vid.get()); + vid->name_hint = name; + }); + +/////////////////////////////// Function //////////////////////////////// + +FunctionFrame Function() { + ObjectPtr n = make_object(); + const IRBuilder& ir_builder = IRBuilder::Current(); + Optional mod = NullOpt; + if (const Optional mod_frame = ir_builder->GetLastFrame()) { + mod = tvm::IRModule(mod_frame.value()->functions); + } + n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod); + return FunctionFrame(n); +} + +tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_info) { + FunctionFrame frame = FindFunctionFrame("R.Arg"); + tvm::relax::Var var(name, struct_info); + frame->params.push_back(var); + return var; +} + +void FuncName(const String& name) { + FunctionFrame frame = FindFunctionFrame("R.func_name"); + if (frame->name.defined()) { + LOG(FATAL) << "ValueError: Duplicate function name, previous one is: \"" << frame->name.value() + << "\""; + } + frame->name = name; +} + +void FuncAttrs(Map attrs) { + FunctionFrame frame = FindFunctionFrame("R.func_attr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "ValueError: Duplicate function attrs, previous one is:\n" << frame->attrs; + } + frame->attrs = attrs; +} + +void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) { + FunctionFrame frame = FindFunctionFrame("R.func_ret_struct_info"); + if (frame->ret_struct_info.defined()) { + LOG(FATAL) << "ValueError: Duplicate function return struct info, previous one is:\n " + << frame->ret_struct_info.value(); + } + frame->ret_struct_info = ret_sinfo; +} + +void FuncRetValue(const tvm::relax::Expr& value) { + // Step 0. Normalize the value. + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + tvm::relax::Expr normalized_value = block_builder->Normalize(value); + + // Step 1. The current Relax TVMScript syntax only allows function return appearing at the end of + // a function body. Therefore if there is any unended block frame when dealing with function + // return, we should end the block frame. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + if (block_frame.defined()) { + block_frame.value()->ExitWithScope(); + ICHECK(!IRBuilder::Current()->FindFrame()) + << "ValueError: Relax functions don't support return in true/false branch of If Node."; + } + // Step 2. Add the output value to the function frame. + FunctionFrame frame = FindFunctionFrame("return"); + CHECK(!frame->output.defined()) + << "ValueError: Relax functions don't support multiple return statement. Please make sure " + "the return statement appears at the end of function."; + + frame->output = std::move(normalized_value); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo").set_body_typed(FuncRetStructInfo); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); + +///////////////////////////// BindingBlock ////////////////////////////// + +BlockFrame Dataflow() { + ObjectPtr n = make_object(); + n->is_dataflow = true; + n->block_ended = false; + return BlockFrame(n); +} + +BlockFrame BindingBlock() { + ObjectPtr n = make_object(); + n->is_dataflow = false; + n->block_ended = false; + return BlockFrame(n); +} + +void DataflowBlockOutput(const Array& vars) { + // Step 1. Check that we're in a Dataflow block that is not ended. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + CHECK(block_frame.defined() && block_frame.value()->is_dataflow) + << "ValueError: `R.output` should appear inside a dataflow block. However, the current " + "innermost block is not a dataflow block."; + CHECK(!block_frame.value()->block_ended) + << "ValueError: It is not allowed for a dataflow block to have multiple output operation."; + + // Step 2. Mark the block frame ended of construction, so that any followup binding after this + // mark in the dataflow block will lead to an error. + block_frame.value()->block_ended = true; + + // Step 3. All the output variables must be global variables and must be emitted by this dataflow + // block. + const Array& emitted_vars = block_frame.value()->emitted_vars; + for (const tvm::relax::Var& var : vars) { + CHECK(std::find(emitted_vars.begin(), emitted_vars.end(), var) != emitted_vars.end()) + << "ValueError: An output variable is not emitted by this dataflow block. Please make sure " + "all dataflow block output variables are emitted exactly by this block."; + block_frame.value()->output_vars.push_back(var); + } +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Dataflow").set_body_typed(Dataflow); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.BindingBlock").set_body_typed(BindingBlock); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.DataflowBlockOutput") + .set_body_typed(DataflowBlockOutput); + +/////////////////////////////// Bindings /////////////////////////////// + +tvm::relax::Var Emit(const tvm::relax::Expr& expr, + const Optional& annotate_struct_info) { + using tvm::relax::GetStructInfo; + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + if (annotate_struct_info.defined()) { + const auto& sinfo = annotate_struct_info.value(); + if (!expr->struct_info_.defined()) { + UpdateStructInfo(expr, sinfo); + } else { + CHECK(StructInfoBaseCheck(sinfo, GetStructInfo(expr)) != tvm::relax::BaseCheckResult::kFailL0) + << "Invalid annotation. Got rhs value struct info: " << GetStructInfo(expr) + << ", given struct info: " << sinfo; + } + } + tvm::relax::Var var = block_builder->Emit(expr); + block_frame->emitted_vars.push_back(var); + return var; +} + +tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, + const tvm::relax::StructInfo& struct_info) { + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + + tvm::relax::Var var = block_builder->EmitMatchCast(value, struct_info); + block_frame->emitted_vars.push_back(var); + return var; +} + +tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding) { + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + block_builder->EmitNormalized(binding); + block_frame->emitted_vars.push_back(binding->var); + return binding->var; +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitVarBinding").set_body_typed(EmitVarBinding); + +///////////////////////////// If Then Else ///////////////////////////// + +IfFrame If(tvm::relax::Expr condition) { + ObjectPtr n = make_object(); + n->condition = condition; + n->then_expr = NullOpt; + n->else_expr = NullOpt; + return IfFrame(n); +} + +ThenFrame Then() { + ObjectPtr n = make_object(); + return ThenFrame(n); +} + +ElseFrame Else() { + ObjectPtr n = make_object(); + return ElseFrame(n); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.If").set_body_typed(If); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Then").set_body_typed(Then); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Else").set_body_typed(Else); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h new file mode 100644 index 000000000000..ae91d05769bd --- /dev/null +++ b/src/script/ir_builder/relax/utils.h @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +inline FunctionFrame FindFunctionFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->FindFrame()) { + return frame.value(); + } + LOG(FATAL) << "ValueError: Function frame not find. Please ensure '" << method + << "' is called under R.function()"; + throw; +} + +inline IfFrame FindIfFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + return frame.value(); + } else { + LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method + << "' is called under R.if_()"; + } + throw; +} + +inline tvm::relax::BlockBuilder GetBlockBuilder() { + Optional frame = IRBuilder::Current()->FindFrame(); + CHECK(frame.defined()) << "ValueError: Relax Function frame not find. Please ensure " + "assignment is called under R.function()"; + return frame.value()->block_builder; +} + +inline BlockFrame CheckBlockFrameExistAndUnended() { + // We check if the current block is "ended" - if a block is ended, it is not allowed to emit new + // bindings into this block, and we should throw exceptions. + + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + CHECK(block_frame.defined()) << "ValueError: Block frame not find"; + CHECK(!block_frame.value()->block_ended) + << "ValueError: New binding is not allowed after dataflow block output."; + return block_frame.value(); +} + +inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String* var_name) { + // Step 0. Check frame type + std::string method; + if (frame->IsInstance()) { + method = "R.Then"; + } else if (frame->IsInstance()) { + method = "R.Else"; + } else { + ICHECK(false) << "TypeError: Unsupported frame type: " << frame->GetTypeKey(); + } + + // Step 1. Check non-empty block and last binding is non-dataflow + CHECK(!frame->binding_blocks.empty()) + << "Empty body is not allowed for '" << method << "' statements."; + const tvm::relax::BindingBlock& last_block = frame->binding_blocks.back(); + CHECK(!last_block->bindings.empty()) << "Blocks are expected to be non-empty."; + + // Step 2. Collect body from the last binding. + tvm::relax::Expr body; + const tvm::relax::Binding& last_binding = last_block->bindings.back(); + if (const auto* var_binding = last_binding.as()) { + CHECK(!var_binding->var->IsInstance()) + << "A non-dataflow var is expected in the last binding of '" << method << "'."; + body = var_binding->value; + *var_name = var_binding->var->name_hint(); + } else if (const auto* match_cast = last_binding.as()) { + CHECK(!match_cast->var->IsInstance()) + << "A non-dataflow var is expected in the last binding of '" << method << "'."; + body = var_binding->value; + *var_name = match_cast->var->name_hint(); + } else { + ICHECK(false) << "TypeError: Unsupported binding type: " << last_binding->GetTypeKey(); + } + + // Step 3. Re-collect binding blocks to remove the last binding. + Array new_blocks(frame->binding_blocks.begin(), + frame->binding_blocks.end() - 1); + Array last_block_bindings(last_block->bindings.begin(), + last_block->bindings.end() - 1); + new_blocks.push_back(tvm::relax::BindingBlock(last_block_bindings)); + + return tvm::relax::SeqExpr(new_blocks, body); +} + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 1e63201a40dd..dd8d3c2ed3f3 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -41,9 +42,17 @@ void PrimFuncFrameNode::ExitWithScope() { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; } else if (Optional opt_frame = builder->FindFrame()) { - ir::IRModuleFrame frame = opt_frame.value(); - frame->global_vars.push_back(GlobalVar(name.value_or(""))); - frame->functions.push_back(func); + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const ir::IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // Case. First time visiting the function. + ir::DeclFunction(func_name, func); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); } else { LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; } diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 7ccc132fa1fe..f3b547532cfd 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -87,7 +87,7 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { * \return The top frame of BlockFrame. */ inline BlockFrame FindBlockFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + if (Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); } else if (Optional frame = IRBuilder::Current()->FindFrame()) { LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.block(). " diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index e726cd42a241..54194e7e2a41 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -549,7 +549,11 @@ void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) { if (doc->rhs) { output_ << " = "; if (const auto* tuple_doc = doc->rhs.as()) { - PrintJoinedDocs(tuple_doc->elements, ", "); + if (tuple_doc->elements.size() > 1) { + PrintJoinedDocs(tuple_doc->elements, ", "); + } else { + PrintDoc(doc->rhs.value()); + } } else { PrintDoc(doc->rhs.value()); } diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 065cfe5168ad..f23820927db6 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -64,6 +64,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) std::sort(functions.begin(), functions.end()); With f(d); (*f)->AddDispatchToken(d, "ir"); + IdDoc module_doc = d->Define(mod, f(), GetBindingName(d).value_or("Module")); + if (mod->attrs.defined() && !mod->attrs->dict.empty()) { + (*f)->stmts.push_back( + ExprStmtDoc(IR(d, "module_attrs") // + ->Call({d->AsDoc(mod->attrs, p->Attr("attrs"))}))); + } + if (mod->global_infos.defined() && !mod->global_infos.empty()) { + (*f)->stmts.push_back(ExprStmtDoc( + IR(d, "module_global_infos") // + ->Call({d->AsDoc(mod->global_infos, p->Attr("global_infos"))}))); + } + // Declare GlobalVars first + IdDoc module_alias = d->cfg->module_alias.empty() ? module_doc : IdDoc(d->cfg->module_alias); + for (const auto& entry : functions) { + const GlobalVar& gv = entry.gv; + d->Define(gv, f(), [=]() { + return d->AsDoc(mod, p->Attr("global_vars"))->Attr(gv->name_hint); + }); + } + // Print functions for (const auto& entry : functions) { const GlobalVar& gv = entry.gv; const BaseFunc& func = entry.func; @@ -79,8 +99,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) (*f)->stmts.push_back(Downcast(doc)); } } - return HeaderWrapper(d, ClassDoc(IdDoc(GetBindingName(d).value_or("Module")), - {IR(d, "ir_module")}, (*f)->stmts)); + return HeaderWrapper(d, ClassDoc(module_doc, {IR(d, "ir_module")}, (*f)->stmts)); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -93,6 +112,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return IR(d, "GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))}); }); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](GlobalInfo ginfo, ObjectPath p, IRDocsifier d) -> Doc { + return IR(d, "dummy_global_info")->Call({}); + }); + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc { return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))}); diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index fd5003073afb..d8dda4f74910 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -149,6 +149,10 @@ IRDocsifier::IRDocsifier(const PrinterConfig& cfg) { auto n = make_object(); n->cfg = cfg; n->dispatch_tokens.push_back(""); + // Define builtin keywords according to cfg. + for (const String& keyword : cfg->GetBuiltinKeywords()) { + n->defined_names.insert(keyword); + } data_ = std::move(n); } diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc new file mode 100644 index 000000000000..8a50fe969850 --- /dev/null +++ b/src/script/printer/relax/binding.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& d, // + const Optional& var, const Optional& ann) { + using relax::SeqExpr; + ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); + std::vector> branches{ + PrintSeqExpr(Downcast(n->true_branch), n_p->Attr("true_branch"), d, false), + PrintSeqExpr(Downcast(n->false_branch), n_p->Attr("false_branch"), d, false), + }; + if (var.defined()) { + for (Array& stmts : branches) { + ExprDoc ret = Downcast(stmts.back())->expr; + stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann)); + } + } + return IfDoc(cond, branches[0], branches[1]); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::MatchCast n, ObjectPath n_p, IRDocsifier d) -> Doc { + using relax::StructInfo; + using relax::MatchStructInfo; + Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ExprDoc rhs = Relax(d, "match_cast") + ->Call({d->AsDoc(n->value, n_p->Attr("value")), + d->AsDoc(n->struct_info, n_p->Attr("struct_info_"))}); + ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); + return AssignDoc(lhs, rhs, ann); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::VarBinding n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (const auto if_ = n->value.as()) { + Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); + return PrintIfExpr(GetRef(if_), n_p->Attr("value"), d, lhs, ann); + } else if (n->value->IsInstance()) { + IdDoc lhs = DefineVar(n->var, d->frames.back(), d); + d->cfg->binding_names.push_back(lhs->name); + Doc ret = d->AsDoc(n->value, n_p->Attr("value")); + d->cfg->binding_names.pop_back(); + return ret; + } else { + ExprDoc rhs = d->AsDoc(n->value, n_p->Attr("value")); + Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); + return AssignDoc(lhs, rhs, ann); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](relax::If n, ObjectPath n_p, IRDocsifier d) -> Doc { + return PrintIfExpr(n, n_p, d, NullOpt, NullOpt); + }); + +TVM_SCRIPT_REPR(relax::MatchCastNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::VarBindingNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::IfNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc new file mode 100644 index 000000000000..c32ab8be2f0e --- /dev/null +++ b/src/script/printer/relax/call.cc @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +class AttrPrinter : public tvm::AttrVisitor { + public: + explicit AttrPrinter(const ObjectPath& p, const IRDocsifier& d, Array* keys, + Array* values) + : p(p), d(d), keys(keys), values(values) {} + + void Visit(const char* key, double* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Float(*value, p->Attr(key))); + } + + void Visit(const char* key, int64_t* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Int(*value, p->Attr(key))); + } + + void Visit(const char* key, uint64_t* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Int(*value, p->Attr(key))); + } + + void Visit(const char* key, int* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Int(*value, p->Attr(key))); + } + + void Visit(const char* key, bool* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Boolean(*value, p->Attr(key))); + } + + void Visit(const char* key, std::string* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Str(*value, p->Attr(key))); + } + + void Visit(const char* key, DataType* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::DataType(*value, p->Attr(key))); + } + + void Visit(const char* key, runtime::ObjectRef* value) final { + keys->push_back(key); + values->push_back(d->AsDoc(*value, p->Attr(key))); + } + + void Visit(const char* key, void** value) final { + LOG(FATAL) << "TypeError: void is not allowed in Attrs"; + } + + void Visit(const char* key, runtime::NDArray* value) final { + LOG(FATAL) << "TypeError: NDArray is not allowed in Attrs"; + } + + const ObjectPath& p; + const IRDocsifier& d; + Array* keys; + Array* values; +}; + +ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& n_p, const IRDocsifier& d) { + // TODO(@junrushao): handle callee better + if (const auto* ext = n.as()) { + return LiteralDoc::Str(ext->global_symbol, n_p); + } else { + return d->AsDoc(n, n_p); + } +} + +Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& n_p, + const IRDocsifier& d) { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); + if (!n->op.same_as(call_tir_op) && !n->op.same_as(call_dps_packed_op)) { + return NullOpt; + } + ICHECK(n->args.size() == 2 || n->args.size() == 3); + ICHECK(n->sinfo_args.size() == 1); + Array args; + Array kwargs_keys; + Array kwargs_values; + // Step 1. Print n->args[0], the callee + args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayIndex(0), d)); + // Step 2. Print n->args[1], the input arguments + args.push_back(d->AsDoc(n->args[1], n_p->Attr("args")->ArrayIndex(1))); + // Step 3. Print n->sinfo_args, the output struct info + relax::StructInfo o_sinfo = n->sinfo_args[0]; + ObjectPath o_sinfo_p = n_p->Attr("sinfo_args")->ArrayIndex(0); + kwargs_keys.push_back("out_sinfo"); + if (const auto* o = o_sinfo.as()) { + Array fields; + ObjectPath fields_p = o_sinfo_p->Attr("fields"); + for (int i = 0, l = o->fields.size(); i < l; ++i) { + fields.push_back(d->AsDoc(o->fields[i], fields_p->ArrayIndex(i))); + } + kwargs_values.push_back(ListDoc(fields)); + } else { + kwargs_values.push_back(d->AsDoc(o_sinfo, o_sinfo_p)); + } + if (n->op.same_as(call_dps_packed_op)) { + return Relax(d, "call_dps_packed")->Call(args, kwargs_keys, kwargs_values); + } + // Step 4. Print n->args[2], the tir variables + if (n->args.size() == 3) { + kwargs_keys.push_back("tir_vars"); + kwargs_values.push_back(d->AsDoc(n->args[2], n_p->Attr("args")->ArrayIndex(2))); + } + return Relax(d, "call_tir")->Call(args, kwargs_keys, kwargs_values); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc { + // Special case: call_tir, call_dps_packed + if (Optional doc = PrintCallTIRDPSPacked(n, n_p, d)) { + return doc.value(); + } + ExprDoc prefix{nullptr}; + Array args; + Array kwargs_keys; + Array kwargs_values; + // Step 1. Print op + if (const auto* op = n->op.as()) { + prefix = Relax(d, "call_packed"); + args.push_back(LiteralDoc::Str(op->global_symbol, n_p->Attr("op"))); + } else if (const auto* op = n->op.as()) { + std::string name = op->name; + if (name.rfind("relax.", 0) == 0) { + prefix = Relax(d, name.substr(6)); + } else { + prefix = IdDoc(name); + } + prefix->source_paths.push_back(n_p->Attr("op")); + } else if (n->op->IsInstance() || + n->op->IsInstance()) { + prefix = d->AsDoc(n->op, n_p->Attr("op")); + } else { + LOG(FATAL) << "TypeError: Unsupported op: " << n->op->GetTypeKey(); + } + // Step 2. Print args + if (!n->args.empty()) { + args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayIndex(0), d)); + } + for (int i = 1, l = n->args.size(); i < l; ++i) { + args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayIndex(i))); + } + // Step 3. Print attrs + if (n->attrs.defined()) { + if (n->op->IsInstance()) { + kwargs_keys.push_back("attrs_type_key"); + kwargs_values.push_back(LiteralDoc::Str(n->attrs->GetTypeKey(), n_p->Attr("attrs"))); + } + if (const auto* attrs = n->attrs.as()) { + std::vector> sorted; + for (const auto& kv : attrs->dict) { + sorted.push_back(kv); + } + std::sort(sorted.begin(), sorted.end()); + for (const auto& kv : sorted) { + kwargs_keys.push_back(kv.first); + kwargs_values.push_back( + d->AsDoc(kv.second, n_p->Attr("attrs")->Attr(kv.first))); + } + } else { + AttrPrinter printer(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values); + const_cast(n->attrs.get())->VisitAttrs(&printer); + } + } + // Step 4. Print type_args + if (n->sinfo_args.size() > 0) { + ObjectPath sinfo_args_p = n_p->Attr("sinfo_args"); + Array sinfo_args; + for (int i = 0, l = n->sinfo_args.size(); i < l; ++i) { + sinfo_args.push_back( + d->AsDoc(n->sinfo_args[i], sinfo_args_p->ArrayIndex(i))); + } + kwargs_keys.push_back("sinfo_args"); + kwargs_values.push_back(TupleDoc(sinfo_args)); + } + return prefix->Call(args, kwargs_keys, kwargs_values); + }); + +TVM_SCRIPT_REPR(relax::CallNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc new file mode 100644 index 000000000000..66d7d187d0c8 --- /dev/null +++ b/src/script/printer/relax/expr.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::PrimValue n, ObjectPath n_p, IRDocsifier d) -> Doc { + // TODO(@junrushao): float numbers + return Relax(d, "prim_value")->Call({d->AsDoc(n->value, n_p->Attr("value"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::StringImm n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "str")->Call({LiteralDoc::Str(n->value, n_p->Attr("value"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::DataTypeImm n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "dtype")->Call({LiteralDoc::DataType(n->value, n_p->Attr("value"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::Tuple n, ObjectPath n_p, IRDocsifier d) -> Doc { + // TODO(@junrushao): revisit tuple printing + if (n->fields.empty()) { + return Relax(d, "tuple")->Call({}); + } + Array fields_doc; + ObjectPath fields_p = n_p->Attr("fields"); + for (int i = 0, l = n->fields.size(); i < l; ++i) { + fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayIndex(i))); + } + return TupleDoc(fields_doc); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::TupleGetItem n, ObjectPath n_p, IRDocsifier d) -> Doc { + ExprDoc idx = LiteralDoc::Int(n->index, n_p->Attr("index")); + return d->AsDoc(n->tuple, n_p->Attr("tuple"))[{idx}]; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ShapeExpr n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array values_doc; + ObjectPath values_p = n_p->Attr("values"); + for (int i = 0, l = n->values.size(); i < l; ++i) { + values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayIndex(i), d)); + } + return Relax(d, "shape")->Call({ListDoc(values_doc)}); + }); + +Optional SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) { + DataType dtype = n.DataType(); + const void* data = n->data; + if (n->ndim != 0 || n->device.device_type != kDLCPU) { + return NullOpt; + } + if (dtype == DataType::Int(32)) { + return LiteralDoc::Int(*reinterpret_cast(data), p); + } else if (dtype == DataType::Int(64)) { + return LiteralDoc::Int(*reinterpret_cast(data), p); + } else if (dtype == DataType::Float(32)) { + return LiteralDoc::Float(*reinterpret_cast(data), p); + } else if (dtype == DataType::Float(64)) { + return LiteralDoc::Float(*reinterpret_cast(data), p); + } else if (dtype == DataType::Bool()) { + return LiteralDoc::Boolean(*reinterpret_cast(data), p); + } else { + return NullOpt; + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::Constant n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (Optional s = SpecialScalar(n->data, n_p->Attr("data"))) { + return Relax(d, "const") + ->Call({ + s.value(), + LiteralDoc::DataType(n->data.DataType(), n_p->Attr("data")->Attr("dtype")), + }); + } + return d->AddMetadata(n); + }); + +Doc PrintRelaxVar(relax::Var n, ObjectPath p, IRDocsifier d) { + if (!d->IsVarDefined(n)) { + ExprDoc ann = d->AsDoc(n->struct_info_, p->Attr("struct_info_")); + Frame f = d->frames.back(); + ExprDoc var = DefineVar(n, f, d); + f->stmts.push_back(AssignDoc(var, NullOpt, ann)); + } + return d->GetVarDoc(n).value(); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("", PrintRelaxVar); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("", PrintRelaxVar); + +TVM_SCRIPT_REPR(relax::PrimValueNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::StringImmNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::DataTypeImmNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::TupleNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::TupleGetItemNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ShapeExprNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::VarNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::DataflowVarNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ConstantNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc new file mode 100644 index 000000000000..fd7bdddfcaf5 --- /dev/null +++ b/src/script/printer/relax/function.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_REGISTER_NODE_TYPE(RelaxFrameNode); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](relax::Function n, ObjectPath n_p, IRDocsifier d) -> Doc { + std::unordered_set func_vars; + With f(d); + IdDoc func_name = d->Define(n, f(), FindFunctionName(d, n).value_or("main")); + (*f)->AddDispatchToken(d, "relax"); + (*f)->is_func = true; + (*f)->func_vars = &func_vars; + // Step 1. Print the return type + Optional ret_type = NullOpt; + if (const auto& func_sinfo = relax::MatchStructInfo(n)) { + ret_type = d->AsDoc(func_sinfo.value()->ret, // + n_p->Attr("struct_info_")->Attr("ret")); + } + // Step 2. Print params + Array params; + { + ObjectPath params_p = n_p->Attr("params"); + for (int i = 0, l = n->params.size(); i < l; ++i) { + params.push_back(AssignDoc( + /*lhs=*/DefineVar(n->params[i], *f, d), + /*rhs=*/NullOpt, StructInfoAsAnn(n->params[i], params_p->ArrayIndex(i), d, NullOpt))); + } + } + // Step 3. Clean up func variables + (*f)->func_vars = nullptr; + // Step 4. Print attributes + if (n->attrs.defined() && !n->attrs->dict.empty()) { + (*f)->stmts.push_back( + ExprStmtDoc(Relax(d, "func_attr") // + ->Call({d->AsDoc(n->attrs, n_p->Attr("attrs"))}))); + } + // Step 5. Print body + Array body = + PrintSeqExpr(Downcast(n->body), n_p->Attr("body"), d, /*use_ret=*/true); + (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end()); + return HeaderWrapper( + d, FunctionDoc(func_name, params, {Relax(d, "function")}, ret_type, (*f)->stmts)); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ExternFunc n, ObjectPath n_p, IRDocsifier d) -> Doc { + // TODO(@junrushao): print more information out of extern function. + return ExprStmtDoc(LiteralDoc::Str(n->global_symbol, n_p)); + }); + +TVM_SCRIPT_REPR(relax::FunctionNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ExternFuncNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/region.cc b/src/script/printer/relax/region.cc new file mode 100644 index 000000000000..1ac0b5ba14df --- /dev/null +++ b/src/script/printer/relax/region.cc @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +Array PrintSeqExpr(const relax::SeqExpr& n, const ObjectPath& n_p, const IRDocsifier& d, + bool use_ret) { + With f(d); + const Array& blocks = n->blocks; + ObjectPath blocks_p = n_p->Attr("blocks"); + Array* stmts = &(*f)->stmts; + for (int i = 0, l = blocks.size(); i < l; ++i) { + Doc block = d->AsDoc(blocks[i], blocks_p->ArrayIndex(i)); + if (const auto* stmt_block = block.as()) { + stmts->insert(stmts->end(), stmt_block->stmts.begin(), stmt_block->stmts.end()); + } else if (const auto* stmt = block.as()) { + stmts->push_back(GetRef(stmt)); + } else { + LOG(FATAL) << "TypeError: Unknown type: " << block->GetTypeKey(); + } + } + ExprDoc ret = d->AsDoc(n->body, n_p->Attr("body")); + if (use_ret) { + stmts->push_back(ReturnDoc(ret)); + } else { + stmts->push_back(ExprStmtDoc(ret)); + } + return *stmts; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](relax::SeqExpr n, ObjectPath n_p, IRDocsifier d) -> Doc { + return StmtBlockDoc(PrintSeqExpr(n, n_p, d, false)); + }); + +Array PrintBindingBlock(const relax::BindingBlock& n, const ObjectPath& n_p, + const IRDocsifier& d, Array* non_dataflow_vars) { + const Array& bindings = n->bindings; + ObjectPath bindings_p = n_p->Attr("bindings"); + Array stmts; + for (int i = 0, l = bindings.size(); i < l; ++i) { + const relax::Binding& binding = bindings[i]; + ObjectPath binding_p = bindings_p->ArrayIndex(i); + ICHECK(binding->var.defined()); + Doc binding_doc = d->AsDoc(binding, binding_p); + if (const auto* stmt = binding_doc.as()) { + stmts.push_back(GetRef(stmt)); + } else if (const auto* stmt_block = binding_doc.as()) { + stmts.insert(stmts.end(), stmt_block->stmts.begin(), stmt_block->stmts.end()); + } else { + LOG(FATAL) << "TypeError: Unknown type: " << binding_doc->GetTypeKey(); + } + if (non_dataflow_vars != nullptr && !binding->var->IsInstance()) { + non_dataflow_vars->push_back(d->AsDoc(binding->var, binding_p->Attr("var"))); + } + } + return stmts; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::BindingBlock n, ObjectPath n_p, IRDocsifier d) -> Doc { + return StmtBlockDoc(PrintBindingBlock(n, n_p, d, nullptr)); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::DataflowBlock n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array non_dataflow_vars; + Array stmts = PrintBindingBlock(n, n_p, d, &non_dataflow_vars); + stmts.push_back(ExprStmtDoc(Relax(d, "output")->Call(non_dataflow_vars))); + return ScopeDoc(NullOpt, Relax(d, "dataflow")->Call({}), stmts); + }); + +TVM_SCRIPT_REPR(relax::SeqExprNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::BindingBlockNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::DataflowBlockNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc new file mode 100644 index 000000000000..c541619ec887 --- /dev/null +++ b/src/script/printer/relax/struct_info.cc @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ObjectStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Object"); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Prim")->Call({LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))}); + }); + +ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifier& d) { + ExprDoc expr_doc = d->AsDoc(e, e_p); + // Step 1. Find if `func_vars` are being collected + const RelaxFrameNode* f = nullptr; + for (const Frame& frame : d->frames) { + if (const auto* relax_frame = frame.as()) { + if (relax_frame->func_vars) { + f = relax_frame; + break; + } + } + } + // Step 2. Figure out if the PrimExpr contains at least a func var + bool func_var_mode = false; + if (f != nullptr) { + tir::PostOrderVisit(e, [f, &func_var_mode](const ObjectRef& obj) -> void { + if (const auto* var = obj.as()) { + if (f->func_vars->count(var)) { + func_var_mode = true; + } + } + }); + } + // Step 3. Stringify the PrimExpr if func var exists + if (func_var_mode) { + return LiteralDoc::Str(DocToPythonScript(expr_doc, d->cfg), e_p); + } + return expr_doc; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::ShapeStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (n->values.defined()) { + Array shape = n->values.value(); + ObjectPath shape_p = n_p->Attr("values"); + Array shape_docs; + for (int i = 0, ndim = shape.size(); i < ndim; ++i) { + shape_docs.push_back(PrintShapeVar(shape[i], shape_p->ArrayIndex(i), d)); + } + return Relax(d, "Shape")->Call({ListDoc(shape_docs)}); + } + return Relax(d, "Shape") + ->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::TensorStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array args; + Array kwargs_keys; + Array kwargs_values; + if (n->shape.defined()) { + // Need to dig into ShapeExpr to preserve the `R.shape` prefix + if (const auto* shape = n->shape.value().as()) { + auto shape_expr = GetRef(shape); + ObjectPath shape_p = n_p->Attr("shape")->Attr("values"); + Array shape_docs; + for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { + shape_docs.push_back( + PrintShapeVar(shape_expr->values[i], shape_p->ArrayIndex(i), d)); + } + args.push_back(TupleDoc(shape_docs)); + } else { + args.push_back(d->AsDoc(n->shape.value(), n_p->Attr("shape"))); + } + } + if (!n->IsUnknownDtype()) { + kwargs_keys.push_back("dtype"); + kwargs_values.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))); + } + if (!n->shape.defined() && !n->IsUnknownNdim()) { + kwargs_keys.push_back("ndim"); + kwargs_values.push_back(LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))); + } + if (args.empty() && kwargs_keys.empty()) { + return Relax(d, "Tensor"); + } + return Relax(d, "Tensor")->Call(args, kwargs_keys, kwargs_values); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::TupleStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (n->fields.empty()) { + return Relax(d, "Tuple"); + } + Array fields_doc; + ObjectPath fields_p = n_p->Attr("fields"); + for (int i = 0, l = n->fields.size(); i < l; ++i) { + fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayIndex(i))); + } + return Relax(d, "Tuple")->Call(fields_doc); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::FuncStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (n->IsOpaque()) { + return Relax(d, "Callable"); + } + // TODO(@junrushao): track symbolic shape relation + Array params_doc; + Array params = n->params.value(); + ObjectPath params_p = n_p->Attr("params"); + for (int i = 0, n_params = params.size(); i < n_params; ++i) { + params_doc.push_back(d->AsDoc(params[i], params_p->ArrayIndex(i))); + } + return Relax(d, "Callable") + ->Call({TupleDoc(params_doc), // + d->AsDoc(n->ret, n_p->Attr("ret"))}); + }); + +TVM_SCRIPT_REPR(relax::ObjectStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::PrimStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ShapeStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::TensorStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::TupleStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::FuncStructInfoNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc new file mode 100644 index 000000000000..2a098644e07d --- /dev/null +++ b/src/script/printer/relax/tir.cc @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +/*! \brief Find the outmost Relax function frame. If not exist, the outmost Relax frame. */ +RelaxFrameNode* GetRelaxFrame(IRDocsifier d) { + RelaxFrameNode* f = nullptr; + for (const Frame& frame : d->frames) { + if (const auto* relax_frame = frame.as()) { + if (relax_frame->is_func) { + f = const_cast(relax_frame); + break; + } else if (f == nullptr) { + f = const_cast(relax_frame); + } + } + } + return f; +} + +Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) { + ICHECK(n->dtype.is_int() && n->dtype.is_scalar()) << "TypeError: Relax only uses " + "scalar integer TIR variables, but gets: " + << n; + if (!d->IsVarDefined(n)) { + RelaxFrameNode* f = GetRelaxFrame(d); + // There should be at least one Relax frame + if (f == nullptr) { + LOG(FATAL) << "IndexError: No relax environment is found when printing a TIR var under " + "relax's dispatch token"; + } + // If the Relax function frame is collecting func vars + if (f->func_vars) { + ICHECK(f->is_func); + f->func_vars->insert(n.get()); + } + IdDoc var = d->Define(n, GetRef(f), n->name_hint.empty() ? "v" : n->name_hint); + var->source_paths.push_back(n_p); + f->stmts.push_back(AssignDoc(var, TIR(d, DType2Str(n->dtype))->Call({}), NullOpt)); + } + if (Optional doc = d->GetVarDoc(n)) { + return doc.value(); + } + LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << n; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", PrintTIRVar); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", PrintTIRVar); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "relax", [](tvm::IntImm n, ObjectPath n_p, IRDocsifier d) -> Doc { // + // TODO(@junrushao): support non-int64 cases + return LiteralDoc::Int(n->value, n_p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "relax", [](tvm::GlobalVar n, ObjectPath n_p, IRDocsifier d) -> Doc { // + if (Optional doc = d->GetVarDoc(n)) { + return doc.value(); + } else { + IdDoc ret(n->name_hint); + ret->source_paths.push_back(n_p); + return ret; + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "relax", [](tvm::IRModule mod, ObjectPath n_p, IRDocsifier d) -> Doc { // + Optional doc = d->GetVarDoc(mod); + ICHECK(doc) << "Unable to print IRModule before definition in Relax."; + if (d->cfg->module_alias.empty()) { + // Use Module Name directly + return doc.value(); + } + RelaxFrameNode* f = GetRelaxFrame(d); + ICHECK(f != nullptr && f->is_func) + << "IndexError: No relax environment is found when printing a module alias var " + "under relax's dispatch token"; + if (!f->module_alias_printed) { + // If the module_alias is not defined before, define it. + f->stmts.push_back(AssignDoc(IdDoc(d->cfg->module_alias), doc.value(), NullOpt)); + f->module_alias_printed = true; + } + return IdDoc(d->cfg->module_alias); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/type.cc b/src/script/printer/relax/type.cc new file mode 100644 index 000000000000..d13d90b1d5ed --- /dev/null +++ b/src/script/printer/relax/type.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ShapeType n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Shape") + ->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ObjectType n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Object"); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::DynTensorType n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Tensor") + ->Call({}, {"ndim", "dtype"}, + {LiteralDoc::Int(n->ndim, n_p->Attr("ndim")), + LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::PackedFuncType n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "PackedFunc"); // TODO(@junrushao): verify if this is correct + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "relax", [](tvm::TupleType n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (n->fields.empty()) { + return Relax(d, "Tuple"); + } + Array fields_doc; + ObjectPath fields_p = n_p->Attr("fields"); + for (int i = 0, l = n->fields.size(); i < l; ++i) { + fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayIndex(i))); + } + return Relax(d, "Tuple")->Call(fields_doc); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "relax", [](tvm::FuncType n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array arg_types_doc; + Array arg_types = n->arg_types; + ObjectPath arg_types_p = n_p->Attr("arg_types"); + for (int i = 0, n_params = arg_types.size(); i < n_params; ++i) { + arg_types_doc.push_back(d->AsDoc(arg_types[i], arg_types_p->ArrayIndex(i))); + } + return Relax(d, "Callable") + ->Call({TupleDoc(arg_types_doc), // + d->AsDoc(n->ret_type, n_p->Attr("ret_type"))}); + }); + +TVM_SCRIPT_REPR(relax::ShapeTypeNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ObjectTypeNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::DynTensorTypeNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::PackedFuncTypeNode, ReprPrintRelax); +TVM_REGISTER_GLOBAL("script.printer.ReprPrintRelax").set_body_typed(ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h new file mode 100644 index 000000000000..97acb79c3d24 --- /dev/null +++ b/src/script/printer/relax/utils.h @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ +#define TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ + +#include +#include +#include + +#include +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace script { +namespace printer { + +class RelaxFrameNode : public FrameNode { + public: + bool is_func = false; + bool module_alias_printed = false; + std::unordered_set* func_vars = nullptr; + + void VisitAttrs(AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("is_global_func", &is_func); + // `func_var_to_define` is not visited + } + + static constexpr const char* _type_key = "script.printer.RelaxFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(RelaxFrameNode, FrameNode); +}; + +class RelaxFrame : public Frame { + public: + explicit RelaxFrame(const IRDocsifier& d) { + ObjectPtr n = make_object(); + n->stmts.clear(); + n->d = d.get(); + n->is_func = false; + n->func_vars = nullptr; + data_ = std::move(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, Frame, RelaxFrameNode); +}; + +/*! \brief Redirected method for the ReprPrinter */ +inline std::string ReprPrintRelax(const ObjectRef& obj, const PrinterConfig& cfg) { + IRDocsifier d(cfg); + With f(d); + (*f)->AddDispatchToken(d, "relax"); + return Docsify(obj, d, *f, cfg); +} + +inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const IRDocsifier& d) { + return d->Define(var, frame, var->name_hint().empty() ? "v" : var->name_hint()); +} + +inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& v_p, + const IRDocsifier& d, const Optional& rhs) { + if (!v->struct_info_.defined()) { + return NullOpt; + } + if (const auto* call = rhs.as()) { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); + if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) { + return NullOpt; + } + } + return d->AsDoc(v->struct_info_, v_p->Attr("struct_info_")); +} + +Array PrintSeqExpr(const relax::SeqExpr& n, const ObjectPath& n_p, const IRDocsifier& d, + bool use_ret); + +ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifier& d); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index f40d7818d7e1..4c24d710bb05 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -68,6 +68,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc { With f(d, func); (*f)->AddDispatchToken(d, "tir"); + IdDoc func_name = d->Define(func, f(), FindFunctionName(d, func).value_or("main")); d->SetCommonPrefix(func, [](const ObjectRef& obj) { return obj->IsInstance() || obj->IsInstance(); }); @@ -167,7 +168,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } return HeaderWrapper(d, FunctionDoc( - /*name=*/IdDoc(FindFunctionName(d, func).value_or("main")), + /*name=*/func_name, /*args=*/args, /*decorators=*/{TIR(d, "prim_func")}, /*return_type=*/ret_type, diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index ec0f0eaf72b0..2a2f46908271 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -111,6 +111,12 @@ inline ExprDoc TIR(const IRDocsifier& d, const String& attr) { return IdDoc(d->cfg->tir_prefix)->Attr(attr); } +/*! \brief Creates the TIR common prefix, which is by default `T` */ +inline ExprDoc Relax(const IRDocsifier& d, const String& attr) { + d->ir_usage.insert("relax"); + return IdDoc(d->cfg->relax_prefix)->Attr(attr); +} + inline std::string DType2Str(const runtime::DataType& dtype) { return dtype.is_void() ? "void" : runtime::DLDataType2String(dtype); } @@ -125,7 +131,9 @@ inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) { if (d->ir_usage.count("tir")) { stmts.push_back(CommentDoc("from tvm.script import tir as " + d->cfg->tir_prefix)); } - + if (d->ir_usage.count("relax")) { + stmts.push_back(CommentDoc("from tvm.script import relax as " + d->cfg->relax_prefix)); + } stmts.push_back(CommentDoc("")); stmts.push_back(Downcast(doc)); return StmtBlockDoc(stmts); diff --git a/src/support/array.h b/src/support/array.h index 218150f9dba0..0ca57a2410c5 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -21,6 +21,7 @@ #include #include +#include #include namespace tvm { @@ -81,11 +82,35 @@ inline std::vector AsVector(const Array& vec); * \brief Convert a std::vector to tvm::runtime::Array * \tparam TSrc The type of elements in the source vector * \tparam TDst The type of elements in the result Array - * \return The result vector + * \return The result Array */ template inline Array AsArray(const std::vector& vec); +/*! + * \brief Convert a tvm::runtime::Array to std::list + * \tparam T The type of elements in the source array + * \return The result list + */ +template +inline std::list AsList(const Array& array) { + std::list list; + for (const auto& v : array) list.push_back(v); + return list; +} + +/*! + * \brief Convert a std::list to tvm::runtime::Array + * \tparam T The type of elements in the source list + * \return The result list + */ +template +inline Array AsArray(const std::list& list) { + Array array; + for (const auto& v : list) array.push_back(v); + return array; +} + /*! * \brief Get the shape tuple as array * \param shape The shape tuple diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 398e24d2510e..ab9a2ff594b2 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -119,7 +119,7 @@ TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", DispatchPureExtern); PrimExpr DispatchFastErf(const PrimExpr& e) { - LOG(WARNING) << "fast_erf will be used instead of erf"; + DLOG(WARNING) << "fast_erf will be used instead of erf"; const CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 1); diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 21d2c6ebe0a5..10aa2688a846 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -905,8 +905,10 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, func_name, module_.get()); } - - nargs -= 1; + // NOTE: This is a bugfix to a previous coupled convention(in lower_tvm_builtin) + // The begin, end should correspond to the right location in cpacked excluding resource handle. + // TODO(tqchen): upstream the fix. + // nargs -= 1; call_args.insert(call_args.end(), { builder_->CreateBitCast(arg_value, t_void_p_), arg_tcode.addr, diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index fd770007e243..32b32063940b 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -23,11 +23,13 @@ #include "codegen_webgpu.h" #include +#include #include #include #include #include +#include #include #include @@ -39,6 +41,63 @@ namespace tvm { namespace codegen { +// WebGPU Info +struct WebGPUWorkGroupInfo { + int workgroup_size[3] = {1, 1, 1}; + // whether we have ref to block index z is used. + bool has_block_index_z{false}; + // set of handles that have write access + std::unordered_set write_access_set; +}; + +class WebGPUWorkgroupInfoCollector : public StmtExprVisitor { + public: + static WebGPUWorkGroupInfo Collect(const Stmt& stmt) { + WebGPUWorkgroupInfoCollector collector; + collector(stmt); + return collector.info_; + } + + private: + void VisitExpr_(const VarNode* op) final { + StmtExprVisitor::VisitExpr_(op); + Var buffer_var = GetRef(op); + if (buffer_var.dtype().is_handle()) { + info_.write_access_set.insert(buffer_var); + } + } + + void VisitStmt_(const BufferStoreNode* op) final { + StmtExprVisitor::VisitStmt_(op); + info_.write_access_set.insert(op->buffer->data); + } + + void VisitStmt_(const AttrStmtNode* op) final { + // record workgroup size + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag.length() != 0) { + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); + if (ts.rank == 1) { + ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here"; + ICHECK_LT(ts.dim_index, 3); + auto* sizeptr = op->value.as(); + ICHECK(sizeptr) << "CodeGenWebGPU: only allows constant thread group size " + << " get " << op->value; + info_.workgroup_size[ts.dim_index] = static_cast(sizeptr->value); + } else if (ts.rank == 0) { + if (ts.dim_index == 2) { + info_.has_block_index_z = true; + } + } + } + } + // normal operation + StmtExprVisitor::VisitStmt_(op); + } + WebGPUWorkGroupInfo info_; +}; + std::string CodeGenWebGPU::Finish() { return decl_stream.str() + this->fwd_decl_stream.str() + stream.str(); } @@ -51,12 +110,11 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { alloc_storage_scope_[arg.get()] = "global"; } } - std::fill(workgroup_size_, workgroup_size_ + 3, 1); } CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {} -void CodeGenWebGPU::AddFunction(const PrimFunc& f) { +runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_readonly_decl) { // clear previous generated state. this->InitFuncState(f); // skip the first underscore, so SSA variable starts from @@ -64,6 +122,7 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) { // Setup the thread group info. ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); + ICHECK_EQ(name_supply_->FreshName("gridDim"), "gridDim"); // add to alloc buffer type. auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); @@ -71,15 +130,24 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) { << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; decl_stream << "//----------------------------------------\n" - << "// function: " << global_symbol.value() << "\n" + << "// Function: " << global_symbol.value() << "\n" << "//----------------------------------------\n"; + runtime::FunctionInfo func_info; + func_info.name = global_symbol.value(); + + WebGPUWorkGroupInfo info = WebGPUWorkgroupInfoCollector::Collect(f->body); std::vector pod_args; int num_buffer = 0; + + // add param_access modes info to launch params + std::ostringstream os_param_access; + os_param_access << "paramWriteAccess:["; // setup buffer argumemts for (Var arg : f->params) { DataType t = arg.dtype(); if (t.is_handle()) { + func_info.arg_types.push_back(t); auto* ptr = arg->type_annotation.as(); ICHECK(ptr) << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " "PointerType, " @@ -95,8 +163,20 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) { value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes()); } std::string vid = AllocVarID(arg.get()); + std::string access_mode; + if (num_buffer != 0) { + os_param_access << ","; + } + if (skip_readonly_decl || info.write_access_set.count(arg)) { + access_mode = "read_write"; + os_param_access << "1"; + } else { + access_mode = "read"; + os_param_access << "0"; + } + // add extra access mode info to launch params this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") " - << "var " << vid << " : array<"; + << "var " << vid << " : array<"; this->PrintType(value_storage_type, this->decl_stream); this->decl_stream << ">;\n"; } else { @@ -104,15 +184,33 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) { } } + // setup thread tags and param access in launch param tags; + if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { + auto thread_axis = opt.value(); + for (size_t i = 0; i < thread_axis.size(); ++i) { + func_info.launch_param_tags.push_back(thread_axis[i]->thread_tag); + } + } + os_param_access << "]"; + func_info.launch_param_tags.push_back(os_param_access.str()); + if (pod_args.size() != 0) { // setup POD arguments // TODO(tvm-team): store as a uniform, readonly buffer. LOG(FATAL) << "Do not support pod arguments for now"; } + + ICHECK(!info.has_block_index_z) + << "blockIdx.z is not supported in WebGPU to accomodate large blockIdx.x"; + // anotate workgroup + this->stream << "@compute @workgroup_size(" << info.workgroup_size[0] << ", " + << info.workgroup_size[1] << ", " << info.workgroup_size[2] << ")\n"; + // add to alloc buffer type. // Function header. - this->stream << "fn main(\n" + this->stream << "fn " << func_info.name << "(\n" << " @builtin(workgroup_id) blockIdx : vec3,\n" + << " @builtin(num_workgroups) gridDim : vec3,\n" << " @builtin(local_invocation_id) threadIdx : vec3\n" << ") {\n"; // the function scope. @@ -121,39 +219,26 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) { this->EndScope(func_scope); this->PrintIndent(); this->stream << "}\n\n"; - // anotate workgroup - this->fwd_decl_stream << "@compute @workgroup_size(" << workgroup_size_[0] << ", " - << workgroup_size_[1] << ", " << workgroup_size_[2] << ")\n"; -} - -void CodeGenWebGPU::VisitStmt_(const AttrStmtNode* op) { - // record workgroup size - if (op->attr_key == tir::attr::thread_extent) { - IterVar iv = Downcast(op->node); - if (iv->thread_tag.length() != 0) { - runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); - if (ts.rank == 1) { - ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here"; - ICHECK_LT(ts.dim_index, 3); - auto* sizeptr = op->value.as(); - ICHECK(sizeptr) << "CodeGenWebGPU: only allows constant thread group size " - << " get " << op->value; - workgroup_size_[ts.dim_index] = static_cast(sizeptr->value); - } - } - } - // normal operation - CodeGenC::VisitStmt_(op); + return func_info; } void CodeGenWebGPU::BindThreadIndex(const IterVar& iv) { ICHECK(!var_idmap_.count(iv->var.get())); std::ostringstream os; PrintType(iv->var.dtype(), os); - os << "(" << iv->thread_tag << ")"; - std::string tidx = os.str(); - this->MarkConst(tidx); - var_idmap_[iv->var.get()] = tidx; + if (iv->thread_tag == "blockIdx.x") { + // WebGPU have restriction to limit the maximum size of blockId.x to be 65535 + // We allow runtime to spread the load out to blockIdx.z so it can be a large number. + os << "(blockIdx.z * gridDim.x + blockIdx.x)"; + std::string tidx = os.str(); + std::string aggregated_bidx = SSAGetID(os.str(), iv->var.dtype()); + var_idmap_[iv->var.get()] = aggregated_bidx; + } else { + os << "(" << iv->thread_tag << ")"; + std::string tidx = os.str(); + this->MarkConst(tidx); + var_idmap_[iv->var.get()] = tidx; + } } void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) @@ -179,8 +264,10 @@ void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) ICHECK(t.bits() == 16 || t.bits() == 32) << "CodeGenWebGPU: only support f16 or f32"; os << "f" << t.bits(); } else if (t.is_uint()) { + ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support u64"; os << "u" << t.bits(); } else if (t.is_int()) { + ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support i64"; os << "i" << t.bits(); } else { LOG(FATAL) << "CodeGenWebGPU: Cannot convert type " << t << " to WebGPU type"; @@ -221,6 +308,10 @@ void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // os << ')'; } +PrimExpr CodeGenWebGPU::EnforceU32(PrimExpr value) { + return cast(DataType::UInt(32, value.dtype().lanes()), value); +} + void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->op.same_as(builtin::reinterpret())) { // generate bitcast(ARG) @@ -229,6 +320,20 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN os << ">("; this->PrintExpr(op->args[0], os); os << ")"; + } else if (op->op.same_as(builtin::shift_right())) { + os << '('; + this->PrintExpr(op->args[0], os); + os << ">>"; + // WebGPU requires shift bits to be u32. + this->PrintExpr(EnforceU32(op->args[1]), os); + os << ')'; + } else if (op->op.same_as(builtin::shift_left())) { + os << '('; + this->PrintExpr(op->args[0], os); + os << "<<"; + // WebGPU requires shift bits to be u32. + this->PrintExpr(EnforceU32(op->args[1]), os); + os << ')'; } else if (op->op.same_as(builtin::if_then_else())) { // conditional that skips eval if cond evals to false std::string result = name_supply_->FreshName("condval"); @@ -241,14 +346,16 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN this->stream << "if (" << cond << ") {\n"; { int then_scope = this->BeginScope(); + std::string true_val = PrintExpr(op->args[1]); this->PrintIndent(); - this->stream << result << " = " << PrintExpr(op->args[1]) << ";\n} else {\n"; + this->stream << result << " = " << true_val << ";\n} else {\n"; this->EndScope(then_scope); } { int else_scope = this->BeginScope(); + std::string false_val = PrintExpr(op->args[2]); this->PrintIndent(); - this->stream << result << " = " << PrintExpr(op->args[2]) << ";\n}\n"; + this->stream << result << " = " << false_val << ";\n}\n"; this->EndScope(else_scope); } os << result; @@ -444,10 +551,13 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) { PrintType(op->dtype, this->decl_stream); this->decl_stream << ", " << constant_size << ">;\n"; } else if (storage_scope.rank == runtime::StorageRank::kLocal) { - this->PrintIndent(); - this->stream << "var " << vid << " : array<"; - PrintType(op->dtype, this->stream); - this->stream << ", " << constant_size << ">;\n"; + this->decl_stream << "var " << vid << " : array<"; + PrintType(op->dtype, this->decl_stream); + this->decl_stream << ", " << constant_size << ">;\n"; + // this->PrintIndent(); + // this->stream << "var " << vid << " : array<"; + // PrintType(op->dtype, this->stream); + // this->stream << ", " << constant_size << ">;\n"; } else { LOG(FATAL) << "WebGPU: Do not support storage scope: " << storage_scope.to_string(); } @@ -456,11 +566,10 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) { void CodeGenWebGPU::VisitStmt_(const ForNode* op) { std::string extent = PrintExpr(op->extent); - PrintIndent(); std::string vid = AllocVarID(op->loop_var.get()); ICHECK(is_zero(op->min)); - stream << "for (var "; - stream << vid << " : "; + PrintIndent(); + stream << "for (var " << vid << " : "; PrintType(op->loop_var.dtype(), stream); stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n"; int for_scope = BeginScope(); @@ -507,11 +616,17 @@ class WebGPUSourceModuleNode final : public runtime::ModuleNode { } std::string GetSource(const std::string& format) final { - std::ostringstream os; - for (auto kv : smap_) { - os << kv.second; + if (format == "func_info") { + std::ostringstream stream; + dmlc::JSONWriter(&stream).Write(fmap_); + return stream.str(); + } else { + std::ostringstream os; + for (auto kv : smap_) { + os << kv.second; + } + return os.str(); } - return os.str(); } private: @@ -527,8 +642,13 @@ class WebGPUSourceModuleNode final : public runtime::ModuleNode { runtime::Module BuildWebGPU(IRModule mod, Target target) { mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); bool output_ssa = false; - + bool skip_readonly_decl = false; std::unordered_map smap; + std::unordered_map fmap; + + // narrow all i64 to i32 + mod = tir::transform::ForceNarrowIndexToInt32()(std::move(mod)); + for (auto kv : mod->functions) { CodeGenWebGPU cg(target); ICHECK(kv.second->IsInstance()) << "CodeGenWebGPU: Can only take PrimFunc"; @@ -541,11 +661,12 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) { << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); cg.Init(output_ssa); - cg.AddFunction(f); + fmap[f_name] = cg.AddFunction(f, skip_readonly_decl); std::string code = cg.Finish(); smap[f_name] = code; } - auto n = make_object(smap, ExtractFuncInfo(mod)); + + auto n = make_object(smap, fmap); return runtime::Module(n); } diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h index 57f226ba8ad6..ff99f4608a39 100644 --- a/src/target/source/codegen_webgpu.h +++ b/src/target/source/codegen_webgpu.h @@ -48,7 +48,7 @@ class CodeGenWebGPU final : public CodeGenC { explicit CodeGenWebGPU(Target target); // overrides std::string Finish() final; - void AddFunction(const PrimFunc& f); // NOLINT(*) + runtime::FunctionInfo AddFunction(const PrimFunc& f, bool skip_readonly_decl); // NOLINT(*) void InitFuncState(const PrimFunc& f) final; void PrintStorageSync(const CallNode* op) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) @@ -71,15 +71,14 @@ class CodeGenWebGPU final : public CodeGenC { void VisitStmt_(const BufferStoreNode* op) final; void VisitStmt_(const ForNode* op) final; void VisitStmt_(const AllocateNode* op) final; - void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const AssertStmtNode* op) final; void VisitStmt_(const AllocateConstNode* op) final; private: /*! - * \brief Records the workgroup size of the kernel. + * \brief Enforce value to be U32. */ - uint32_t workgroup_size_[3]; + static PrimExpr EnforceU32(PrimExpr value); /*! * \brief Storage type of bool values. */ diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index ac304b92b6d7..ffef425c0e41 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -27,6 +27,8 @@ #include #include +#include "../intrin_rule.h" + namespace tvm { namespace codegen { namespace spirv { @@ -100,6 +102,9 @@ TVM_REGISTER_OP("tir.pow").set_attr("vulkan.FLowerIntrinsic", TVM_REGISTER_OP("tir.tanh") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + +TVM_REGISTER_OP("tir.erf").set_attr("vulkan.FLowerIntrinsic", + codegen::intrin ::DispatchFastErf); } // namespace intrin namespace legalize { diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index cc52cf618dc1..133b9923fb71 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -581,5 +581,85 @@ TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args, TVMRetValue* *ret = CreatePrimFunc(arg_list, index_dtype_override); }); +// Relax version impl +PrimFunc GenerateAndCompletePrimFunc(const Array& arg_list, + const Array& root_stmts, CreateFuncInfo* info, + const Optional> tir_var_list) { + Array parameters; + Map buffer_map; + for (const te::Tensor& tensor : arg_list) { + Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); + parameters.push_back(arg); + auto it = info->tensor2buffers.find(tensor); + ICHECK(it != info->tensor2buffers.end()); + buffer_map.Set(arg, it->second); + } + + // add additional arguments for tir vars that are left unbound by match buffer + if (tir_var_list) { + for (const Var& v : tir_var_list.value()) { + parameters.push_back(v); + } + } + + PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters), + /*body=*/SeqStmt::Flatten(root_stmts), + /*ret_type=*/VoidType(), + /*buffer_map=*/std::move(buffer_map)), + {{"global_symbol", String("main")}, {"tir.noalias", Bool(true)}}); + + const auto* complete = runtime::Registry::Get("script.Complete"); + ICHECK(complete); + func = (*complete)(std::move(func), info->root_alloc); + return func; +} + +PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, + const Array& constants, + const Optional>& tir_var_list, + std::optional index_dtype_override) { + // Infomations used in CreatePrimFunc and its sub-functions. + CreateFuncInfo info(arg_list); + // Root body stmts. + Array root_stmts; + // Analyzer + arith::Analyzer analyzer; + + // Step 1. Create ordered array of operations and validate they are supported. + Array order = CollectOrderedOps(arg_list); + + // Step 2. Initialize buffer binds map + InitializeBufferBinds(order, &info); + + // Step 3. Rewrite compute stages into blocks. + for (const te::Operation& op : order) { + RewriteStageToBlock(op, &info, &root_stmts, &analyzer); + } + auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info, tir_var_list); + func = tir::BindParams(func, constants); + if (index_dtype_override.has_value()) { + func = IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func)); + } + auto result = LayoutFreePlaceholdersNormalizer().Process(std::move(func)); + return result; +} + +PrimFunc CreatePrimFunc(const Array& arg_list, + const Optional> tir_var_list, + std::optional index_dtype_override) { + return CreatePrimFuncWithConstants(arg_list, {}, tir_var_list, index_dtype_override); +} + +TVM_REGISTER_GLOBAL("te.CreateRelaxPrimFunc").set_body([](TVMArgs args, TVMRetValue* ret) { + Array arg_list = args[0]; + Optional> tir_var_list = args[1]; + std::optional index_dtype_override{std::nullopt}; + // Add conversion to make std::optional compatible with FFI. + if (args[2].type_code() != kTVMNullptr) { + index_dtype_override = args[2].operator DataType(); + } + *ret = CreatePrimFunc(arg_list, tir_var_list, index_dtype_override); +}); + } // namespace tir } // namespace tvm diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index 4246347a16f3..946f024849bf 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -42,6 +42,23 @@ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, const Array& constants, std::optional index_dtype_override = std::nullopt); +// Relax version +// TODO(relax-team) combine with the relay version +/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ +PrimFunc CreatePrimFunc(const Array& arg_list, + const Optional> tir_var_list, + std::optional index_dtype_override); + +/*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the + * constants array is N, the last N tensors in arg_list will be treated as constant tensors. + * Constant tensors will not be part of the parameters of the created PrimFunc, instead constants + * will be embedded in the body as AllocateConstNode. + */ +PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, + const Array& constants, + const Optional>& tir_var_list, + std::optional index_dtype_override = std::nullopt); + } // namespace tir } // namespace tvm diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 4c59a1767372..781a0ecd7c3d 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -115,8 +115,8 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return PrimFuncPass(pass_func, pass_info); } diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index c85590428450..a800b9d77c29 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -355,6 +355,18 @@ TIR_DEFINE_BUILTIN_FUNC(start_profile_intrinsic) TIR_DEFINE_BUILTIN_FUNC(end_profile_intrinsic) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_BUILTIN_FUNC(anylist_getitem) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)); + +TIR_DEFINE_BUILTIN_FUNC(anylist_resetitem) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TGlobalSymbol", "TVMBackendAnyListResetItem"); + +TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_packed) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_cpacked) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace builtin } // namespace tir } // namespace tvm diff --git a/src/tir/op/runtime.cc b/src/tir/op/runtime.cc new file mode 100644 index 000000000000..9ee6c67ec96b --- /dev/null +++ b/src/tir/op/runtime.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/op/runtime.cc + * \brief TIR ops for runtime functions. + */ +#include +#include + +namespace tvm { +namespace tir { + +TVM_REGISTER_OP("tir.TVMBackendAnyListSetPackedArg") + .set_num_inputs(5) + .set_attr("TGlobalSymbol", "TVMBackendAnyListSetPackedArg") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tir.TVMBackendAnyListMoveFromPackedReturn") + .set_num_inputs(3) + .set_attr("TGlobalSymbol", "TVMBackendAnyListMoveFromPackedReturn") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc new file mode 100644 index 000000000000..8666d7eb4712 --- /dev/null +++ b/src/tir/transforms/default_gpu_schedule.cc @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../meta_schedule/utils.h" + +namespace tvm { +namespace tir { +namespace transform { +/*! + * \brief A helper function to do default thread binding for a block. + * \param sch The schedule to work on. + * \param block The block to be scheduled. + * \param max_thread_per_block The maximum number of threads per block. + * \param max_threadblocks The maximum number of threadblocks. + */ +void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread_per_block, + int64_t max_threadblocks = 256) { + // fetch the loops + Array loops = sch->GetLoops(block); + for (const tir::LoopRV& loop : loops) { + // skip block if already scheduled + if (sch->Get(loop)->thread_binding.defined()) { + return; + } + } + Array iters = sch->Get(block)->iter_vars; + ICHECK_EQ(loops.size(), iters.size()); + Array data_parallel_loops; + // only fuse data parallel loops + for (size_t i = 0; i < loops.size(); ++i) { + if (iters[i]->iter_type == tir::IterVarType::kDataPar) { + data_parallel_loops.push_back(loops[i]); + } + } + // skip if no data parallel loops + if (data_parallel_loops.size() == 0) { + return; + } + // fuse all data parallel loops + tir::LoopRV fused = sch->Fuse(data_parallel_loops, /*preserve_unit_iters=*/false); + int64_t product = std::numeric_limits::max(); + if (sch->Get(fused)->extent->IsInstance()) { + product = sch->Get(fused)->extent.as()->value; + } + // schedule the fused loop + if (product > max_thread_per_block * max_threadblocks) { + Array splits = + sch->Split(fused, + /*factors=*/{NullOpt, Integer(max_threadblocks), Integer(max_thread_per_block)}); + sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]}); + sch->Bind(splits[1], "blockIdx.x"); + sch->Bind(splits[2], "threadIdx.x"); + } else { + Array splits = + sch->Split(fused, /*factors=*/{NullOpt, Integer(std::min(product, max_thread_per_block))}); + sch->Bind(splits[0], "blockIdx.x"); + sch->Bind(splits[1], "threadIdx.x"); + } +} + +Pass DefaultGPUSchedule() { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { + // get the target from context. + tvm::Target target = tvm::Target::Current(); + ICHECK(target.defined()) << "Target is not set in current context"; + // skip non-cuda targets. + if (target->kind->name != "cuda") { + return m; + } + // get the max thread per block from target. + Optional opt_max_thread_per_block = target->GetAttr("max_num_threads"); + ICHECK(opt_max_thread_per_block.defined()) + << "max_num_threads is not set for target " << target; + int64_t max_thread_per_block = opt_max_thread_per_block.value().IntValue(); + tir::Schedule sch = tir::Schedule::Traced(m, /*seed=*/-1, /*debug_mask=*/0, + tir::ScheduleErrorRenderLevel::kDetail); + for (const auto& [gv, func] : m->functions) { + if (func->IsInstance() && !func->HasNonzeroAttr(attr::kIsScheduled)) { + sch->WorkOn(gv->name_hint); + Array blocks = meta_schedule::BlockCollector::Collect(sch); + for (const tir::BlockRV& block : blocks) { + ThreadBind(sch, block, max_thread_per_block); + } + } + } + return sch->mod(); + }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"DefaultGPUSchedule", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("tir.transform.DefaultGPUSchedule").set_body_typed(DefaultGPUSchedule); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/force_narrow_index_to_i32.cc b/src/tir/transforms/force_narrow_index_to_i32.cc new file mode 100644 index 000000000000..70dc554e120a --- /dev/null +++ b/src/tir/transforms/force_narrow_index_to_i32.cc @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file force_narrow_index_to_i32.cc + * \brief Force narrow down indexing expressions and integer buffers to int32 dtype. + * \note This pass is not used in default cases. + */ + +#include +#include + +namespace tvm { +namespace tir { + +class Int32DTypeNarrower : public IndexDataTypeNormalizer { + public: + static PrimFunc RewriteDataType(PrimFunc func) { + // Check if the integer parameter buffers have dtype other than int32. + for (auto it : func->buffer_map) { + if (it.second->dtype.is_int() && it.second->dtype.bits() != 32) { + LOG(FATAL) << "The buffer " << it.second << " in the function buffer map has dtype " + << it.second->dtype << ". The function is " << func; + } + } + + Int32DTypeNarrower narrower(func); + return narrower.Rewrite(func); + } + + private: + explicit Int32DTypeNarrower(PrimFunc func) + : IndexDataTypeNormalizer(DataType::Int(32)), func_(std::move(func)) {} + + Stmt VisitStmt_(const BlockNode* block) final { + Block block_ = Downcast(IndexDataTypeNormalizer::VisitStmt_(block)); + // Check if the allocated integer buffers have dtype other than int32. + for (const Buffer& buf : block_->alloc_buffers) { + if (buf->dtype.is_int() && buf->dtype.bits() != 32) { + LOG(FATAL) << "The buffer " << buf << " allocated in the function has dtype " << buf->dtype + << ". The function is " << func_; + } + } + return block_; + } + + PrimFunc func_; +}; + +PrimFunc ForceNarrowIndexToInt32(PrimFunc func) { + return Int32DTypeNarrower::RewriteDataType(func); +} + +namespace transform { + +Pass ForceNarrowIndexToInt32() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + return ForceNarrowIndexToInt32(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ForceNarrowIndexToInt32") + .set_body_typed(ForceNarrowIndexToInt32); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index ea418635bc2a..1019d89ba55b 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -302,13 +302,21 @@ class BuiltinLower : public StmtExprMutator { return Stmt(n); } } + PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_call_packed())) { - return MakeCallPacked(op, /* use_string_lookup */ true); + return MakeCallPackedGeneric(op, 0, builtin::tvm_call_packed_lowered(), + /* use_string_lookup */ true); } else if (op->op.same_as(builtin::tvm_call_cpacked())) { - return MakeCallPacked(op, /* use_string_lookup */ false); + return MakeCallPackedGeneric(op, 0, builtin::tvm_call_cpacked_lowered(), + /* use_string_lookup */ false); } else if (op->op.same_as(builtin::tvm_call_trace_packed())) { - return MakeCallTracePacked(op); + return MakeCallPackedGeneric(op, 0, builtin::tvm_call_trace_packed_lowered(), + /* use_string_lookup */ true); + } else if (op->op.same_as(builtin::anylist_setitem_call_packed())) { + return MakeAnyListSetItemCallPacked(op, builtin::tvm_call_packed_lowered(), true); + } else if (op->op.same_as(builtin::anylist_setitem_call_cpacked())) { + return MakeAnyListSetItemCallPacked(op, builtin::tvm_call_cpacked_lowered(), false); } else if (op->op.same_as(builtin::tvm_stack_make_shape())) { return MakeShape(op); } else if (op->op.same_as(builtin::tvm_stack_make_array())) { @@ -447,8 +455,68 @@ class BuiltinLower : public StmtExprMutator { cast(DataType::Int(32), device_type_.value()))); return TVMStructGet(DataType::Handle(), scope.stack_array, idx, builtin::kArrAddr); } - // call packed. - PrimExpr MakeCallPacked(const CallNode* op, bool use_string_lookup) { + + void SetPackedArg(PrimExpr arg, const Var& value_stack, const Buffer& tcode_stack, + size_t stack_offset, std::vector* prep_seq) { + auto* call_pattern = arg.as(); + if (call_pattern && call_pattern->op.same_as(builtin::anylist_getitem())) { + // call runtime function to set anylist + prep_seq->emplace_back( + Evaluate(Call(DataType::Int(32), Op::Get("tir.TVMBackendAnyListSetPackedArg"), + {call_pattern->args[0], call_pattern->args[1], value_stack, + tcode_stack->data, ConstInt32(stack_offset)}))); + } else { + DataType api_type = APIType(arg.dtype()); + if (arg.dtype() != api_type) { + arg = Cast(api_type, arg); + } + prep_seq->emplace_back( + TVMStructSet(value_stack, stack_offset, builtin::kTVMValueContent, arg)); + int arg_tcode = api_type.code(); + if (api_type.is_handle() && arg.as()) { + arg_tcode = kTVMStr; + } else if (IsArrayHandle(arg)) { + arg_tcode = kTVMDLTensorHandle; + } + // opaque handle need to set the kind properly + if (arg_tcode == kTVMOpaqueHandle) { + prep_seq->emplace_back(IfThenElse( + Call(DataType::Bool(), builtin::isnullptr(), {arg}), + BufferStore(tcode_stack, ConstInt32(kTVMNullptr), {ConstInt32(stack_offset)}), + BufferStore(tcode_stack, ConstInt32(arg_tcode), {ConstInt32(stack_offset)}))); + } else { + prep_seq->emplace_back( + BufferStore(tcode_stack, ConstInt32(arg_tcode), {ConstInt32(stack_offset)})); + } + } + } + + PrimExpr MakeAnyListSetItemCallPacked(const CallNode* op, const Op& lowered_op, + bool use_string_lookup) { + PrimExpr list_handle = op->args[0]; + PrimExpr list_index = op->args[1]; + + Call call = MakeCallPackedGeneric(op, 2, lowered_op, use_string_lookup); + PrimExpr value_stack = call->args[1]; + PrimExpr tcode_stack = call->args[2]; + // The stack offset of return value stack_end + PrimExpr ret_offset = call->args[4]; + auto& prep_seq = prep_seq_stack_.back(); + prep_seq.emplace_back(Evaluate(call)); + return Call(DataType::Int(32), Op::Get("tir.TVMBackendAnyListMoveFromPackedReturn"), + {list_handle, list_index, value_stack, tcode_stack, ret_offset}); + } + /*! + * \brief Generic tool to make low-level + * packed_call(other_args..., func_name, packed_arg0, packed_arg1...) + * + * \param op The call + * \param name_offset The beginning of function name and call packed section. + * \param lowered_packed_op The target lowered op. + * \param use_string_lookup Whether to lookup function by string. + */ + Call MakeCallPackedGeneric(const CallNode* op, size_t name_offset, const Op& lowered_packed_op, + bool use_string_lookup) { auto& scope = alloca_scope_.back(); auto& prep_seq = prep_seq_stack_.back(); @@ -456,34 +524,24 @@ class BuiltinLower : public StmtExprMutator { size_t restore_array_stack = scope.run_sizes.array_stack; size_t arg_stack_begin = scope.run_sizes.arg_stack; - size_t arg_count = op->args.size(); + size_t args_begin = name_offset + 1; + size_t args_end = op->args.size(); // cpacked expects a resource_handle parameter if (!use_string_lookup) { - arg_count--; + --args_end; } + size_t num_args = args_end - args_begin; - scope.run_sizes.arg_stack += arg_count; + // The extra one slot is for return value. + scope.run_sizes.arg_stack += num_args + 1; // Specially handle the buffer packed intrinsic PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - for (size_t i = 1; i < arg_count; ++i) { - PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1); - PrimExpr arg = op->args[i]; - DataType t = arg.dtype(); - DataType api_type = APIType(t); - if (t != api_type) { - arg = Cast(api_type, arg); - } - prep_seq.emplace_back(TVMStructSet(scope.stack_value, - static_cast(arg_stack_begin + i - 1), - builtin::kTVMValueContent, arg)); - int arg_tcode = api_type.code(); - if (api_type.is_handle() && arg.as()) { - arg_tcode = kTVMStr; - } - if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle; - prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index})); + + for (size_t i = 0; i < num_args; ++i) { + this->SetPackedArg(op->args[args_begin + i], scope.stack_value, scope.stack_tcode, + arg_stack_begin + i, &prep_seq); } // Verify stack size matches earlier value. if (is_precheck_) { @@ -494,13 +552,12 @@ class BuiltinLower : public StmtExprMutator { scope.run_sizes.shape_stack = restore_shape_stack; scope.run_sizes.array_stack = restore_array_stack; scope.run_sizes.arg_stack = arg_stack_begin; - Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + op->args.size() - 1)}; - + Array packed_args = {op->args[name_offset], scope.stack_value, + scope.stack_tcode->data, ConstInt32(arg_stack_begin), + ConstInt32(arg_stack_begin + num_args)}; // cpacked call resource_handle if (!use_string_lookup) { - PrimExpr last_arg = op->args[arg_count]; + PrimExpr last_arg = op->args[args_end]; const VarNode* var_node = last_arg.as(); if (var_node != nullptr) { tir::Var resource_handle = GetRef(var_node); @@ -509,57 +566,7 @@ class BuiltinLower : public StmtExprMutator { packed_args.push_back(last_arg); } } - - auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered() - : builtin::tvm_call_cpacked_lowered(); - return Call(op->dtype, builtin_call, packed_args); - } - - PrimExpr MakeCallTracePacked(const CallNode* op) { - ICHECK(!alloca_scope_.empty()); - auto& scope = alloca_scope_.back(); - auto& prep_seq = prep_seq_stack_.back(); - - int64_t restore_shape_stack = scope.run_sizes.shape_stack; - size_t restore_array_stack = scope.run_sizes.array_stack; - size_t arg_stack_begin = scope.run_sizes.arg_stack; - scope.run_sizes.arg_stack += op->args.size(); - size_t args_size = op->args.size(); - ICHECK_GT(args_size, 0); - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - for (size_t i = 1; i < op->args.size(); ++i) { - PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1); - PrimExpr arg = op->args[i]; - DataType t = arg.dtype(); - DataType api_type = APIType(t); - if (t != api_type) { - arg = Cast(api_type, arg); - } - prep_seq.emplace_back(TVMStructSet(scope.stack_value, - static_cast(arg_stack_begin + i - 1), - builtin::kTVMValueContent, arg)); - int arg_tcode = api_type.code(); - ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; - prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index})); - } - // Verify stack size matches earlier value. - if (is_precheck_) { - scope.UpdateMax(); - } else { - scope.AssertMaxIsValid(); - } - scope.run_sizes.shape_stack = restore_shape_stack; - scope.run_sizes.array_stack = restore_array_stack; - // Update the top of the stack, so we can use more than one - // packed function's arguments with the one stack. - scope.run_sizes.arg_stack = arg_stack_begin + args_size - 1; - Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + op->args.size() - 1), - // Pass traced value. - op->args[args_size - 1]}; - return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args); + return Call(op->dtype, lowered_packed_op, packed_args); } Stmt MakeNdMemAllocWithScope(const LetStmtNode* let, const CallNode* call) { diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index 3d1c6f9f7d5b..a9d692cc0752 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -64,5 +64,9 @@ TVM_REGISTER_GLOBAL("topi.any").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]); }); +TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::collapse_sum(args[0], args[1]); +}); + } // namespace topi } // namespace tvm diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc new file mode 100644 index 000000000000..9ddae05e59e3 --- /dev/null +++ b/tests/cpp/nested_msg_test.cc @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::runtime; +using namespace tvm::relax; + +TEST(NestedMsg, Basic) { + // start with no annotation + relax::Var x("x", NullOpt), y("y", NullOpt); + + // constructor from array, T and nullopt. + NestedMsg msg({x, NullOpt, x}); + + EXPECT_TRUE(msg.IsNested()); + EXPECT_FALSE(msg.IsLeaf()); + EXPECT_TRUE(msg != nullptr); + + EXPECT_ANY_THROW(msg.LeafValue()); + + auto arr = msg.NestedArray(); + EXPECT_TRUE(arr[0].same_as(x)); + EXPECT_TRUE(arr[1] == nullptr); + EXPECT_TRUE(arr[1].IsNull()); + + EXPECT_TRUE(arr[2].LeafValue().same_as(x)); + + auto a0 = arr[0]; + EXPECT_TRUE(a0.IsLeaf()); + + // assignment + // assign null + a0 = NullOpt; + EXPECT_TRUE(a0 == nullptr); + + // assign array + a0 = {x, {x, NullOpt, y}}; + EXPECT_TRUE(a0.IsNested()); + auto t0 = a0.NestedArray()[1]; + EXPECT_TRUE(t0.IsNested()); + EXPECT_TRUE(t0.NestedArray()[2].same_as(y)); + + // assign leaf + a0 = x; + + EXPECT_TRUE(a0.IsLeaf()); + EXPECT_TRUE(a0.same_as(x)); +} + +TEST(NestedMsg, ForEachLeaf) { + relax::Var x("x", NullOpt), y("y", NullOpt); + NestedMsg msg = {x, {x, y}, NullOpt, {x, {x, y}}}; + + int x_count = 0, y_count = 0; + + ForEachLeaf(msg, [&](const Expr& v) { + if (v.same_as(x)) ++x_count; + if (v.same_as(y)) ++y_count; + }); + EXPECT_EQ(x_count, 4); + EXPECT_EQ(y_count, 2); +} + +TEST(NestedMsg, Equal) { + relax::Var x("x", NullOpt), y("y", NullOpt); + relax::Var z("z", NullOpt); + + auto fequal = [](Expr lhs, Expr rhs) { return lhs.same_as(rhs); }; + + using M = NestedMsg; + + EXPECT_TRUE(Equal(M(NullOpt), M(NullOpt), fequal)); + + EXPECT_TRUE(Equal(M(x), M(x), fequal)); + + EXPECT_TRUE(Equal(M({x, y}), M({x, y}), fequal)); + + EXPECT_TRUE(Equal(M({x, NullOpt}), M({x, NullOpt}), fequal)); + + EXPECT_TRUE(Equal(M({x, {NullOpt, y}}), M({x, {NullOpt, y}}), fequal)); + + EXPECT_TRUE(Equal(M({x, {NullOpt, y}, {x, z}}), M({x, {NullOpt, y}, {x, z}}), fequal)); + + // type mismatch + EXPECT_FALSE(Equal(M({x, {NullOpt, y}, x}), M({x, {NullOpt, y}, {x, z}}), fequal)); + + EXPECT_FALSE(Equal(M({x, {NullOpt, y}, {x, NullOpt}}), M({x, {NullOpt, y}, {x, z}}), fequal)); + + EXPECT_FALSE(Equal(M({x, {NullOpt, y}}), M({x, {NullOpt, y}, {x, z}}), fequal)); + + EXPECT_FALSE(Equal(M(x), M(NullOpt), fequal)); + + EXPECT_FALSE(Equal(M(NullOpt), M(x), fequal)); + + EXPECT_FALSE(Equal(M(x), M(Array({x})), fequal)); + + EXPECT_FALSE(Equal(M(Array({x})), M(x), fequal)); +} + +TEST(NestedMsg, MapAndDecompose) { + relax::Var x("x", PrimStructInfo(runtime::DataType::Int(16))); + relax::Var y("y", PrimStructInfo(runtime::DataType::Int(32))); + relax::Var z("z", PrimStructInfo(runtime::DataType::Int(64))); + + BlockBuilder bb = BlockBuilder::Create(NullOpt); + relax::Expr t0 = bb->Normalize(Tuple({x, y})); + relax::Expr t1 = bb->Normalize(Tuple({t0, x, z, t0})); + + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + + auto output = MapToNestedMsg(t1, [&](Expr value) { + if (value.same_as(x)) return c0; + if (value.same_as(y)) return c1; + return c2; + }); + + NestedMsg expected = {{c0, c1}, c0, c2, {c0, c1}}; + + EXPECT_TRUE(Equal(output, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); + + auto output2 = + MapToNestedMsg(GetStructInfo(t1), [&](StructInfo sinfo) -> NestedMsg { + const auto* prim_sinfo = sinfo.as(); + if (prim_sinfo == nullptr) return NullOpt; + int bits = prim_sinfo->dtype.bits(); + if (bits == 16) return c0; + if (bits == 32) return c1; + if (bits == 64) return c2; + return NullOpt; + }); + + EXPECT_TRUE(Equal(output2, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); + + int x_count = 0, y_count = 0, z_count = 0; + + DecomposeNestedMsg(t1, expected, [&](Expr value, NestedMsg msg) { + if (value.same_as(x)) { + EXPECT_TRUE(msg.same_as(c0)); + ++x_count; + } else if (value.same_as(y)) { + EXPECT_TRUE(msg.same_as(c1)); + ++y_count; + } else { + EXPECT_TRUE(msg.same_as(c2)); + ++z_count; + } + }); + EXPECT_EQ(x_count, 3); + EXPECT_EQ(y_count, 2); + EXPECT_EQ(z_count, 1); +} + +TEST(NestedMsg, MapToNestedMsgBySInfo) { + auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0); + auto sf1 = TupleStructInfo({sf0, sf0}); + auto sf2 = TupleStructInfo({sf0, sf0}); + auto x = relax::Var("x", TupleStructInfo({sf1, sf2, sf0})); + + auto msg = MapToNestedMsgBySInfo(x, [](Expr value) { return value; }); + + EXPECT_TRUE(msg.IsNested()); + auto arr = msg.NestedArray(); + + EXPECT_TRUE(arr[1].IsNested()); + auto arr1 = arr[1].NestedArray(); + + EXPECT_TRUE(arr1[0].IsLeaf()); + EXPECT_TRUE(StructuralEqual()(arr1[0].LeafValue(), TupleGetItem(TupleGetItem(x, 1), 0))); + + EXPECT_TRUE(arr[2].IsLeaf()); + EXPECT_TRUE(StructuralEqual()(arr[2].LeafValue(), TupleGetItem(x, 2))); +} + +TEST(NestedMsg, NestedMsgToExpr) { + auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0); + auto sf1 = TupleStructInfo({sf0, sf0}); + + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + + relax::Var x("x", sf0), y("y", sf0), z("z", sf0); + + NestedMsg msg = {c0, {c0, c1}, {c0, {c1, c2}}}; + auto expr = NestedMsgToExpr(msg, [&](Optional leaf) { + ICHECK(leaf.defined()); + int value = leaf.value().IntValue(); + switch (value) { + case 0: + return x; + case 1: + return y; + default: + return z; + } + }); + + Expr expected = Tuple({x, Tuple({x, y}), Tuple({x, Tuple({y, z})})}); + EXPECT_TRUE(StructuralEqual()(expr, expected)); + + // test simplified + relax::Var t("t", sf1); + NestedMsg msg1 = {TupleGetItem(t, 0), TupleGetItem(t, 1)}; + auto expr1 = NestedMsgToExpr(msg1, [](Optional leaf) { return leaf.value(); }); + EXPECT_TRUE(StructuralEqual()(expr1, t)); +} + +TEST(NestedMsg, CombineNestedMsg) { + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + + NestedMsg lhs = {c0, {c0, c1}, NullOpt, {c0, {c1, c2}}}; + NestedMsg rhs = {c1, {c2, NullOpt}, NullOpt, {c1, {c2, c2}}}; + NestedMsg expected = {c1, {c2, c1}, NullOpt, {c1, {c2, c2}}}; + + auto output = CombineNestedMsg(lhs, rhs, [](Integer x, Integer y) { + if (x->value > y->value) return x; + return y; + }); + + EXPECT_TRUE(Equal(output, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); +} + +TEST(NestedMsg, MapNestedMsg) { + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + auto c3 = Integer(3); + + NestedMsg msg = {c0, {c0, c1}, NullOpt, {c0, {c2, c1}}}; + NestedMsg expected = {c3, {c3, NullOpt}, NullOpt, {c3, {c2, NullOpt}}}; + + auto output = MapNestedMsg(msg, [](Integer x) { + if (x->value == 0) { + return NestedMsg(Integer(3)); + } else if (x->value == 1) { + return NestedMsg(); + } else { + return NestedMsg(x); + } + }); + + EXPECT_TRUE(Equal(output, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); +} + +TEST(NestedMsg, TransformTupleLeaf) { + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + using NInt = NestedMsg; + + NInt msg1 = {c0, {c0, c1}, c2, {c0, {c1, c2}}}; + NInt msg2 = {c1, {c2, c0}, c2, {c1, {c2, c0}}}; + + PrimStructInfo s = PrimStructInfo(runtime::DataType::Int(32)); + relax::Var x("x", s), y("y", s), z("z", s); + BlockBuilder bb = BlockBuilder::Create(NullOpt); + Expr expr = bb->Normalize(Tuple({x, Tuple({x, x}), x, Tuple({x, Tuple({x, x})})})); + + auto ftransleaf = [&](Expr value, std::array msgs) -> Expr { + int lhs = Downcast(msgs[0].LeafValue())->value; + int rhs = Downcast(msgs[1].LeafValue())->value; + if (lhs > rhs) + return z; + else if (lhs == rhs) + return value; + else + return y; + }; + + Expr expected = Tuple({y, Tuple({y, z}), x, Tuple({y, Tuple({y, z})})}); + + EXPECT_TRUE(StructuralEqual()( + TransformTupleLeaf(expr, std::array({msg1, msg2}), ftransleaf), expected)); + + EXPECT_TRUE( + expr.same_as(TransformTupleLeaf(expr, std::array({msg1, msg1}), ftransleaf))); +} diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index c4a404198c16..ff4c96d4d13c 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -131,6 +131,7 @@ "apps/wasm-standalone/wasm-graph/.cargo/config", # html for demo purposes "web/apps/browser/rpc_server.html", + "web/apps/browser/rpc_plugin.html", # images are normally not allowed # discuss with committers before add more images "apps/android_rpc/app/src/main/res/mipmap-hdpi/ic_launcher.png", diff --git a/tests/python/contrib/test_hexagon/test_relax_integration.py b/tests/python/contrib/test_hexagon/test_relax_integration.py new file mode 100644 index 000000000000..823f4bdb9294 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_relax_integration.py @@ -0,0 +1,236 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax hexagon test.""" + +import numpy as np +import pytest +import tvm.testing +from tvm import relay, relax, runtime +from tvm.relax.testing import relay_translator +from tvm.contrib.hexagon.session import Session +from tvm.relay import testing + + +class TestConv2d: + """Test conv2d op""" + + n_batch = tvm.testing.parameter(1, relay.Any()) + + @tvm.testing.requires_hexagon + def test_conv2d(self, hexagon_session: Session, n_batch): + """Test Relax conv2d op and compare with Relay""" + dtype = "float32" + data = relay.var("data", relay.TensorType((n_batch, 64, 64, 3), dtype)) + weight = relay.var("weight", relay.TensorType((5, 5, 3, 8), dtype)) + y = relay.nn.conv2d( + data, + weight, + padding=(2, 2), + kernel_size=(5, 5), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="float32", + ) + f = relay.Function([data, weight], y) + relay_mod = tvm.IRModule.from_expr(f) + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + relax_mod = relay_translator.from_relay(relay_mod["main"], target) + + exe = relax.build(relax_mod, target) + dev = hexagon_session.device + vm_mod = hexagon_session.get_executor_from_factory(exe) + vm_rt = relax.VirtualMachine(vm_mod, dev) + + data_np = np.random.rand(1, 64, 64, 3).astype(np.float32) + weight_np = np.random.rand(5, 5, 3, 8).astype(np.float32) + + # Run on hexagon and get result + data = tvm.nd.array(data_np, dev) + weight = tvm.nd.array(weight_np, dev) + vm_rt.set_input("main", data, weight) + vm_rt.invoke_stateful("main") + hexagon_res = vm_rt.get_outputs("main") + + # Compile and run on Relay for comparison. + dev = tvm.cpu() + data = tvm.nd.array(data_np, dev) + weight = tvm.nd.array(weight_np, dev) + + target = tvm.target.Target("llvm", host="llvm") + vm_exec = relay.vm.compile(relay_mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu()) + relay_res = vm_factory.invoke("main", data, weight) + tvm.testing.assert_allclose(hexagon_res.numpy(), relay_res.numpy(), rtol=1e-3) + + +class TestMLP: + """Test MLP""" + + n_batch = tvm.testing.parameter(1, relay.Any()) + + @tvm.testing.requires_hexagon + def test_mlp(self, hexagon_session: Session, n_batch): + """Test Relax MLP and compare with Relay""" + relay_mod, params = testing.mlp.get_workload(batch_size=n_batch, dtype="float32") + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) + + exe = relax.build(relax_mod, target) + hexagon_device = hexagon_session.device + + vm_mod = hexagon_session.get_executor_from_factory(exe) + vm_rt = relax.VirtualMachine(vm_mod, hexagon_device) + + shape = (1, 1, 28, 28) + data_np = np.random.rand(*shape).astype("float32") + data = tvm.nd.array(data_np, hexagon_device) + vm_rt.set_input("main", data) + vm_rt.invoke_stateful("main") + hexagon_res = vm_rt.get_outputs("main") + + # Compile and run on Relay for comparison. + cpu_dev = tvm.cpu() + data = tvm.nd.array(data_np, cpu_dev) + + target = tvm.target.Target("llvm", host="llvm") + vm_exec = relay.vm.compile(relay_mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, cpu_dev) + relay_res = vm_factory.invoke("main", data, **params) + tvm.testing.assert_allclose(hexagon_res.numpy(), relay_res.numpy(), rtol=1e-3) + + +def get_onnx_mobilenet(): + """Download and import mobilenet model with ONNX""" + import onnx # pylint: disable=import-outside-toplevel + + # pylint: disable=line-too-long + model_url = "https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-7.onnx" + model_path = tvm.contrib.download.download_testdata( + model_url, "mobilenetv2-7.onnx", module="onnx" + ) + return onnx.load(model_path) + + +@pytest.mark.skip("takes too long (~20min)") +@tvm.testing.requires_hexagon +def test_mobilenet_onnx(hexagon_session: Session): + """Test MobileNetV2 ONNX model""" + onnx_model = get_onnx_mobilenet() + data_np = np.random.rand(1, 3, 224, 224).astype("float32") + shape_dict = {"input": data_np.shape} + relay_mod, _ = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True) + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + relax_mod = relay_translator.from_relay(relay_mod["main"], target_hexagon) + + # Compile and run on Hexagon. + exe = relax.build(relax_mod, target) + dev = hexagon_session.device + + vm_mod = hexagon_session.get_executor_from_factory(exe) + vm_rt = relax.VirtualMachine(vm_mod, dev) + data = tvm.nd.array(data_np, dev) + vm_rt.set_input("main", data) + vm_rt.invoke_stateful("main") + hexagon_res = vm_rt.get_outputs("main") + + # Compile and run on LLVM for comparison. + relax_mod = relay_translator.from_relay(relay_mod["main"], "llvm") + exe = relax.build(relax_mod, "llvm") + dev = tvm.cpu() + vm_rt = relax.VirtualMachine(exe, dev) + data = tvm.nd.array(data_np, dev) + llvm_res = vm_rt["main"](data) + tvm.testing.assert_allclose(hexagon_res.numpy(), llvm_res.numpy(), rtol=1e-3) + + +@pytest.mark.skip("takes too long (~20min)") +@tvm.testing.requires_hexagon +def test_mobilenet(hexagon_session: Session): + """Test MobileNet workload""" + relay_mod, params = testing.mobilenet.get_workload(batch_size=1, dtype="float32") + data_np = np.random.rand(1, 3, 224, 224).astype("float32") + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + # translate the relay mobilenet and bind params + relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) + + # Compile and run on Hexagon. + exe = relax.build(relax_mod, target) + dev = hexagon_session.device + + vm_mod = hexagon_session.get_executor_from_factory(exe) + vm_rt = relax.VirtualMachine(vm_mod, dev) + data = tvm.nd.array(data_np, dev) + vm_rt.set_input("main", data) + vm_rt.invoke_stateful("main") + hexagon_res = vm_rt.get_outputs("main") + + # Compile and run on LLVM for comparison. + relax_mod = relay_translator.from_relay(relay_mod["main"], "llvm", params) + exe = relax.build(relax_mod, "llvm") + dev = tvm.cpu() + vm_rt = relax.VirtualMachine(exe, dev) + data = tvm.nd.array(data_np, dev) + llvm_res = vm_rt["main"](data) + tvm.testing.assert_allclose(hexagon_res.numpy(), llvm_res.numpy(), rtol=1e-3) + + +@pytest.mark.skip("takes too long (~20min)") +@tvm.testing.requires_hexagon +def test_mobilenet_dyn(hexagon_session: Session): + """Test MobileNet workload with dynamic batch size""" + relay_mod, params = testing.mobilenet.get_workload(batch_size=relay.Any(), dtype="float32") + data_np = np.random.rand(1, 3, 224, 224).astype("float32") + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + # translate the relay mobilenet and bind params + relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) + + # Compile and run on Hexagon. + exe = relax.build(relax_mod, target) + dev = hexagon_session.device + + vm_mod = hexagon_session.get_executor_from_factory(exe) + vm_rt = relax.VirtualMachine(vm_mod, dev) + data = tvm.nd.array(data_np, dev) + vm_rt.set_input("main", data) + vm_rt.invoke_stateful("main") + hexagon_res = vm_rt.get_outputs("main") + + # Compile and run on Relay for comparison. + dev = tvm.cpu() + data = tvm.nd.array(data_np, dev) + + target = tvm.target.Target("llvm", host="llvm") + vm_exec = relay.vm.compile(relay_mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu()) + relay_res = vm_factory.invoke("main", data, **params) + tvm.testing.assert_allclose(hexagon_res.numpy(), relay_res.numpy(), rtol=1e-3) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py index f4342f5814df..1c68d084f798 100644 --- a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py +++ b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py @@ -24,6 +24,8 @@ from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET from tvm.relay.backend import Executor from tvm.relay.testing import run_opt_pass, run_infer_type +from tvm.relax.testing import relay_translator +from .infrastructure import get_hexagon_target @tvm.testing.requires_hexagon @@ -471,5 +473,73 @@ def test_qnn_tanh(self, hexagon_session: Session): np.testing.assert_equal(hexagon_output, llvm_output) +def test_qnn_conv2d_is_scalar_relax(): + """Test to check if the input scale and output scale is constant, + qnn.requantize will compute with fixed_point_value.""" + + data_shape = (1, 64, 56, 56) + kernel_shape = (128, 64, 3, 3) + + data_dtype = "uint8" + in_data = relay.var("data", shape=data_shape, dtype=data_dtype) + + kernel_dtype = "int8" + kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype) + azp = np.array([0]).astype("int32") + wzp = np.array([0]).astype("int32") # assumed zero + bias = (np.zeros((1, 512, 1, 1), dtype="uint32") * -12).astype("int32") + rqsci = np.array([1]).astype("float32") + rqzpi = np.array([0]).astype("int32") + rqsco = np.array([1]).astype("float32") + rqzpo = np.array([0]).astype("int32") + strides = (1, 1) + + input_zero_point = relay.const(azp[0], dtype="int32") + kernel_zero_point = relay.const(wzp[0], dtype="int32") + + input_scale = relay.const(1.0, dtype="float32") + kernel_scale = relay.const(1.0, dtype="float32") + + conv_op = relay.qnn.op.conv2d( + in_data, + kernel, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + input_scale=input_scale, + kernel_scale=kernel_scale, + kernel_size=(kernel_shape[2], kernel_shape[3]), + channels=kernel_shape[0], + strides=strides, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + + bias = relay.var("bias", shape=(kernel_shape[0],), dtype="int32") + bias_op = relay.nn.bias_add(conv_op, bias, axis=1) + + requant_op = relay.qnn.op.requantize( + bias_op, + input_scale=relay.const(rqsci[0]), + input_zero_point=relay.const(rqzpi[0]), + output_scale=relay.const(rqsco[0]), + output_zero_point=relay.const(rqzpo[0]), + out_dtype="int32", + ) + + clip_op = relay.op.clip(requant_op, 0.0, 255.0) + cast_op = relay.op.cast(clip_op, "uint8") + + func = relay.Function([in_data, kernel, bias], cast_op) + + mod = tvm.IRModule.from_expr(func) + target_hexagon = get_hexagon_target("v69") + relax_mod = relay_translator.from_relay( + mod["main"], target_hexagon, disabled_pass=["qnn.Legalize"] + ) + + assert "requantize_scalar" in relax_mod.astext(show_meta_data=False) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/conftest.py b/tests/python/relax/conftest.py new file mode 100644 index 000000000000..f1b1187066e6 --- /dev/null +++ b/tests/python/relax/conftest.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations + +import pytest + +import tvm +from tvm.relax.ir.instrument import WellFormedInstrument + + +tvm.transform.PassContext.current().override_instruments([WellFormedInstrument()]) diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py new file mode 100644 index 000000000000..72a256d733a2 --- /dev/null +++ b/tests/python/relax/test_analysis.py @@ -0,0 +1,425 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Set, Union + +import tvm +import tvm.testing +from tvm import tir +from tvm import relax as rx +from tvm.relax.analysis import ( + has_reshape_pattern, + udchain, + remove_all_unused, + name_to_binding, + all_vars, + all_global_vars, + free_vars, + bound_vars, +) +from tvm.script import relax as R, tir as T + + +def var_name_set(vars: List[Union[rx.Var, rx.GlobalVar]]) -> Set[str]: + return set(map(lambda v: v.name_hint, vars)) + + +def test_use_def(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", R.Tensor([m, n], "float16")) + y = rx.Var("y", R.Tensor([n], "float16")) + ib = rx.BlockBuilder() + with ib.function("func", [x, y]): + with ib.dataflow(): + lv0 = ib.emit(rx.op.add(x, y)) + lv1 = ib.emit(rx.op.multiply(lv0, y)) + gv0 = ib.emit_output(lv1) + ib.emit_func_output(gv0) + dfb = ib.get()["func"].body.blocks[0] + udc = udchain(dfb) + assert set(udc[x]) == {lv0} + assert set(udc[y]) == {lv0, lv1} + assert set(udc[lv0]) == {lv1} + assert set(udc[lv1]) == {gv0} + assert set(udc[gv0]) == set() + + +def test_chained_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed( + "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32") + ) + R.output(lv0) + return lv0 + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_binding_block_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed( + "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32") + ) + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return z + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return z + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_binding_block_fake_unused_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return lv0 + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + # This might bring side effect so cannot be removed. + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return lv0 + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_edge_binding_block_fake_unused_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((32, 32), "float32"))) + return x + + optimized = remove_all_unused(IdentityUnused["main"]) + tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"]) + + +def test_name_to_binding_var_shadowing(): + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + lv1 = lv0 + R.output(lv1) + + with R.dataflow(): + lv0 = lv1 # shadowing + lv2 = lv0 + R.output(lv2) + return lv2 + + n2binding = name_to_binding(main) + + assert "lv0" in n2binding + assert "lv1" in n2binding + assert "lv2" in n2binding + + assert len(n2binding["lv0"]) == 2 + + +@tvm.script.ir_module +class VarExample: + @R.function + def func(a: R.Tensor) -> R.Tensor: + # normalized into assigning R.add(a, a) to a var and returning it + return R.add(a, a) + + @R.function + def main(x: R.Tensor, y: R.Tensor) -> R.Tensor: + cls = VarExample + z = R.add(x, y) + # no binding here + _ = R.match_cast(x, R.Tensor((5, 5))) + with R.dataflow(): + q = R.add(z, z) + p = cls.func(q) + r = R.match_cast(p, R.Tensor((5, 5))) + s = r + R.output(s) + return s + + +def test_all_vars(): + vars = all_vars(VarExample["func"]) + assert len(vars) == 2 + assert vars[0].name_hint == "a" + # the body of the seq expr in the func body is a var + assert vars[1] == VarExample["func"].body.body + + var_names = var_name_set(all_vars(VarExample["main"])) + assert var_names == {"_", "x", "y", "z", "p", "q", "r", "s"} + + +def test_bound_vars(): + vars = bound_vars(VarExample["func"]) + assert len(vars) == 2 + assert vars[0].name_hint == "a" + # the body of the seq expr in the func body is a bound var + assert vars[1] == VarExample["func"].body.body + + # all the vars are bound + var_names = var_name_set(bound_vars(VarExample["main"])) + assert var_names == {"_", "x", "y", "z", "p", "q", "r", "s"} + + # if we consider only the body, then the function arguments are not bound + body_names = var_name_set(bound_vars(VarExample["main"].body)) + assert body_names == {"_", "z", "p", "q", "r", "s"} + + # only binding is in the (normalized) body + simple_body_vars = bound_vars(VarExample["func"].body) + assert len(simple_body_vars) == 1 + assert simple_body_vars[0] == VarExample["func"].body.body + + +def test_free_vars(): + # all the vars are bound + assert len(free_vars(VarExample["func"])) == 0 + assert len(free_vars(VarExample["main"])) == 0 + + # the arguments are free if we look only at the bodies + func_free = var_name_set(free_vars(VarExample["func"].body)) + main_free = var_name_set(free_vars(VarExample["main"].body)) + assert len(func_free) == 1 + assert len(main_free) == 2 + assert "a" in func_free + assert main_free == {"x", "y"} + + # function that captures vars + x = rx.Var("x", R.Tensor(ndim=-1)) + y = rx.Var("y", R.Tensor(ndim=-1)) + z = rx.Var("z", R.Tensor(ndim=-1)) + inner = rx.Function( + [z], + rx.op.add(x, rx.op.add(y, z)), + ret_struct_info=R.Tensor(ndim=-1), + ) + outer = rx.Function( + [x, y], + rx.Call(inner, [y]), + ret_struct_info=R.Tensor(ndim=-1), + ) + assert len(free_vars(outer)) == 0 + assert var_name_set(free_vars(inner)) == {"x", "y"} + + +def test_all_global_vars(): + # there is one call to "func" + global_vars = all_global_vars(VarExample["main"]) + assert len(global_vars) == 1 + assert global_vars[0].name_hint == "func" + + gv1 = rx.GlobalVar("gv1") + gv2 = rx.GlobalVar("gv2") + gv3 = rx.GlobalVar("gv3") + call = rx.Call(gv1, [gv2, gv3]) + call_var_names = var_name_set(all_global_vars(call)) + assert call_var_names == {"gv1", "gv2", "gv3"} + + +def test_reshape_pattern_reshape(): + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((1, 2, 3, 4), "float32"), + T_reshape: T.Buffer((8, 3), "float32"), + ): + for i0, i1 in T.grid(8, 3): + with T.block("T_reshape"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads( + rxplaceholder[ + (ax0 * 3 + ax1) // 24, + (ax0 * 3 + ax1) % 24 // 12, + (ax0 * 3 + ax1) % 12 // 4, + (ax0 * 3 + ax1) % 4, + ] + ) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = rxplaceholder[ + (ax0 * 3 + ax1) // 24, + (ax0 * 3 + ax1) % 24 // 12, + (ax0 * 3 + ax1) % 12 // 4, + (ax0 * 3 + ax1) % 4, + ] + + assert has_reshape_pattern(reshape) + + +def test_reshape_pattern_reshape_scheduled(): + @T.prim_func + def reshape_scheduled( + rxplaceholder: T.Buffer((1, 2, 3, 4), "float32"), + T_reshape: T.Buffer((8, 3), "float32"), + ): + for i0_i1_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i0_i1_fused_1 in T.thread_binding(24, thread="threadIdx.x"): + with T.block("T_reshape"): + ax0 = T.axis.spatial(8, (i0_i1_fused_0 * 24 + i0_i1_fused_1) // 3) + ax1 = T.axis.spatial(3, (i0_i1_fused_0 * 24 + i0_i1_fused_1) % 3) + T.reads( + rxplaceholder[ + (ax0 * 3 + ax1) // 24, + (ax0 * 3 + ax1) % 24 // 12, + (ax0 * 3 + ax1) % 12 // 4, + (ax0 * 3 + ax1) % 4, + ] + ) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = rxplaceholder[ + (ax0 * 3 + ax1) // 24, + (ax0 * 3 + ax1) % 24 // 12, + (ax0 * 3 + ax1) % 12 // 4, + (ax0 * 3 + ax1) % 4, + ] + + assert has_reshape_pattern(reshape_scheduled) + + +def test_reshape_pattern_expand_dims(): + @T.prim_func + def expand_dims( + rxplaceholder: T.Buffer((2, 3, 4), "float32"), + expand_dims: T.Buffer((2, 1, 1, 1, 3, 1, 4, 1), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(2, 1, 1, 1, 3, 1, 4, 1): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap( + "SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7] + ) + T.reads(rxplaceholder[i0_1, i4_1, i6_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1] = rxplaceholder[ + i0_1, i4_1, i6_1 + ] + + assert has_reshape_pattern(expand_dims) + + +def test_reshape_pattern_with_raggedness(): + @T.prim_func + def reshape_raggedness( + A: T.Buffer((100, 768), "float32"), + src_indptr: T.Buffer((9,), "int32"), + B: T.Buffer((100, 12, 64), "float32"), + ): + for b in T.serial(8): + with T.block("block0"): + vb = T.axis.spatial(8, b) + for i in T.serial(src_indptr[vb + 1] - src_indptr[vb]): + for h in T.serial(12): + for f in T.serial(64): + with T.block("block1"): + vi, vh, vf = T.axis.remap("SSS", [i, h, f]) + B[src_indptr[vb] + vi, vh, vf] = A[ + src_indptr[vb] + vi, vh * 64 + vf + ] + + assert has_reshape_pattern(reshape_raggedness) + + +def test_reshape_pattern_reject_seqstmt(): + @T.prim_func + def identity_bias(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")): + C = T.alloc_buffer((128, 128), "float32") + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + C[vi0, vi1] = A[vi0, vi1] + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + B[vi0, vi1] = C[vi0, vi1] + T.float32(1) + + @T.prim_func + def identity_identity(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")): + C = T.alloc_buffer((128, 128), "float32") + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + C[vi0, vi1] = A[vi0, vi1] + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + B[vi0, vi1] = C[vi0, vi1] + + assert not has_reshape_pattern(identity_bias) + assert not has_reshape_pattern(identity_identity) + + +def test_reshape_pattern_reject_reduction(): + @T.prim_func + def reduction(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4,), "float32")): + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SR", [i0, i1]) + with T.init(): + B[vi0] = T.float32(0) + B[vi0] = B[vi0] + A[vi0, vi1] + + assert not has_reshape_pattern(reduction) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_analysis_detect_recursion.py b/tests/python/relax/test_analysis_detect_recursion.py new file mode 100644 index 000000000000..b4c7adc84456 --- /dev/null +++ b/tests/python/relax/test_analysis_detect_recursion.py @@ -0,0 +1,453 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List +import tvm +import tvm.testing +from tvm import relax as rx +from tvm.script import relax as R, tir as T + +from tvm.relax.analysis import detect_recursion + + +def assert_groups(groups: List[List[rx.GlobalVar]], expected: List[List[str]]) -> None: + assert len(groups) == len(expected) + + # disregard order, search only by name for convenience + expected_sets = [set(expected_group) for expected_group in expected] + actual_sets = [set(map(lambda gv: gv.name_hint, actual_group)) for actual_group in groups] + + for expected_set in expected_sets: + assert expected_set in actual_sets + + +def test_no_recursion(): + @tvm.script.ir_module + class NoRecursion: + @R.function + def a(x: R.Object) -> R.Object: + return x + + @R.function + def b(x: R.Object) -> R.Object: + return x + + groups = detect_recursion(NoRecursion) + assert len(groups) == 0 + + +def test_simple_recursion(): + @tvm.script.ir_module + class SimpleRecursion: + @R.function + def c(x: R.Object) -> R.Object: + return SimpleRecursion.c(x) + + groups = detect_recursion(SimpleRecursion) + assert_groups(groups, ["c"]) + + +def test_tree(): + # no cycle! + @tvm.script.ir_module + class Tree: + @R.function + def a(x: R.Object) -> R.Object: + return Tree.b(x) + + @R.function + def b(x: R.Object) -> R.Object: + return Tree.c(x) + + @R.function + def c(x: R.Object) -> R.Object: + z: R.Object = Tree.d(x) + return Tree.e(z) + + @R.function + def d(x: R.Object) -> R.Object: + return Tree.e(x) + + @R.function + def e(x: R.Object) -> R.Object: + return x + + groups = detect_recursion(Tree) + assert len(groups) == 0 + + +def test_two_function_case(): + @tvm.script.ir_module + class TwoFunctionCase: + @R.function + def a(x: R.Object) -> R.Object: + return TwoFunctionCase.b(x) + + @R.function + def b(x: R.Object) -> R.Object: + return TwoFunctionCase.a(x) + + # not part of the group, shouldn't be reported + @R.function + def c(x: R.Object) -> R.Object: + return x + + groups = detect_recursion(TwoFunctionCase) + assert_groups(groups, [["a", "b"]]) + + +def test_two_groups_of_two(): + @tvm.script.ir_module + class TwoGroupsOfTwo: + @R.function + def a(x: R.Object) -> R.Object: + return TwoGroupsOfTwo.b(x) + + @R.function + def b(x: R.Object) -> R.Object: + return TwoGroupsOfTwo.a(x) + + @R.function + def c(x: R.Object) -> R.Object: + return TwoGroupsOfTwo.d(x) + + @R.function + def d(x: R.Object) -> R.Object: + return TwoGroupsOfTwo.c(x) + + # not part of either group, shouldn't be reported + @R.function + def e(x: R.Object) -> R.Object: + return x + + groups = detect_recursion(TwoGroupsOfTwo) + assert_groups(groups, [["a", "b"], ["c", "d"]]) + + +def test_mutual_recursion_and_simple_recursion(): + @tvm.script.ir_module + class MutualAndSimple: + @R.function + def a(x: R.Object) -> R.Object: + return MutualAndSimple.b(x) + + @R.function + def b(x: R.Object) -> R.Object: + return MutualAndSimple.a(x) + + # forms its own group + @R.function + def c(x: R.Object) -> R.Object: + return MutualAndSimple.c(x) + + groups = detect_recursion(MutualAndSimple) + assert_groups(groups, [["a", "b"], ["c"]]) + + +def test_simultaneous_mutual_and_simple_recursion(): + # even though both call themselves and each other, + # it should still form only one group + @tvm.script.ir_module + class SimultaneousMutualAndSimple: + @R.function + def a(x: R.Object) -> R.Object: + cls = SimultaneousMutualAndSimple + return cls.b(cls.a(x)) + + @R.function + def b(x: R.Object) -> R.Object: + cls = SimultaneousMutualAndSimple + return cls.a(cls.b(x)) + + groups = detect_recursion(SimultaneousMutualAndSimple) + assert_groups(groups, [["a", "b"]]) + + +def test_three_function_case(): + @tvm.script.ir_module + class ThreeFunctionCase: + @R.function + def a(x: R.Object) -> R.Object: + return ThreeFunctionCase.b(x) + + @R.function + def b(x: R.Object) -> R.Object: + return ThreeFunctionCase.c(x) + + @R.function + def c(x: R.Object) -> R.Object: + return ThreeFunctionCase.a(x) + + groups = detect_recursion(ThreeFunctionCase) + assert_groups(groups, [["a", "b", "c"]]) + + +def test_call_from_outside_of_group(): + @tvm.script.ir_module + class CallFromOutOfGroup: + # A calls into a group of mutually recursive functions, + # but is not part of the cycle + @R.function + def a(x: R.Object) -> R.Object: + return CallFromOutOfGroup.d(x) + + @R.function + def b(x: R.Object) -> R.Object: + return CallFromOutOfGroup.c(x) + + @R.function + def c(x: R.Object) -> R.Object: + return CallFromOutOfGroup.d(x) + + @R.function + def d(x: R.Object) -> R.Object: + return CallFromOutOfGroup.b(x) + + # E also calls into the cycle but isn't part of it + @R.function + def e(x: R.Object) -> R.Object: + return CallFromOutOfGroup.b(x) + + groups = detect_recursion(CallFromOutOfGroup) + assert_groups(groups, [["b", "c", "d"]]) + + +def test_call_from_group_to_outside(): + @tvm.script.ir_module + class CallFromGroupToOutside: + # A calls into a group of mutually recursive functions, + # but is not part of the cycle + @R.function + def a(x: R.Object) -> R.Object: + return CallFromGroupToOutside.b(x) + + @R.function + def b(x: R.Object) -> R.Object: + # d is called from a member of the group but it is not part of the cycle + z: R.Object = CallFromGroupToOutside.d(x) + return CallFromGroupToOutside.c(z) + + @R.function + def c(x: R.Object) -> R.Object: + return CallFromGroupToOutside.a(x) + + @R.function + def d(x: R.Object) -> R.Object: + return x + + groups = detect_recursion(CallFromGroupToOutside) + assert_groups(groups, [["a", "b", "c"]]) + + +def test_group_with_two_cycles(): + """ + a -> b <- f + ^ | ^ + | v | + d <- c -> e + + There are two smaller cycles in this group, + but you can have one big cycle + B -> C -> D -> A -> B -> C -> E -> F -> B + """ + + @tvm.script.ir_module + class GroupWithTwoCycles: + @R.function + def a(x: R.Object) -> R.Object: + return GroupWithTwoCycles.b(x) + + @R.function + def b(x: R.Object) -> R.Object: + return GroupWithTwoCycles.c(x) + + @R.function + def c(x: R.Object) -> R.Object: + y = GroupWithTwoCycles.d(x) + return GroupWithTwoCycles.e(y) + + @R.function + def d(x: R.Object) -> R.Object: + return GroupWithTwoCycles.a(x) + + @R.function + def e(x: R.Object) -> R.Object: + return GroupWithTwoCycles.f(x) + + @R.function + def f(x: R.Object) -> R.Object: + return GroupWithTwoCycles.b(x) + + groups = detect_recursion(GroupWithTwoCycles) + assert_groups(groups, [["a", "b", "c", "d", "e", "f"]]) + + +def test_multicycle_example(): + """ + Example from the documentation + A <-> B <-> C + ^ | ^ + | v | + | D | + | | | + v v v + E <-> F <-> G + """ + + @tvm.script.ir_module + class MulticycleExample: + @R.function + def a(x: R.Object) -> R.Object: + cls = MulticycleExample + y = cls.b(x) + return cls.e(y) + + @R.function + def b(x: R.Object) -> R.Object: + cls = MulticycleExample + y = cls.a(x) + z = cls.c(y) + return cls.d(z) + + @R.function + def c(x: R.Object) -> R.Object: + cls = MulticycleExample + y = cls.g(x) + return cls.b(y) + + @R.function + def d(x: R.Object) -> R.Object: + cls = MulticycleExample + return cls.f(x) + + @R.function + def e(x: R.Object) -> R.Object: + cls = MulticycleExample + y = cls.f(x) + return cls.a(y) + + @R.function + def f(x: R.Object) -> R.Object: + cls = MulticycleExample + y = cls.g(x) + return cls.e(y) + + @R.function + def g(x: R.Object) -> R.Object: + cls = MulticycleExample + y = cls.f(x) + return cls.c(y) + + groups = detect_recursion(MulticycleExample) + assert_groups(groups, [["a", "b", "c", "d", "e", "f", "g"]]) + + +def test_control_flow(): + @tvm.script.ir_module + class ControlFlowExample: + @R.function + def a(x: R.Object) -> R.Object: + cls = ControlFlowExample + y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool") + if y: + ret = cls.b(x) + else: + ret = cls.c(x) + return ret + + @R.function + def b(x: R.Object) -> R.Object: + cls = ControlFlowExample + return cls.a(x) + + @R.function + def c(x: R.Object) -> R.Object: + cls = ControlFlowExample + return cls.a(x) + + groups = detect_recursion(ControlFlowExample) + assert_groups(groups, [["a", "b", "c"]]) + + +def test_returning_self(): + @tvm.script.ir_module + class ReturnsSelf: + @R.function + def a() -> R.Object: + # this is also a form of recursion + return ReturnsSelf.a + + groups = detect_recursion(ReturnsSelf) + assert_groups(groups, [["a"]]) + + +def test_mutual_recursion_via_references(): + @tvm.script.ir_module + class GatherReferences: + @R.function + def a(x: R.Object) -> R.Object: + cls = GatherReferences + return cls.b(x) + + @R.function + def b(x: R.Object) -> R.Object: + cls = GatherReferences + return (cls.a, cls.b, cls.c) + + @R.function + def c(x: R.Object) -> R.Object: + cls = GatherReferences + return cls.a(x) + + groups = detect_recursion(GatherReferences) + assert_groups(groups, [["a", "b", "c"]]) + + +def test_disregard_primfuncs(): + @tvm.script.ir_module + class CallPrimFunc: + # copied from test_analysis.py + @T.prim_func + def identity_identity(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")): + C = T.alloc_buffer((128, 128), "float32") + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + C[vi0, vi1] = A[vi0, vi1] + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + B[vi0, vi1] = C[vi0, vi1] + + @R.function + def a(x: R.Tensor((4, 4), "float32")) -> R.Object: + cls = CallPrimFunc + y = R.call_tir(cls.identity_identity, x, R.Tensor((4, 4), "float32")) + return cls.b(y) + + @R.function + def b(x: R.Tensor((4, 4), "float32")) -> R.Object: + cls = CallPrimFunc + y = R.call_tir(cls.identity_identity, x, R.Tensor((4, 4), "float32")) + return cls.a(y) + + groups = detect_recursion(CallPrimFunc) + # the prim func should not be listed here + assert_groups(groups, [["a", "b"]]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_analysis_estimate_memory_usage.py b/tests/python/relax/test_analysis_estimate_memory_usage.py new file mode 100644 index 000000000000..32bb56a670b9 --- /dev/null +++ b/tests/python/relax/test_analysis_estimate_memory_usage.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import relax as R, tir as T +from tvm.relax.analysis import estimate_memory_usage + + +def test_basic(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add( + rxplaceholder: T.Buffer(T.int64(8), "float32"), + rxplaceholder_1: T.Buffer((), "float32"), + T_add: T.Buffer(T.int64(8), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), + T_reshape: T.Buffer(T.int64(8), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def relu( + rxplaceholder: T.Buffer(T.int64(8), "float32"), compute: T.Buffer(T.int64(8), "float32") + ): + T.evaluate(0) + + @T.prim_func + def log( + rxplaceholder: T.Buffer(T.int64(10), "float32"), + compute: T.Buffer(T.int64(10), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def exp( + rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), + compute: T.Buffer((T.int64(2), T.int64(4)), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def pad( + rxplaceholder: T.Buffer(T.int64(8), "float32"), + PadInput: T.Buffer(T.int64(10), "float32"), + ): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = Module + storage: R.Object = R.memory.alloc_storage( + R.shape([32]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor( + storage, offset=0, shape=R.shape([2, 4]), dtype="float32" + ) + _: R.Tuple() = cls.exp(x, alloc) + lv: R.Tensor((2, 4), dtype="float32") = alloc + lv1: R.Tensor((8,), dtype="float32") = R.call_packed( + "vm.builtin.reshape", lv, R.shape([8]), sinfo_args=[R.Tensor((8,), dtype="float32")] + ) + storage1: R.Object = R.memory.alloc_storage( + R.shape([40]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( + storage1, offset=0, shape=R.shape([8]), dtype="float32" + ) + _1: R.Tuple() = cls.relu(lv1, alloc1) + _2: R.Tuple() = R.memory.kill_tensor(alloc) + _3: R.Tuple() = R.memory.kill_tensor(lv1) + lv2: R.Tensor((8,), dtype="float32") = alloc1 + alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( + storage, offset=0, shape=R.shape([8]), dtype="float32" + ) + _4: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2) + _5: R.Tuple() = R.memory.kill_tensor(alloc1) + lv3: R.Tensor((8,), dtype="float32") = alloc2 + alloc3: R.Tensor((10,), dtype="float32") = R.memory.alloc_tensor( + storage1, offset=0, shape=R.shape([10]), dtype="float32" + ) + _6: R.Tuple() = cls.pad(lv3, alloc3) + _7: R.Tuple() = R.memory.kill_tensor(alloc2) + lv4: R.Tensor((10,), dtype="float32") = alloc3 + alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([10]), dtype="float32", runtime_device_index=0 + ) + _8: R.Tuple() = cls.log(lv4, alloc4) + _9: R.Tuple() = R.memory.kill_tensor(alloc3) + gv5: R.Tensor((10,), dtype="float32") = alloc4 + _11: R.Tuple() = R.memory.kill_storage(storage) + _10: R.Tuple() = R.memory.kill_storage(storage1) + return gv5 + + assert ( + estimate_memory_usage(Module) + == r"""Memory usage estimation: +- Function main: + * Without memory planning, there are 5 constant-size memory allocation(s) with total size 1.639e-07 GB. + * With memory planning, there are 2 constant-size memory allocation(s) with total size 6.706e-08 GB. + * Memory planning reduces constant memory size to 40.9%.""" + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py new file mode 100644 index 000000000000..03b98f8a565e --- /dev/null +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -0,0 +1,561 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests analysis functions of struct info""" + +import pytest +import tvm +import tvm.testing +from tvm import relax as rx, TVMError +from tvm import tir + + +def test_get_static_type_basic(): + # object + s0 = rx.ObjectStructInfo() + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s0), rx.ObjectType()) + + # prim + s1 = rx.PrimStructInfo("float32") + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s1), tvm.ir.PrimType("float32")) + + +def test_get_static_type_shape(): + # shape + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s2 = rx.ShapeStructInfo([1, n + 1, m]) + s3 = rx.ShapeStructInfo(ndim=2) + + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s2), rx.ShapeType(ndim=3)) + + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s3), rx.ShapeType(ndim=2)) + + +def test_get_static_type_tensor(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + + tvm.ir.assert_structural_equal( + rx.analysis.get_static_type(s4), rx.DynTensorType(ndim=3, dtype="int64") + ) + + +def test_get_static_type_tuple(): + # tuple + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + s0 = rx.ObjectStructInfo() + s2 = rx.ShapeStructInfo([1, n + 1, m]) + s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + t0 = rx.TupleStructInfo([s4, s0]) + t1 = rx.TupleStructInfo([t0, s2]) + + tvm.ir.assert_structural_equal( + rx.analysis.get_static_type(t1), + rx.TupleType( + [ + rx.TupleType([rx.DynTensorType(ndim=3, dtype="int64"), rx.ObjectType()]), + rx.ShapeType(ndim=3), + ] + ), + ) + + +def test_get_static_type_func(): + # tuple + def fn_info(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + def fn_type(): + x = rx.DynTensorType(ndim=3, dtype="float32") + y = rx.DynTensorType(ndim=3, dtype="float32") + z = rx.DynTensorType(ndim=2, dtype="float32") + return rx.FuncType([x, y], z) + + f0 = fn_info(1) + + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(fn_info(1)), fn_type()) + + +def test_erase_to_well_defined_basic(): + s0 = rx.ObjectStructInfo() + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s0), s0) + + # prim + s1 = rx.PrimStructInfo("float32") + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1), s1) + + +def test_erase_to_well_defined_shape(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s2 = rx.ShapeStructInfo([1, n + 1, m]) + s3 = rx.ShapeStructInfo(ndim=2) + # have undefined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s2), rx.ShapeStructInfo(ndim=3) + ) + # all defined + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2, {n: n, m: m}), s2) + + # replacement + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s2, {n: 2, m: m + 1}), rx.ShapeStructInfo([1, 3, m + 1]) + ) + + # partial defined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s2, {n: n}), rx.ShapeStructInfo(ndim=3) + ) + + +def test_erase_to_well_defined_tensor(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + rshape = rx.Var("shape", rx.ShapeStructInfo(ndim=2)) + s0 = rx.TensorStructInfo(rshape, dtype="int32") + + # undefined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s0, None, None), + rx.TensorStructInfo(ndim=2, dtype="int32"), + ) + + # defined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s0, None, {rshape: rshape}), s0 + ) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s0, None, {rshape: rx.ShapeExpr([1, 2])}), + rx.TensorStructInfo([1, 2], dtype="int32"), + ) + + s1 = rx.TensorStructInfo([m + 1, n], dtype="float32") + + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1, {n: n, m: m}), s1) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s1, {n: 2, m: 3}), + rx.TensorStructInfo([4, 2], dtype="float32"), + ) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s1, {m: m}, {rshape: rshape}), + rx.TensorStructInfo(ndim=2, dtype="float32"), + ) + + s2 = rx.TensorStructInfo([1, 2], dtype="float32") + + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2), s2) + + +def test_erase_to_well_defined_tuple(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + s0 = rx.ObjectStructInfo() + s2 = rx.ShapeStructInfo([1, m]) + s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + t0 = rx.TupleStructInfo([s4, s0]) + t1 = rx.TupleStructInfo([t0, s2]) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(t1, {m: m + 1}), + rx.TupleStructInfo( + [ + rx.TupleStructInfo( + [rx.TensorStructInfo(ndim=3, dtype="int64"), rx.ObjectStructInfo()] + ), + rx.ShapeStructInfo([1, m + 1]), + ] + ), + ) + + +def test_erase_to_well_defined_func(): + def fn_info(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + f0 = fn_info(1) + + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(f0), f0) + + +def test_base_check(): + BR = rx.analysis.BaseCheckResult + bcheck = rx.analysis.struct_info_base_check + + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + obj0 = rx.ObjectStructInfo() + prim0 = rx.PrimStructInfo("int32") + prim1 = rx.PrimStructInfo("float32") + + shape0 = rx.ShapeStructInfo(ndim=-1) + shape1 = rx.ShapeStructInfo(ndim=2) + shape2 = rx.ShapeStructInfo(ndim=3) + shape3 = rx.ShapeStructInfo([1, 2, 3]) + shape4 = rx.ShapeStructInfo([1, n, 3]) + + tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32") + tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32") + tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32") + tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32") + tensor4 = rx.TensorStructInfo([n, m], "int32") + tensor5 = rx.TensorStructInfo([n, m, 1], "int32") + tensor6 = rx.TensorStructInfo([n, m, 2], "int32") + + # obj + assert bcheck(obj0, prim0) == BR.PASS + assert bcheck(obj0, shape1) == BR.PASS + assert bcheck(obj0, tensor2) == BR.PASS + assert obj0.is_base_of(tensor2) + + # prim + assert prim0.is_base_of(prim0) + assert not prim0.is_base_of(prim1) + assert bcheck(prim0, obj0) == BR.FAIL_L1 + assert bcheck(prim0, prim0) == BR.PASS + assert bcheck(prim0, prim1) == BR.FAIL_L0 + + # shape + assert bcheck(shape0, obj0) == BR.FAIL_L1 + assert bcheck(shape0, prim0) == BR.FAIL_L0 + + # unknown dim + assert bcheck(shape0, shape1) == BR.PASS + assert bcheck(shape1, shape0) == BR.FAIL_L1 + + # ndim mismatch + assert bcheck(shape1, shape2) == BR.FAIL_L0 + + # lhs do not have symbolic value but ndim match + assert bcheck(shape2, shape3) == BR.PASS + + # rhs do not symbolic but lhs do + assert bcheck(shape3, shape2) == BR.FAIL_L2 + + # shape mismatch + assert bcheck(shape3, shape4) == BR.FAIL_L2 + assert shape4.is_base_of(rx.ShapeStructInfo([1, n, 3])) + + # tensor + assert bcheck(tensor0, obj0) == BR.FAIL_L1 + assert bcheck(tensor0, prim0) == BR.FAIL_L0 + assert bcheck(tensor0, shape0) == BR.FAIL_L0 + + # dtype mismatch + assert bcheck(tensor0, tensor1) == BR.FAIL_L0 + assert bcheck(tensor0, tensor3) == BR.FAIL_L0 + assert bcheck(tensor3, tensor4) == BR.FAIL_L0 + assert bcheck(tensor1, tensor2) == BR.FAIL_L0 + + # ndim mismatch + assert bcheck(tensor2, tensor5) == BR.FAIL_L0 + + # static shape mismatch + assert bcheck(tensor5, tensor6) == BR.FAIL_L0 + + # match + assert tensor0.is_base_of(rx.TensorStructInfo(ndim=-1, dtype="int32")) + assert tensor0.is_base_of(tensor2) + assert tensor0.is_base_of(tensor4) + assert tensor0.is_base_of(tensor5) + assert tensor0.is_base_of(tensor6) + assert tensor2.is_base_of(tensor4) + assert tensor4.is_base_of(rx.TensorStructInfo([n, m], dtype="int32")) + + # tuple + t0 = rx.TupleStructInfo([obj0, tensor0]) + t1 = rx.TupleStructInfo([prim0, tensor4]) + t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) + t3 = rx.TupleStructInfo([tensor0, obj0]) + + assert t0.is_base_of(t1) + + assert bcheck(t0, t2) == BR.FAIL_L0 + assert bcheck(t0, t3) == BR.FAIL_L1 + + assert rx.TupleStructInfo([t0, t1]).is_base_of(rx.TupleStructInfo([t1, t1])) + assert bcheck(rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t0])) == BR.FAIL_L1 + + def fn_info_shape(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + def fn_info_erased(): + x = rx.TensorStructInfo(ndim=3, dtype="float32") + y = rx.TensorStructInfo(ndim=3, dtype="float32") + z = rx.TensorStructInfo(ndim=2, dtype="float32") + return rx.FuncStructInfo([x, y], z) + + assert fn_info_shape(1).is_base_of(fn_info_shape(1)) + assert fn_info_erased().is_base_of(fn_info_shape(1)) + assert bcheck(fn_info_shape(1), fn_info_erased()) == BR.FAIL_L2 + + fopaque = rx.FuncStructInfo.opaque_func() + assert fopaque.is_base_of(fn_info_shape(1)) + + +def _check_derive(ctx, finfo, args_sinfo, ret): + gv = rx.GlobalVar("test") + rx.expr._update_struct_info(gv, finfo) + args = [] + for i, sinfo in enumerate(args_sinfo): + arg = rx.Var("arg%i" % i, sinfo) + args.append(arg) + call = rx.Call(gv, args) + derived_ret = rx.analysis.derive_call_ret_struct_info(finfo, call, ctx) + tvm.ir.assert_structural_equal(ret, derived_ret) + + +def test_derive_call_ret_struct_info(): + obj0 = rx.ObjectStructInfo() + prim0 = rx.PrimStructInfo("float32") + + n, m = tir.Var("n0", "int64"), tir.Var("m0", "int64") + bb = rx.BlockBuilder() + # derivation cases + with bb.testing_scope(def_vars=[n, m]): + + def func0(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([n, m], "float32") + z = rx.TensorStructInfo([m + c, n], "float32") + return rx.FuncStructInfo([x], z) + + # Tensor => Tensor + _check_derive( + bb, + func0(1), + [rx.TensorStructInfo([10, 11], "float32")], + rx.TensorStructInfo([12, 10], "float32"), + ) + + _check_derive( + bb, + func0(2), + [rx.TensorStructInfo([n, m], "float32")], + rx.TensorStructInfo([m + 2, n], "float32"), + ) + + # passing in information that cannot deduce n, m + # it is still OK as type still matches, return an + # eriased output + _check_derive( + bb, + func0(2), + [rx.TensorStructInfo(ndim=2, dtype="float32")], + rx.TensorStructInfo(ndim=2, dtype="float32"), + ) + + # Error: wrong number of arguments + with pytest.raises(TVMError): + _check_derive( + bb, + func0(2), + [rx.TensorStructInfo(ndim=2, dtype="float32"), obj0], + rx.TensorStructInfo(ndim=2, dtype="float32"), + ) + + # Error:type mismatch + with pytest.raises(TVMError): + _check_derive(bb, func0(2), [obj0], obj0) + + # opaque derivation + fopaque0 = lambda: rx.FuncStructInfo.opaque_func() + fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0) + _check_derive(bb, fopaque0(), [obj0, prim0], obj0) + _check_derive(bb, fopaque1(), [obj0, prim0], prim0) + + # recursive tuple derivation + def func_tuple0(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x0 = rx.TensorStructInfo([n, c], "float32") + x1 = rx.TensorStructInfo([n + c, m], "float32") + z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) + return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z) + + _check_derive( + bb, + func_tuple0(2), + [ + rx.TupleStructInfo( + [ + rx.TensorStructInfo([n, 2], "float32"), + rx.TensorStructInfo([n + 2, 10], "float32"), + ] + ) + ], + rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]), + ) + + def func_tuple1(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x0 = rx.TensorStructInfo([n, m], "float32") + x1 = rx.TensorStructInfo([n + c, c], "float32") + z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) + return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z) + + # Still OK, to pass erased tensor into n+2, n is captured by other argument. + _check_derive( + bb, + func_tuple1(4), + [ + rx.TupleStructInfo( + [ + rx.TensorStructInfo([n, 4], "float32"), + rx.TensorStructInfo(ndim=2, dtype="float32"), + ] + ) + ], + rx.TupleStructInfo([rx.TensorStructInfo([4, n], "float32")]), + ) + + # tuple length mismatch is not causes an error + with pytest.raises(TVMError): + _check_derive( + bb, + func_tuple0(4), + [rx.TupleStructInfo([rx.TensorStructInfo([n, 4], "float32")])], + rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]), + ) + + # mixed shape types + def func_shape_mixed(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x0 = rx.ShapeStructInfo([n, m]) + f0 = func_tuple0(c) + z = rx.ShapeStructInfo([m + n, c]) + return rx.FuncStructInfo([x0, f0], z) + + _check_derive( + bb, + func_shape_mixed(3), + [ + rx.ShapeStructInfo([10, 20]), + rx.FuncStructInfo.opaque_func(ret=rx.ShapeStructInfo(ndim=2)), + ], + rx.ShapeStructInfo([30, 3]), + ) + + +def _check_lca(lhs, rhs, target): + tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(lhs, rhs), target) + tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(rhs, lhs), target) + + +def test_struct_info_lca(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + obj0 = rx.ObjectStructInfo() + prim0 = rx.PrimStructInfo("int32") + prim1 = rx.PrimStructInfo("float32") + + shape0 = rx.ShapeStructInfo(ndim=-1) + shape1 = rx.ShapeStructInfo(ndim=2) + shape2 = rx.ShapeStructInfo(ndim=3) + shape3 = rx.ShapeStructInfo([1, 2, 3]) + shape4 = rx.ShapeStructInfo([1, n, 3]) + + tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32") + tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32") + tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32") + tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32") + tensor4 = rx.TensorStructInfo([n, m], "int32") + tensor5 = rx.TensorStructInfo([n, m, 1], "int32") + tensor6 = rx.TensorStructInfo([n, m, 2], "int32") + + # obj + _check_lca(obj0, prim0, obj0) + _check_lca(obj0, prim1, obj0) + + # shape + _check_lca(shape0, tensor0, obj0) + _check_lca(shape0, shape1, shape0) + _check_lca(shape1, shape2, shape0) + _check_lca(shape1, shape3, shape0) + + _check_lca(shape2, shape3, shape2) + _check_lca(shape3, shape4, shape2) + _check_lca(shape4, rx.ShapeStructInfo([1, n, 3]), shape4) + + # tensor + _check_lca(tensor0, prim0, obj0) + _check_lca(tensor0, tensor1, rx.TensorStructInfo(ndim=-1, dtype=None)) + _check_lca(tensor0, tensor2, tensor0) + _check_lca(tensor0, tensor4, tensor0) + + _check_lca(tensor2, tensor4, tensor2) + _check_lca(tensor5, tensor6, rx.TensorStructInfo(ndim=3, dtype="int32")) + _check_lca(tensor4, tensor5, rx.TensorStructInfo(ndim=-1, dtype="int32")) + _check_lca(tensor4, rx.TensorStructInfo([n, m], dtype="int32"), tensor4) + + # tuple + t0 = rx.TupleStructInfo([obj0, tensor0]) + t1 = rx.TupleStructInfo([prim0, tensor4]) + t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) + t3 = rx.TupleStructInfo([tensor0, obj0]) + + _check_lca(t0, t1, t0) + _check_lca(t0, t2, obj0) + _check_lca(t0, t3, rx.TupleStructInfo([obj0, obj0])) + + t5 = rx.TupleStructInfo([t0, t1]) + t6 = rx.TupleStructInfo([t1, t2]) + + _check_lca(t5, t6, rx.TupleStructInfo([t0, obj0])) + + t7 = rx.TupleStructInfo([]) + _check_lca(t7, rx.TupleStructInfo([]), t7) + + def fn_info_shape(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + def fn_info_erased(): + x = rx.TensorStructInfo(ndim=3, dtype="float32") + y = rx.TensorStructInfo(ndim=3, dtype="float32") + z = rx.TensorStructInfo(ndim=2, dtype="float32") + return rx.FuncStructInfo([x, y], z) + + fopaque0 = lambda: rx.FuncStructInfo.opaque_func() + fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0) + fopaque2 = lambda: rx.FuncStructInfo.opaque_func( + ret=rx.TensorStructInfo(ndim=2, dtype="float32") + ) + + _check_lca(fn_info_shape(1), fn_info_shape(2), fn_info_erased()) + _check_lca(fn_info_shape(2), fn_info_shape(2), fn_info_shape(2)) + + _check_lca(fopaque0(), fopaque1(), fopaque0()) + _check_lca(fopaque0(), fn_info_shape(1), fopaque0()) + _check_lca(fopaque2(), fn_info_shape(1), fopaque2()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_analysis_suggest_layout_transforms.py b/tests/python/relax/test_analysis_suggest_layout_transforms.py new file mode 100644 index 000000000000..2850f0ed9f94 --- /dev/null +++ b/tests/python/relax/test_analysis_suggest_layout_transforms.py @@ -0,0 +1,831 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm.testing + +from tvm import relax, tir +from tvm.script import tir as T + + +def apply_transformations(func, suggested_transfoms, print_transformation=False): + sch = tir.Schedule(func) + for block, per_block_transformations in suggested_transfoms.items(): + blockrv = sch.get_block(block.name_hint) + for obj, index_map in per_block_transformations.items(): + if isinstance(obj, tir.Block): + block_name = obj.name_hint + if print_transformation: + print("Block transformation: ", block_name, " :: ", index_map) + sch.transform_block_layout(block_name, index_map) + else: + assert isinstance(obj, tir.Buffer) + buffer = obj + if print_transformation: + print("Buffer transformation: ", buffer, " :: ", index_map) + sch.transform_layout(blockrv, buffer, index_map) + return sch.mod["main"] + + +def test_nested_blocks(): + @T.prim_func + def nested_block( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i, j in T.grid(32, 64): + with T.block("outer"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(arg[v_i, v_j, 0:224, 0:224]) + T.writes(relu[v_i, v_j, 0:224, 0:224]) + for k, l in T.grid(224, 224): + with T.block("inner"): + v_k, v_l = T.axis.remap("SS", [k, l]) + T.reads(arg[v_i, v_j, v_k, v_l]) + T.writes(relu[v_i, v_j, v_k, v_l]) + relu[v_i, v_j, v_k, v_l] = T.max(arg[v_i, v_j, v_k, v_l], T.float32(0)) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=nested_block, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + # no suggestions for nested block. + assert len(suggested_transforms.items()) == 0 + + +def test_mismatch_transformations_and_num_params(): + @T.prim_func + def elemwise( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(relu[v_i0, v_i1, v_i2, v_i3]) + relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + with pytest.raises(tvm.TVMError, match="Incompatible PrimFunc and write_transformations"): + _ = relax.analysis.suggest_layout_transforms( + func=elemwise, + write_buffer_transforms=[ + lambda n, c, h, w: (n, h, w, c), + lambda n, c, h, w: (n, h, w, c), + lambda n, c, h, w: (n, h, w, c), + ], + ) + + +def test_empty_write_transformations(): + @T.prim_func + def elemwise( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(relu[v_i0, v_i1, v_i2, v_i3]) + relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=elemwise, write_buffer_transforms=[] + ) + assert len(suggested_transforms.items()) == 0 + + +def test_non_bijective_block_transform(): + @T.prim_func + def before( + arg: T.Buffer((32, 64), "float32"), + output: T.Buffer((32, 64), "float32"), + ): + for ax0, ax1 in T.grid(32, 64): + with T.block("compute"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(arg[v_ax0, v_ax1]) + T.writes(output[v_ax0, v_ax1]) + output[v_ax0, v_ax1] = arg[v_ax0, v_ax1] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c: (n, c // 5, c % 5)] + ) + assert len(suggested_transforms.items()) == 0 + + +def test_non_affine_access(): + @T.prim_func + def before( + arg: T.Buffer((32, 64), "float32"), + output: T.Buffer((32 * 64, 10), "float32"), + ): + for ax0, ax1, ax2 in T.grid(32, 64, 10): + with T.block("compute"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(arg[v_ax0, v_ax1]) + T.writes(output[v_ax0 * v_ax1, v_ax2]) + output[v_ax0 * v_ax1, v_ax2] = arg[v_ax0, v_ax1] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda a, b: (b, a)] + ) + assert len(suggested_transforms.items()) == 0 + + +def test_unsupported_write_spatial_layout(): + @T.prim_func + def before( + arg: T.Buffer((4, 4), "float32"), + output: T.Buffer((16), "float32"), + ): + for ax0, ax1 in T.grid(4, 4): + with T.block("flatten"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(arg[v_ax0, v_ax1]) + T.writes(output[v_ax0 * 4 + v_ax1]) + output[v_ax0 * 4 + v_ax1] = arg[v_ax0, v_ax1] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda a: (a // 4, a % 4)] + ) + assert len(suggested_transforms.items()) == 0 + + +def test_unpacked_iter_used_in_read_access(): + @T.prim_func + def before( + arg: T.Buffer((8, 4), "float32"), + output: T.Buffer((4, 8), "float32"), + ): + for ax0, ax1, ax2 in T.grid(4, 8, 4): + with T.block("compute"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(arg[v_ax1, v_ax2]) + T.writes(output[v_ax0, v_ax1]) + output[v_ax0, v_ax1] = arg[v_ax1, v_ax2] + + @T.prim_func + def expected( + arg: T.Buffer((8, 4), "float32"), + output: T.Buffer((32), "float32"), + ): + for ax0, ax2 in T.grid(32, 4): + with T.block("compute"): + v_ax0, v_ax2 = T.axis.remap("SS", [ax0, ax2]) + T.reads(arg[v_ax0 % 8, v_ax2]) + T.writes(output[v_ax0]) + output[v_ax0] = arg[v_ax0 % 8, v_ax2] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda a, b: (a * 8 + b)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_invalid_index_map(): + @T.prim_func + def elemwise( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(relu[v_i0, v_i1, v_i2, v_i3]) + relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + with pytest.raises(tvm.TVMError, match="Mismatch between output buffer shape and index map"): + _ = relax.analysis.suggest_layout_transforms( + func=elemwise, write_buffer_transforms=[lambda n, h, w: (n, w, h)] + ) + with pytest.raises(AssertionError): + _ = relax.analysis.suggest_layout_transforms(func=elemwise, write_buffer_transforms=[2]) + + +def test_SRSR_block(): + @T.prim_func + def before( + arg: T.Buffer((32, 224, 64, 224), "float32"), + sum: T.Buffer((32, 64), "float32"), + ): + for ax0, k2, ax1, k3 in T.grid(32, 224, 64, 224): + with T.block("rxplaceholder_red"): + v_ax0, v_k2, v_ax1, v_k3 = T.axis.remap("SRSR", [ax0, k2, ax1, k3]) + T.reads(arg[v_ax0, v_ax1, v_k2, v_k3]) + T.writes(sum[v_ax0, v_ax1]) + with T.init(): + sum[v_ax0, v_ax1] = T.float32(0) + sum[v_ax0, v_ax1] = sum[v_ax0, v_ax1] + arg[v_ax0, v_k2, v_ax1, v_k3] + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 16, 224, 4), "float32"), + sum: T.Buffer((32, 16, 4), "float32"), + ): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 16, 224, 4): + with T.block("rxplaceholder_red"): + v0, v1, v2, v3, v4 = T.axis.remap("SRSRS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1, v2, v3, v4]) + T.writes(sum[v0, v2, v4]) + with T.init(): + sum[v0, v2, v4] = T.float32(0) + sum[v0, v2, v4] = sum[v0, v2, v4] + arg[v0, v1, v2, v3, v4] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c: (n, c // 4, c % 4)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_elemwise_symbolic(): + @T.prim_func + def before(arg: T.handle, relu: T.handle): + N = T.int64() + C = T.int64() + H = T.int64() + W = T.int64() + Arg = T.match_buffer(arg, (N, C, H, W)) + Relu = T.match_buffer(relu, (N, C, H, W)) + for i0, i1, i2, i3 in T.grid(N, C, H, W): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(Arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(Relu[v_i0, v_i1, v_i2, v_i3]) + Relu[v_i0, v_i1, v_i2, v_i3] = T.max(Arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + @T.prim_func + def expected(arg: T.handle, relu: T.handle): + N = T.int64() + C = T.int64() + H = T.int64() + W = T.int64() + Arg = T.match_buffer(arg, (N, H, W, C)) + Relu = T.match_buffer(relu, (N, H, W, C)) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): + with T.block("compute"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(Arg[v0, v1, v2, v3]) + T.writes(Relu[v0, v1, v2, v3]) + Relu[v0, v1, v2, v3] = T.max(Arg[v0, v1, v2, v3], T.float32(0)) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_elemwise(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(relu[v_i0, v_i1, v_i2, v_i3]) + relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 64), "float32"), + relu: T.Buffer((32, 224, 224, 64), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 64): + with T.block("compute"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v1, v2, v3]) + T.writes(relu[v0, v1, v2, v3]) + relu[v0, v1, v2, v3] = T.max(arg[v0, v1, v2, v3], T.float32(0)) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_pool_nchw_nhwc(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + pool_max: T.Buffer((32, 64, 111, 223), "float32"), + ): + for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(32, 64, 111, 223, 2, 2): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap( + "SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1] + ) + T.reads( + arg[ + v_ax0, + v_ax1, + v_ax2 * 2 + v_rv0 * 2, + v_ax3 + v_rv1, + ] + ) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(-3.4028234663852886e38) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3], + arg[ + v_ax0, + v_ax1, + v_ax2 * 2 + v_rv0 * 2, + v_ax3 + v_rv1, + ], + ) + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 64), "float32"), + pool_max: T.Buffer((32, 111, 223, 64), "float32"), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(32, 111, 223, 64, 2, 2): + with T.block("pool_max"): + v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5]) + T.reads(arg[v0, v1 * 2 + v4 * 2, v2 + v5, v3]) + T.writes(pool_max[v0, v1, v2, v3]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v0, v1, v2, v3] = T.float32(-3.4028234663852886e38) + pool_max[v0, v1, v2, v3] = T.max( + pool_max[v0, v1, v2, v3], + arg[v0, v1 * 2 + v4 * 2, v2 + v5, v3], + ) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, + write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)], + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_pool_nchw16c_nhwc(): + @T.prim_func + def before( + arg: T.Buffer( + (32, 4, 224, 224, 16), + "float32", + ), + pool_max: T.Buffer( + (32, 4, 110, 220, 16), + "float32", + ), + ): + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(32, 4, 110, 220, 16, 5, 5): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( + "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] + ) + T.reads(arg[v_ax0, v_ax1, v_ax2 * 2 + v_rv0, v_ax3 + v_rv1, v_ax4]) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32(-3.4028234663852886e38) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], + arg[v_ax0, v_ax1, v_ax2 * 2 + v_rv0, v_ax3 + v_rv1, v_ax4], + ) + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 64), "float32"), + pool_max: T.Buffer((32, 110, 220, 64), "float32"), + ): + for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(32, 110, 220, 64, 5, 5): + with T.block("pool_max"): + v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5]) + T.reads(arg[v0, v1 * 2 + v4, v2 + v5, v3]) + T.writes(pool_max[v0, v1, v2, v3]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v0, v1, v2, v3] = T.float32(-3.4028234663852886e38) + pool_max[v0, v1, v2, v3] = T.max( + pool_max[v0, v1, v2, v3], + arg[v0, v1 * 2 + v4, v2 + v5, v3], + ) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, + write_buffer_transforms=[lambda n, C, h, w, c: (n, h, w, C * 16 + c)], + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_reduce(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + sum: T.Buffer((32, 64), "float32"), + ): + for ax0, ax1, k2, k3 in T.grid(32, 64, 224, 224): + with T.block("rxplaceholder_red"): + v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) + T.reads(arg[v_ax0, v_ax1, v_k2, v_k3]) + T.writes(sum[v_ax0, v_ax1]) + with T.init(): + sum[v_ax0, v_ax1] = T.float32(0) + sum[v_ax0, v_ax1] = sum[v_ax0, v_ax1] + arg[v_ax0, v_ax1, v_k2, v_k3] + + @T.prim_func + def expected( + arg: T.Buffer((32, 4, 224, 224, 16), "float32"), + sum: T.Buffer((32, 4, 16), "float32"), + ): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 4, 224, 224, 16): + with T.block("rxplaceholder_red"): + v0, v1, v2, v3, v4 = T.axis.remap("SSRRS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1, v2, v3, v4]) + T.writes(sum[v0, v1, v4]) + with T.init(): + sum[v0, v1, v4] = T.float32(0) + sum[v0, v1, v4] = sum[v0, v1, v4] + arg[v0, v1, v2, v3, v4] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c: (n, c // 16, c % 16)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_upsampling(): + # relay materializes the layout if H, W or D dimensions are moved or tiled. + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + resize: T.Buffer((32, 64, 202, 246), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 202, 246): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, 0:224, 0:224]) + T.writes(resize[v_i0, v_i1, v_i2, v_i3]) + resize[v_i0, v_i1, v_i2, v_i3] = arg[ + v_i0, + v_i1, + T.max( + T.min( + T.Cast( + "int64", + T.floor( + T.float32(1.1089109182357788) * T.Cast("float32", v_i2) + + T.float32(1.0000000000000001e-05) + ), + ), + 223, + ), + 0, + ), + T.max( + T.min( + T.Cast( + "int64", + T.floor( + T.float32(0.91056913137435913) * T.Cast("float32", v_i3) + + T.float32(1.0000000000000001e-05) + ), + ), + 223, + ), + 0, + ), + ] + + @T.prim_func + def expected( + arg: T.Buffer((32, 64, 224, 224), "float32"), + resize: T.Buffer((32, 202, 246, 64), "float32"), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(32, 202, 246, 64): + with T.block("resize"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v3, 0:224, 0:224]) + T.writes(resize[v0, v1, v2, v3]) + resize[v0, v1, v2, v3] = arg[ + v0, + v3, + T.max( + T.min( + T.Cast( + "int64", + T.floor( + T.float32(1.1089109182357788) * T.Cast("float32", v1) + + T.float32(1.0000000000000001e-05) + ), + ), + T.int64(223), + ), + T.int64(0), + ), + T.max( + T.min( + T.Cast( + "int64", + T.floor( + T.float32(0.91056913137435913) * T.Cast("float32", v2) + + T.float32(1.0000000000000001e-05) + ), + ), + T.int64(223), + ), + T.int64(0), + ), + ] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_strided_slice(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + T_strided_slice_with_axes: T.Buffer((32, 64, 10, 8), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 64, 10, 8): + with T.block("T_strided_slice_with_axes"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads( + arg[ + v_ax0, + v_ax1, + v_ax2 * 5 + 2, + v_ax3 * 7 + 4, + ] + ) + T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3]) + T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3] = arg[ + v_ax0, + v_ax1, + v_ax2 * 5 + 2, + v_ax3 * 7 + 4, + ] + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 16, 4), "float32"), + T_strided_slice_with_axes: T.Buffer((32, 10, 8, 16, 4), "float32"), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 10, 8, 16, 4): + with T.block("T_strided_slice_with_axes"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1 * 5 + 2, v2 * 7 + 4, v3, v4]) + T.writes(T_strided_slice_with_axes[v0, v1, v2, v3, v4]) + T_strided_slice_with_axes[v0, v1, v2, v3, v4] = arg[ + v0, v1 * 5 + 2, v2 * 7 + 4, v3, v4 + ] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c // 4, c % 4)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_binary_broadcast(): + @T.prim_func + def before( + arg0: T.Buffer((32, 64, 224, 224), "float32"), + arg1: T.Buffer((64, 224, 224), "float32"), + T_add: T.Buffer((32, 64, 224, 224), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(32, 64, 224, 224): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads( + arg0[v_ax0, v_ax1, v_ax2, v_ax3], + arg1[v_ax1, v_ax2, v_ax3], + ) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = ( + arg0[v_ax0, v_ax1, v_ax2, v_ax3] + arg1[v_ax1, v_ax2, v_ax3] + ) + + @T.prim_func + def expected( + arg0: T.Buffer((32, 224, 224, 16, 4), "float32"), + arg1: T.Buffer((224, 224, 16, 4), "float32"), + T_add: T.Buffer((32, 224, 224, 16, 4), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 16, 4): + with T.block("T_add"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg0[v0, v1, v2, v3, v4], arg1[v1, v2, v3, v4]) + T.writes(T_add[v0, v1, v2, v3, v4]) + T_add[v0, v1, v2, v3, v4] = arg0[v0, v1, v2, v3, v4] + arg1[v1, v2, v3, v4] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c // 4, c % 4)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_transpose(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + T_transpose: T.Buffer((32, 224, 224, 64), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 64): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax3, v_ax1, v_ax2]) + T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax3, v_ax1, v_ax2] + + @T.prim_func + def expected( + arg: T.Buffer((32, 64, 224, 224), "float32"), + T_transpose: T.Buffer((32, 224, 64, 224), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 64, 224): + with T.block("T_transpose"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v2, v3, v1]) + T.writes(T_transpose[v0, v1, v2, v3]) + T_transpose[v0, v1, v2, v3] = arg[v0, v2, v3, v1] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_pad(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + PadInput: T.Buffer((32, 64, 230, 230), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 230, 230): + with T.block("PadInput"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2 - 2, v_i3 - 2]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else( + 2 <= v_i2 and v_i2 < 226 and 2 <= v_i3 and v_i3 < 226, + arg[v_i0, v_i1, v_i2 - 2, v_i3 - 2], + T.float32(2), + ) + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 16, 4), "float32"), + PadInput: T.Buffer((32, 230, 230, 16, 4), "float32"), + ): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 230, 230, 16, 4): + with T.block("PadInput"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1 - 2, v2 - 2, v3, v4]) + T.writes(PadInput[v0, v1, v2, v3, v4]) + PadInput[v0, v1, v2, v3, v4] = T.if_then_else( + 2 <= v1 and v1 < 226 and 2 <= v2 and v2 < 226, + arg[v0, v1 - 2, v2 - 2, v3, v4], + T.float32(2), + ) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c // 4, c % 4)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_split(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + split0: T.Buffer((32, 32, 224, 224), "float32"), + split1: T.Buffer((32, 32, 224, 224), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): + with T.block("T_split_sections"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(split0[v_ax0, v_ax1, v_ax2, v_ax3]) + split0[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): + with T.block("T_split_sections_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3]) + T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3]) + split1[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3] + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 64), "float32"), + split0: T.Buffer((32, 224, 224, 32), "float32"), + split1: T.Buffer((32, 224, 224, 32), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 32): + with T.block("T_split_sections"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v1, v2, v3]) + T.writes(split0[v0, v1, v2, v3]) + split0[v0, v1, v2, v3] = arg[v0, v1, v2, v3] + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 32): + with T.block("T_split_sections_1"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v1, v2, v3 + 32]) + T.writes(split1[v0, v1, v2, v3]) + split1[v0, v1, v2, v3] = arg[v0, v1, v2, v3 + 32] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, + write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c), lambda n, c, h, w: (n, h, w, c)], + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_split_tiling_split_dim(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + split0: T.Buffer((32, 32, 224, 224), "float32"), + split1: T.Buffer((32, 32, 224, 224), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): + with T.block("T_split_sections"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(split0[v_ax0, v_ax1, v_ax2, v_ax3]) + split0[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): + with T.block("T_split_sections_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3]) + T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3]) + split1[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3] + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 16, 4), "float32"), + split0: T.Buffer((32, 224, 224, 8, 4), "float32"), + split1: T.Buffer((32, 224, 224, 8, 4), "float32"), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 8, 4): + with T.block("T_split_sections"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1, v2, v3, v4]) + T.writes(split0[v0, v1, v2, v3, v4]) + split0[v0, v1, v2, v3, v4] = arg[v0, v1, v2, v3, v4] + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 8, 4): + with T.block("T_split_sections_1"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1, v2, v3 + 8, v4]) + T.writes(split1[v0, v1, v2, v3, v4]) + split1[v0, v1, v2, v3, v4] = arg[v0, v1, v2, v3 + 8, v4] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, + write_buffer_transforms=[ + lambda n, c, h, w: (n, h, w, c // 4, c % 4), + lambda n, c, h, w: (n, h, w, c // 4, c % 4), + ], + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py new file mode 100644 index 000000000000..b4b68504a489 --- /dev/null +++ b/tests/python/relax/test_analysis_well_formed.py @@ -0,0 +1,537 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax as rx +from tvm import tir +from tvm.script import relax as R +from tvm.script import tir as T + +m = tir.Var("m", "int64") +n = tir.Var("n", "int64") +x = rx.Var("x", R.Tensor([m, n], "float32")) +cond = rx.Var("cond", R.Tensor([], "bool")) + + +def build_function(blocks, params=[]): + """Returns relax.function with given blocks""" + seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var) + func = rx.Function([x, cond] + params, seq_expr, R.Tensor("float32")).with_attr( + "global_symbol", "foo" + ) + return func + + +def test_var(): + # Error: Var gv0 is not defined + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + gv1 = rx.Var("gv1", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, gv0) + bindings = [rx.VarBinding(gv1, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: Var gv0 is defined more than once + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, x) + call_node2 = rx.op.multiply(x, x) + bindings = [rx.VarBinding(gv0, call_node), rx.VarBinding(gv0, call_node2)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_dataflow_var(): + # Error: DataflowVar lv0 is not defined + lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, lv0) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.DataflowBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: DataflowVar gv0 is defined more than once + lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, x) + call_node2 = rx.op.multiply(x, x) + bindings = [rx.VarBinding(lv0, call_node), rx.VarBinding(lv0, call_node2)] + blocks = [rx.DataflowBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: DataflowVar lv0 is defined outside DataflowBlock + lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, x) + bindings = [rx.VarBinding(lv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: DataflowVar lv0 is used outside DataflowBlock + lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(lv0, x) + bindings = [rx.VarBinding(lv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_param_var(): + v0 = rx.Var("v0", R.Tensor([m, n], "float32")) + v1 = rx.Var("v1", R.Tensor([m, n], "float32")) + v2 = rx.Var("v2", R.Tensor([m, n], "float32")) + bb = rx.BlockBuilder() + with bb.function("func1", [v0, v1]): + gv0 = bb.emit(rx.op.add(v0, v1)) + bb.emit_func_output(gv0) + with bb.function("func2", [v0, v2]): + gv0 = bb.emit(rx.op.add(v2, v1)) + bb.emit_func_output(gv0) + mod = bb.get() + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_global_var(): + # Error: GlobalVar GlobalVar0 is not defined + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + globalvar = rx.GlobalVar("GlobalVar0") + call_node = rx.Call( + op=tvm.ir.Op.get("relax.call_tir"), + args=[globalvar, rx.Tuple([x]), rx.ShapeExpr([m, n])], + ) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_symbolic_var(): + # Error: Symbolic Var new_s is not defined + new_s = tir.Var("new_s", "int64") + gv0 = rx.Var("gv0", R.Tensor([m, new_s], "int64")) + call_node = rx.op.add(x, x) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_symbolic_var_invalid_type(): + with pytest.raises( + tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64" + ): + dim = tir.Var("dim", "float32") + y = rx.Var("y", R.Tensor([dim], "float32")) + gv0 = rx.Var("gv0", R.Tensor([dim], "float32")) + call_node = rx.op.add(y, y) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks, [y]) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_seq_expr(): + # Error: SeqExpr in VarBinding + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + # build a SeqExpr + gv1 = rx.Var("gv1", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, gv0) + _bindings = [rx.VarBinding(gv1, call_node)] + _blocks = [rx.BindingBlock(_bindings)] + _seq_expr = rx.SeqExpr(_blocks, gv1) + # build a Binding with the SeqExpr as value + bindings = [rx.VarBinding(gv0, _seq_expr)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_recursive(): + scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32") + gv0 = rx.Var("gv0", scalar_struct_info) + f = rx.Var("f", rx.FuncStructInfo([scalar_struct_info], scalar_struct_info)) + ipt = rx.Var("ipt", scalar_struct_info) + x0 = rx.Var("x0", scalar_struct_info) + x1 = rx.Var("x1", scalar_struct_info) + x2 = rx.Var("x2", scalar_struct_info) + y = rx.Var("y", scalar_struct_info) + inner_block = rx.BindingBlock( + [rx.VarBinding(x0, rx.const(2, "int32")), rx.VarBinding(y, rx.Call(f, [x0]))] + ) + inner_func = rx.Function([ipt], rx.SeqExpr([inner_block], y), scalar_struct_info) + outer_block = rx.BindingBlock( + [ + rx.VarBinding(f, inner_func), + rx.VarBinding(x1, rx.const(1, "int32")), + rx.VarBinding(x2, rx.op.add(x1, rx.Call(f, [x1]))), + rx.VarBinding(gv0, x2), + ] + ) + func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info) + mod = tvm.IRModule.from_expr(func) + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized) + + +def test_if(): + # Error: Var defined in true/false branch is invisible in the outer scope + # except the return Var, i.e the var in the last stmt + # v_in_if is invisible in the outer scope + v_in_if = rx.Var("v_in_if", R.Tensor([m, n], "float32")) + # gv0 is visible in the outer scope + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + # build true branch + true_bindings = [ + rx.VarBinding(v_in_if, rx.op.add(x, x)), + rx.VarBinding(gv0, rx.op.multiply(x, x)), + ] + true_blocks = [rx.BindingBlock(true_bindings)] + true_seq_expr = rx.SeqExpr(true_blocks, true_blocks[-1].bindings[-1].var) + # build false branch + false_bindings = [ + rx.VarBinding(v_in_if, rx.op.multiply(x, x)), + rx.VarBinding(gv0, rx.op.add(x, x)), + ] + false_blocks = [rx.BindingBlock(false_bindings)] + false_seq_expr = rx.SeqExpr(false_blocks, false_blocks[-1].bindings[-1].var) + # build If node + if_node = rx.If(cond=cond, true_branch=true_seq_expr, false_branch=false_seq_expr) + gv1 = rx.Var("gv1", R.Tensor([m, n], "float32")) + # try to call v_in_if defined in the true/false branch + bindings = [rx.VarBinding(gv0, if_node), rx.VarBinding(gv1, v_in_if)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=True) + + +def test_if_non_seq_body(): + # Error: If node has a body that is not a seq node + if_node = rx.If(cond=cond, true_branch=x, false_branch=x) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", R.Tensor([m, n], "float32")), + if_node, + ) + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # on the other hand, if they're wrapped in a seq node, it's fine + seq = rx.SeqExpr([], x) + new_if_node = rx.If(cond=cond, true_branch=seq, false_branch=seq) + new_blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", R.Tensor([m, n], "float32")), + new_if_node, + ) + ] + ) + ] + new_func = build_function(new_blocks) + new_mod = tvm.IRModule.from_expr(new_func) + # apply normalization to fill in checked_type_ + normalized = rx.transform.Normalize()(new_mod) + assert rx.analysis.well_formed(normalized, check_struct_info=True) + + +def test_if_complex_condition(): + # Error: If condition must be a leaf expression + cond_tuple = rx.Tuple([cond]) + cond_idx = rx.TupleGetItem(cond_tuple, 0) + if_node = rx.If(cond_idx, rx.SeqExpr([], x), rx.SeqExpr([], x)) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", R.Tensor([m, n], "float32")), + if_node, + ) + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + cond_var = rx.Var("q", R.Tensor([], "bool")) + new_if = rx.If(cond_var, rx.SeqExpr([], x), rx.SeqExpr([], x)) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding(cond_var, cond_idx), + rx.VarBinding( + rx.Var("gv1", R.Tensor([m, n], "float32")), + new_if, + ), + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + # apply normalization to fill in checked_type_ + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized, check_struct_info=True) + + +def test_tuple_get_item_nested(): + # Error: The tuple value in tuple get item must be a leaf expression + nested_tup = rx.Var( + "t", rx.TupleStructInfo([rx.TupleStructInfo([rx.TensorStructInfo([], "int32")])]) + ) + double_idx = rx.TupleGetItem(rx.TupleGetItem(nested_tup, 0), 0) + ret_var = rx.Var("r", R.Tensor([], "int32")) + f = rx.Function( + [nested_tup], + rx.SeqExpr([rx.BindingBlock([rx.VarBinding(ret_var, double_idx)])], ret_var), + ret_struct_info=R.Tensor(ndim=0, dtype="int32"), + ) + f = f.with_attr("global_symbol", "f") + mod = tvm.IRModule.from_expr(f) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # okay with an intermediate binding + first_idx = rx.TupleGetItem(nested_tup, 0) + idx_var = rx.Var("v", rx.TupleStructInfo([rx.TensorStructInfo([], "int32")])) + second_idx = rx.TupleGetItem(idx_var, 0) + new_f = rx.Function( + [nested_tup], + rx.SeqExpr( + [ + rx.BindingBlock( + [rx.VarBinding(idx_var, first_idx), rx.VarBinding(ret_var, second_idx)] + ) + ], + ret_var, + ), + ret_struct_info=R.Tensor(ndim=0, dtype="int32"), + ) + new_f = new_f.with_attr("global_symbol", "new_f") + mod = tvm.IRModule.from_expr(new_f) + # normalize in order to fill in checked type + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized, check_struct_info=True) + + +def test_complex_seq_body(): + # Error: seq expr with a body that is not a leaf expression is not permitted + x = rx.Var("x", R.Tensor([], "int32")) + y = rx.Var("y", R.Tensor([], "int32")) + func = rx.Function( + [x, y], + rx.SeqExpr([], rx.op.add(x, y)), + R.Tensor(ndim=0, dtype="int32"), + ).with_attr("global_symbol", "foo") + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # but if the result is bound, then it's okay + z = rx.Var("z", R.Tensor([], "int32")) + new_func = rx.Function( + [x, y], + rx.SeqExpr( + [ + rx.BindingBlock( + [ + rx.VarBinding( + var=z, + value=rx.op.add(x, y), + ) + ] + ) + ], + z, + ), + R.Tensor(ndim=0, dtype="int32"), + ).with_attr("global_symbol", "foo") + new_mod = tvm.IRModule.from_expr(new_func) + # normalize in order to fill in checked type + normalized = rx.transform.Normalize()(new_mod) + assert rx.analysis.well_formed(normalized, check_struct_info=True) + + +def test_inline_prim_func(): + # Error: inline prim_func is disallowed in Relax IR + x = rx.Var("x", R.Tensor([], "int32")) + y = rx.Var("y", R.Tensor([], "int32")) + new_func = rx.Function( + [], + rx.SeqExpr( + [ + rx.BindingBlock( + [ + rx.VarBinding( + var=x, + value=tir.PrimFunc([], tir.Evaluate(0)), + ), + rx.VarBinding( + var=y, + value=rx.Call( + op=tvm.ir.Op.get("relax.call_tir"), + args=[ + rx.GlobalVar("GlobalVar0"), + rx.Tuple([x, tir.PrimFunc([], tir.Evaluate(0))]), + rx.ShapeExpr([]), + ], + ), + ), + ] + ) + ], + y, + ), + R.Tensor(ndim=0, dtype="int32"), + ).with_attr("global_symbol", "foo") + new_mod = tvm.IRModule.from_expr(new_func) + assert not rx.analysis.well_formed(new_mod, check_struct_info=False) + + +def test_ANF(): + # Error: Nested Call + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, rx.op.add(x, x)) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: Call Node in Tuple + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + bindings = [rx.VarBinding(gv0, rx.Tuple((x, rx.op.add(x, x))))] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_global_var_vs_gsymbol(): + # Error: gsymbol "main1" not equals to the name in global var "main" + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + bindings = [rx.VarBinding(gv0, x)] + blocks = [rx.DataflowBlock(bindings)] + func = rx.Function( + [x], + rx.SeqExpr(blocks, gv0), + R.Tensor(ndim=2, dtype="float32"), + ).with_attr("global_symbol", "main1") + mod = tvm.IRModule({rx.GlobalVar("main"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_nested_dataflow(): + scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32") + gv0 = rx.Var("gv0", scalar_struct_info) + f = rx.DataflowVar("f", rx.FuncStructInfo([], scalar_struct_info)) + x0 = rx.DataflowVar("x0", scalar_struct_info) + x1 = rx.DataflowVar("x1", scalar_struct_info) + x2 = rx.DataflowVar("x2", scalar_struct_info) + y = rx.Var("y", scalar_struct_info) + inner_block = rx.DataflowBlock([rx.VarBinding(x0, rx.const(2, "int32")), rx.VarBinding(y, x0)]) + inner_func = rx.Function([], rx.SeqExpr([inner_block], y), scalar_struct_info) + outer_block = rx.DataflowBlock( + [ + rx.VarBinding(x1, rx.const(1, "int32")), + rx.VarBinding(f, inner_func), + rx.VarBinding(x2, rx.op.add(x1, rx.Call(f, []))), + rx.VarBinding(gv0, x2), + ] + ) + func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info) + mod = tvm.IRModule.from_expr(func) + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized) + + +def test_sinfo_args_tir_var_used_before_define_call_packed(): + # Error: Symbolic Var m1, n1 are not defined + m1 = tir.Var("m1", "int64") + n1 = tir.Var("n1", "int64") + call = R.call_packed("my_func", x, sinfo_args=R.Tensor((m1, n1), "float32")) + func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_sinfo_args_tir_var_used_before_define_call_tir(): + # Error: Symbolic Var m1, n1 are not defined + m1 = tir.Var("m1", "int64") + n1 = tir.Var("n1", "int64") + call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32")) + func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_sinfo_erase_to_well_formed(): + # Error: The return sinfo contains undefined symbolic vars + """ + @R.function + def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1", "n1"), dtype="float32"): + m = T.int64() + n = T.int64() + gv = R.call_dps_packed("my_func", (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) + return gv + """ + m1 = tir.Var("m1", "int64") + n1 = tir.Var("n1", "int64") + call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m, n), "float32")) + blocks = [rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])] + seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var) + func = rx.Function([x], seq_expr, R.Tensor((m1, n1), "float32")).with_attr( + "global_symbol", "foo" + ) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert not rx.analysis.well_formed(mod) + + +def test_func_sinfo_well_formed(): + @R.function + def foo(): + @R.function + def local(x: R.Tensor(["m", "n"], "float32")): + return x + + return local + + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(foo)) + assert rx.analysis.well_formed(mod) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py new file mode 100644 index 000000000000..84b8cb1d0930 --- /dev/null +++ b/tests/python/relax/test_ast_printer.py @@ -0,0 +1,700 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import re +from functools import partial +from typing import Dict + +import numpy as np +import tvm +import tvm.testing +from tvm import relax as rx +from tvm import tir +from tvm.relax.testing import dump_ast +from tvm.relax.testing.ast_printer import ASTPrinter +from tvm.script import relax as R +from tvm.script import tir as T + +# Overload dump_ast to test both struct info and type annotations +dump_ast = partial(dump_ast, include_struct_info_annotations=True, include_type_annotations=True) + + +def strip_whitespace(text: str) -> str: + """ + Remove all whitespace to avoid reasoning about newlines and indents + """ + return re.sub(r"\s", "", text) + + +def normalize(func: rx.Function) -> rx.Function: + """ + Normalize the expr to fill in the checked_type_ and struct_info fields everywhere + """ + # using a default mutator to use the BlockBuilder's normalizer, + # which oddly differs from the Normalize pass + @rx.expr_functor.mutator + class DefaultMutator(rx.PyExprMutator): + pass + + mod = tvm.IRModule() + mod["main"] = func + mut = DefaultMutator(mod) + mod["main"] = mut.visit_expr(func) + return mod["main"] + + +def assert_fields(nodename: str, fields: Dict[str, str], target: str) -> None: + """ + Given a target string, ensure that the string defines the specified node + and that the given mappings of fields to values are present in the string. + Strips all whitespace in the target and fields. + Does not assume any particular ordering for the fields. + """ + stripped_target = strip_whitespace(target) + assert stripped_target.startswith(f"{nodename}(") + for field, value in fields.items(): + assert f"{field}={strip_whitespace(value)}" in stripped_target + + +# test cases are mostly adapted from text_expr, only testing very basic properties + + +def test_var() -> None: + v0 = rx.Var("v0") + v0_str = dump_ast(v0) + assert v0_str == 'Var(name_hint="v0")' + + v1 = rx.Var("v1", R.Tensor([54, 96], "float32")) + v1_no_annos = dump_ast( + v1, include_struct_info_annotations=False, include_type_annotations=False + ) + assert v1_no_annos == 'Var(name_hint="v1")' + v1_annos = dump_ast(v1) + assert v1_annos != v1_no_annos + assert "PrimExpr" in v1_annos + assert "struct_info" in v1_annos + assert "checked_type_" in v1_annos + + +def test_dataflow_var() -> None: + v0 = rx.DataflowVar("v0") + v0_str = dump_ast(v0) + assert v0_str == 'DataflowVar(name_hint="v0")' + + v1 = rx.DataflowVar("v1", R.Tensor([54, 96], "float16")) + v1_no_annos = dump_ast( + v1, include_struct_info_annotations=False, include_type_annotations=False + ) + assert v1_no_annos == 'DataflowVar(name_hint="v1")' + v1_annos = dump_ast(v1) + assert v1_annos != v1_no_annos + assert "PrimExpr" in v1_annos + assert "struct_info" in v1_annos + assert "checked_type_" in v1_annos + + +def test_match_cast() -> None: + # match_cast([16, 8], [m, n]) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + shape = rx.const([16, 8], "int32") + var = rx.Var("v0", R.Shape()) + b0 = rx.MatchCast(var, shape, R.Tensor([m, n], "int32")) + b0_str = dump_ast(b0) + assert b0_str.startswith("MatchCast(") + assert "Constant" in b0_str + assert "PrimExpr(value=`m" in b0_str + assert "PrimExpr(value=`n" in b0_str + assert "16" in b0_str + assert "8" in b0_str + assert b0_str != dump_ast(b0, include_type_annotations=False) + + # var1: Tensor((m, n), "float32") = + # match_cast(var0: R.Tensor("float32"), [m, n]) + value = rx.Var("value", R.Tensor("float32")) + var = rx.Var("v1", R.Tensor([m, n], "float32")) + b1 = rx.MatchCast(var, value, R.Tensor([m, n], "float32")) + b1_str = dump_ast(b1) + assert b1_str.startswith("MatchCast(") + assert "PrimExpr(value=`m" in b1_str + assert "PrimExpr(value=`n" in b1_str + assert b1_str != dump_ast( + b1, include_type_annotations=False, include_struct_info_annotations=False + ) + + +def test_var_binding() -> None: + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b0 = rx.VarBinding(v0, val) + b0_str = dump_ast(b0, include_type_annotations=False, include_struct_info_annotations=False) + assert b0_str.startswith("VarBinding(") + assert 'var=Var(name_hint="v0")' in b0_str + assert "value=" in b0_str + assert "Constant(" in b0_str + + +def test_binding_block() -> None: + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.BindingBlock([b0, b1]) + block0_str = dump_ast(block0) + assert block0_str.startswith("BindingBlock(") + assert "bindings=" in block0_str + assert "VarBinding(" in block0_str + assert "MatchCast(" in block0_str + assert '"v0"' in block0_str + + +def test_dataflow_block() -> None: + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.DataflowBlock([b0, b1]) + block0_str = dump_ast(block0) + assert block0_str.startswith("DataflowBlock(") + assert "bindings=" in block0_str + assert "VarBinding(" in block0_str + assert "MatchCast(" in block0_str + assert '"v0"' in block0_str + + +def test_seq_expr() -> None: + x = rx.Var("foo") + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + seqe = rx.SeqExpr(blocks, x) + seqe_str = dump_ast(seqe) + assert seqe_str.startswith("SeqExpr(") + assert "blocks=" in seqe_str + assert "BindingBlock(" in seqe_str + assert "VarBinding(" in seqe_str + assert "Constant(" in seqe_str + assert 'var=Var(name_hint="foo")' in seqe_str + assert "value=Constant(data" in strip_whitespace(seqe_str) + assert "body=" in seqe_str + + +def test_shape_expr() -> None: + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + s = rx.ShapeExpr([m, n]) + s_str = dump_ast(s) + assert s_str.startswith("ShapeExpr(") + assert "values=" in s_str + assert "PrimExpr(value=`m: int32`)" in s_str + assert "PrimExpr(value=`n: int32`)" in s_str + + +def test_func(): + x = rx.Var("foo", R.Tensor("float32", ndim=2)) + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + seqe = rx.SeqExpr(blocks, x) + func = rx.Function([x], seqe, R.Tensor("float32")) + func = func.with_attr("global_symbol", "func") + + func_str = dump_ast(func) + assert func_str.startswith("Function(") + assert "params=" in func_str + assert "body=" in func_str + assert "ret_struct_info=" in func_str + assert "attrs=" in func_str + assert '"global_symbol": "func"' in func_str + assert "SeqExpr(" in func_str + assert "blocks=" in func_str + assert "VarBinding(" in func_str + assert func_str != dump_ast(func, include_type_annotations=False) + + +def test_shape_of(): + v0 = rx.Var("v0", R.Tensor(ndim=2)) + s0 = rx.get_shape_of(v0) + s0_str = dump_ast(s0) + assert s0_str.startswith("Call(") + assert 'op=Op(name="relax.shape_of")' in s0_str + assert "args=" in s0_str + assert 'name_hint="v0"' in s0_str + + v1 = rx.Var("v1", R.Tensor([96, 54])) + s1 = rx.get_shape_of(v1) + s1_str = dump_ast(s1) + assert s1_str.startswith("ShapeExpr("), s1_str + assert "values=" in s1_str + assert "PrimExpr(value=`T.int64(96)`)" in s1_str + assert "PrimExpr(value=`T.int64(54)`)" in s1_str + + +def test_shape_expr(): + shape_expr = rx.ShapeExpr([10, 20]) + shape_expr_str = dump_ast(shape_expr) + assert shape_expr_str.startswith("ShapeExpr(") + assert "values" in shape_expr_str + assert "PrimExpr(value=`T.int64(10)`)" in shape_expr_str + assert "PrimExpr(value=`T.int64(20)`)" in shape_expr_str + + +def test_types(): + printer = ASTPrinter() + assert strip_whitespace(printer.visit_type_(rx.ShapeType())) == "ShapeType(ndim=-1)" + assert strip_whitespace(printer.visit_type_(rx.ShapeType(ndim=1))) == "ShapeType(ndim=1)" + object_type = rx.ObjectType() + assert strip_whitespace(printer.visit_type_(object_type)) == "ObjectType()" + packed_type = rx.PackedFuncType() + assert strip_whitespace(printer.visit_type_(packed_type)) == "PackedFuncType()" + tensor_type = rx.DynTensorType(ndim=2, dtype="int32") + assert strip_whitespace(printer.visit_type_(tensor_type)) == "DynTensorType(ndim=2,dtype=int32)" + unit_type = rx.TupleType([]) + assert strip_whitespace(printer.visit_type_(unit_type)) == "TupleType(fields=[])" + tuple_type = rx.TupleType([rx.ShapeType(), object_type]) + assert_fields( + "TupleType", + {"fields": "[ShapeType(ndim=-1),ObjectType()]"}, + strip_whitespace(printer.visit_type_(tuple_type)), + ) + + func_type = rx.FuncType([tensor_type], unit_type) + assert_fields( + "FuncType", + {"arg_types": "[DynTensorType(ndim=2, dtype=int32)]", "ret_type": "TupleType(fields=[])"}, + printer.visit_type_(func_type), + ) + + +def test_struct_info(): + printer = ASTPrinter(include_type_annotations=True) + + assert printer.visit_struct_info_(rx.ObjectStructInfo()) == "ObjectStructInfo()" + + assert printer.visit_struct_info_(rx.PrimStructInfo("int32")) == "PrimStructInfo(dtype=int32)" + + # empty shape + empty_ssi = rx.ShapeStructInfo() + assert printer.visit_struct_info_(empty_ssi) == "ShapeStructInfo(ndim=-1)" + + # include some dimensions + shape_info = rx.ShapeStructInfo([tir.IntImm("int64", 1), tir.IntImm("int64", 2)]) + assert strip_whitespace(printer.visit_struct_info_(shape_info)) == strip_whitespace( + """ + ShapeStructInfo( + ndim=2, + values=[ + PrimExpr(value=`T.int64(1)`), + PrimExpr(value=`T.int64(2)`) + ] + ) + """ + ) + + # tensor struct info + default_tsi = rx.TensorStructInfo() + assert ( + strip_whitespace(printer.visit_struct_info_(default_tsi)) + == "TensorStructInfo(dtype=float32,ndim=-1)" + ) + + # use a var as the shape + x = rx.Var("x", struct_info=rx.ShapeStructInfo(values=[])) + var_tsi = rx.TensorStructInfo(shape=x, dtype="int32") + assert strip_whitespace(printer.visit_struct_info_(var_tsi)) == strip_whitespace( + """ + TensorStructInfo( + dtype=int32, + shape=Var( + name_hint="x", + struct_info=ShapeStructInfo(ndim=0, values=[]), + checked_type_=ShapeType(ndim=0) + ) + ) + """ + ) + + empty_tuple = rx.TupleStructInfo([]) + assert printer.visit_struct_info_(empty_tuple) == "TupleStructInfo(fields=[])" + + tuple_of_shape = rx.TupleStructInfo([empty_ssi]) + assert strip_whitespace(printer.visit_struct_info_(tuple_of_shape)) == strip_whitespace( + """ + TupleStructInfo(fields=[ + ShapeStructInfo(ndim=-1) + ]) + """ + ) + + simple_func = rx.FuncStructInfo([], rx.ObjectStructInfo()) + assert ( + strip_whitespace(printer.visit_struct_info_(simple_func)) + == "FuncStructInfo(params=[],ret=ObjectStructInfo())" + ) + + +def test_call_packed(): + # test case from test_parser + @R.function + def f( + x: R.Tensor((32, "m"), "float32"), + y: R.Tensor(("m",), "float32"), + r: R.Tensor(dtype="int64"), + ) -> R.Object: + m = T.int64() + z: R.Tensor((32, m), "float32") = R.multiply(x, y) + w: R.Tensor = R.multiply(z, z) + q: R.Tensor(ndim=2) = R.add(w, w) + t = R.add(w, z) + sh: R.Shape = R.shape_of(t) + o: R.Object = R.call_packed( + "contrib.tensor_array_stack", x, y, sinfo_args=R.Object(), test_attr=True + ) + return o + + # checking that the call_packed call is turned into a call to an extern func + f_str = strip_whitespace( + dump_ast( + f, + include_type_annotations=False, + include_struct_info_annotations=False, + include_call_attrs=True, + ) + ) + + # the function has an annotated return type + assert "ret_struct_info=ObjectStructInfo()" in f_str + + assert isinstance(f.body, rx.SeqExpr) + extern_call = f.body.blocks[0].bindings[-1].value + extern_call_text = dump_ast( + extern_call, + include_type_annotations=False, + include_struct_info_annotations=False, + include_call_attrs=True, + ) + assert strip_whitespace(extern_call_text) in f_str + assert_fields( + "Call", + { + "op": 'ExternFunc(global_symbol="contrib.tensor_array_stack")', + "args": '[Var(name_hint="x"), Var(name_hint="y")]', + "sinfo_args": "[ObjectStructInfo()]", + "attrs": '{"test_attr": 1}', + }, + extern_call_text, + ) + + # check that the op call is there too + op_call = f.body.blocks[0].bindings[0].value + op_call_text = dump_ast( + op_call, + include_type_annotations=False, + include_struct_info_annotations=False, + include_call_attrs=True, + ) + assert strip_whitespace(op_call_text) in f_str + assert_fields( + "Call", + { + "op": 'Op(name="relax.multiply")', + "args": '[Var(name_hint="x"), Var(name_hint="y")]', + }, + op_call_text, + ) + + # TODO: add testcase for op attrs + + +def test_call_tir(): + # also from test_parser + @tvm.script.ir_module + class TestCallTIR: + @T.prim_func + def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + T.func_attr(({"global_symbol": "addone"})) + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.int64(), T.int64() + gv0 = R.call_tir(TestCallTIR.addone, (x,), R.Tensor((m, n), dtype="float32")) + return gv0 + + mod = TestCallTIR + foo = mod["foo"] + + foo_str = strip_whitespace( + dump_ast( + foo, + include_type_annotations=False, + include_struct_info_annotations=False, + include_call_attrs=False, + ) + ) + assert foo_str.startswith('Function(params=[Var(name_hint="x")]') + + # call_tir is an op in Relax and it takes an extern func as an argument + assert isinstance(foo.body, rx.SeqExpr) + tir_call = foo.body.blocks[0].bindings[0].value + tir_call_text = dump_ast( + tir_call, + include_type_annotations=False, + include_struct_info_annotations=False, + include_call_attrs=False, + ) + assert_fields( + "Call", + { + "op": 'Op(name="relax.call_tir")', + "args": """[ + GlobalVar(name_hint="addone"), + Tuple(fields=[Var(name_hint="x")]) + ]""", + "sinfo_args": """[ + TensorStructInfo( + dtype=float32, + shape=ShapeExpr( + values=[ + PrimExpr(value=`m`), + PrimExpr(value=`n`) + ] + ) + ) + ]""", + }, + tir_call_text, + ) + assert strip_whitespace(tir_call_text) in foo_str + + +def test_call_dps_packed(): + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.int64(), T.int64() + gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + return gv0 + + foo_str = strip_whitespace( + dump_ast( + foo, + include_type_annotations=False, + include_struct_info_annotations=False, + include_call_attrs=False, + ) + ) + assert foo_str.startswith('Function(params=[Var(name_hint="x")]') + + # call_dps_packed is an op in Relax and it takes an extern func as an argument + assert isinstance(foo.body, rx.SeqExpr) + tir_call = foo.body.blocks[0].bindings[0].value + tir_call_text = dump_ast( + tir_call, + include_type_annotations=False, + include_struct_info_annotations=False, + include_call_attrs=False, + ) + assert_fields( + "Call", + { + "op": 'Op(name="relax.call_dps_packed")', + "args": """[ + ExternFunc(global_symbol="test.op.identity"), + Tuple(fields=[Var(name_hint="x")]) + ]""", + "sinfo_args": """[ + TensorStructInfo( + dtype=float32, + shape=ShapeExpr( + values=[ + PrimExpr(value=`m`), + PrimExpr(value=`n`) + ] + ) + ) + ]""", + }, + tir_call_text, + ) + assert strip_whitespace(tir_call_text) in foo_str + + +def test_operators(): + @R.function + def foo(x: R.Tensor): + return R.unique(x, sorted=True, axis=-1) + + foo_str = strip_whitespace( + dump_ast( + foo, + include_type_annotations=False, + include_struct_info_annotations=False, + ) + ) + assert 'Op(name="relax.unique")' in foo_str + # the sorted argument is true, so it will be a PrimValue of 1 + assert "PrimExpr(value=`T.int64(1)`)" in foo_str + # axis is -1 + assert "PrimExpr(value=`T.int64(-1)`)" in foo_str + + @R.function + def bar(x: R.Tensor): + return R.print(x, format="{}") + + bar_str = strip_whitespace( + dump_ast( + bar, + include_type_annotations=False, + include_struct_info_annotations=False, + ) + ) + # the format string is a StringImm argument + assert 'StringImm(value="{}")' in bar_str + + +def test_print_struct_info_annotation_non_var(): + @R.function + def f() -> R.Tensor: + return R.const([1, 2]) + + body = normalize(f).body + body_str = strip_whitespace(dump_ast(body)) + # the constant has a shape of (2,) + struct_info = strip_whitespace( + """ + struct_info=TensorStructInfo( + dtype=int32, + shape=ShapeExpr( + values=[PrimExpr(value=`T.int64(2)`)], + struct_info=ShapeStructInfo( + ndim=1, + values=[PrimExpr(value=`T.int64(2)`)] + ), + checked_type_=ShapeType(ndim=1) + ) + ) + """ + ) + assert struct_info in body_str + + +def test_print_type_annotation_non_var(): + @R.function + def f() -> R.Shape: + return R.shape_of(R.const(1)) + + body = normalize(f).body + assert isinstance(body, rx.SeqExpr) + call = body.blocks[-1].bindings[-1].value + assert isinstance(call, rx.Call) + arg = call.args[0] + arg_str = strip_whitespace(dump_ast(arg)) + # the constant should have a tensor type + assert "checked_type_=DynTensorType(ndim=0" in arg_str + + call_str = strip_whitespace(dump_ast(call)) + # we expect the shape_of call to have a checked_type_ of ShapeType + type_str = "checked_type_=ShapeType(ndim=0)" + assert type_str in call_str + + +def test_if(): + @R.function + def f(cond: R.Tensor((), dtype="bool")) -> R.Tensor((), dtype="int32"): + if cond: + x = R.const(1) + else: + x = R.const(2) + return x + + body = normalize(f).body + assert isinstance(body, rx.SeqExpr) + body_str = strip_whitespace(dump_ast(body)) + # we expect both branches to be seq exprs + assert "If" in body_str + assert "true_branch=SeqExpr(" in body_str + assert "false_branch=SeqExpr(" in body_str + + +def test_tuple_get_item(): + @R.function + def f(x: R.Tuple(R.Tensor((), dtype="int32"))) -> R.Tensor((), dtype="int32"): + return x[0] + + body = normalize(f).body + assert isinstance(body, rx.SeqExpr) + body_str = strip_whitespace(dump_ast(body)) + + assert "TupleGetItem" in body_str + assert 'tuple_value=Var(name_hint="x"' in body_str + assert "index=0" in body_str + + +def test_prim_value(): + prim_value = rx.PrimValue(tir.IntImm("int64", 1)) + prim_str = strip_whitespace(dump_ast(prim_value)) + assert prim_str == strip_whitespace( + """ + PrimValue( + value=PrimExpr(value=`T.int64(1)`), + struct_info=PrimStructInfo(dtype=int64), + checked_type_=PrimType(dtype=int64) + ) + """ + ) + + +def test_string_imm(): + string_imm = rx.StringImm("test") + str_str = strip_whitespace(dump_ast(string_imm)) + assert str_str == strip_whitespace( + """ + StringImm( + value="test", + struct_info=ObjectStructInfo(), + checked_type_=ObjectType() + ) + """ + ) + + +def test_datatype_imm(): + data_type_imm = rx.DataTypeImm("int32") + data_type_str = strip_whitespace(dump_ast(data_type_imm)) + assert data_type_str == strip_whitespace( + """ + DataTypeImm( + value=int32, + struct_info=ObjectStructInfo(), + checked_type_=ObjectType() + ) + """ + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py new file mode 100644 index 000000000000..4b194f154238 --- /dev/null +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -0,0 +1,431 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.script +import tvm.testing +from tvm import relax +from tvm.ir import assert_structural_equal +from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_const_shape_arg(): + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Shape([1, 2]), y: R.Shape): + return x + + @T.prim_func + def extra_func(H: T.Buffer(T.int64(4), "int64")): + """Extra function, checks if the pass preserves it.""" + H[T.int64(1)] = H[T.int64(0)] + T.int64(1) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Shape([1, 2]), y: R.Shape): + shape_heap = R.null_value() + _ = R.call_packed("vm.builtin.check_shape_info", x, 2, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_shape_info", y, -1, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.ASSERT_EQUAL_TO_IMM, + 1, + MS.ASSERT_EQUAL_TO_IMM, + 2, + "", + sinfo_args=[R.Tuple()], + ) + return x + + @T.prim_func + def extra_func(H: T.Buffer(T.int64(4), "int64")): + H[T.int64(1)] = H[T.int64(0)] + T.int64(1) + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_static_fn_check(): + """Check static shape and function.""" + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])): + return y + + @tvm.script.ir_module + class Expected: + @R.function + def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])): + shape_heap = R.null_value() + _ = R.call_packed("vm.builtin.check_func_info", f, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_shape_info", y, 2, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed( + "vm.builtin.match_shape", + y, + shape_heap, + 2, + MS.ASSERT_EQUAL_TO_IMM, + 1, + MS.ASSERT_EQUAL_TO_IMM, + 2, + "", + sinfo_args=[R.Tuple()], + ) + return y + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_simple_symbolic_shape(): + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor(["n", 2, "m"], "float32")): + return x + + sindex = { + "n": 0, + "m": 1, + } + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(["n", 2, "m"], "float32")): + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(2)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 3, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 3, + MS.STORE_TO_HEAP, + sindex["n"], + MS.ASSERT_EQUAL_TO_IMM, + 2, + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + return x + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_symbolic_compute(): + MS = MatchShapeCode + MK = MakeShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None) + ) -> R.Shape(ndim=3): + m = T.int64() + k = T.int64() + z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None)) + return R.shape([k + 1, m, 2]) + + # slot assignment: + # 0: n, 1: m, 2:k, 3: k+1 + sindex = {"n": 0, "m": 1, "k": 2, "k+1": 3} + + @tvm.script.ir_module + class Expected: + @T.prim_func + def shape_func(H: T.Buffer(T.int64(4), "int64")): + # generated compute function + T.func_attr({"tir.is_host_func": 1}) + H[T.int64(sindex["k+1"])] = H[T.int64(sindex["k"])] + T.int64(1) + + @R.function + def main( + x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None) + ) -> R.Shape(ndim=3): + m = T.int64() + k = T.int64() + cls = Expected + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(4)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", y, 3, R.dtype(""), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + _ = R.call_packed( + "vm.builtin.match_shape", + y, + shape_heap, + 3, + MS.STORE_TO_HEAP, + sindex["k"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["m"], + MS.NO_OP, + 0, + "", + sinfo_args=[R.Tuple()], + ) + _ = cls.shape_func(shape_heap) + # extra assertion on y's shape after shape computation + _ = R.call_packed( + "vm.builtin.match_shape", + y, + shape_heap, + 3, + MS.ASSERT_EQUAL_TO_LOAD, + sindex["k"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["m"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["k+1"], + "", + sinfo_args=[R.Tuple()], + ) + z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None)) + # construct shape value for return + s = R.call_packed( + "vm.builtin.make_shape", + shape_heap, + 3, + MK.LOAD_SHAPE, + sindex["k+1"], + MK.LOAD_SHAPE, + sindex["m"], + MK.USE_IMM, + 2, + sinfo_args=[R.Shape(ndim=3)], + ) + return s + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_tuple_handling(): + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tuple( + R.Tensor(["n", "m"], "float32"), R.Tuple(R.Shape, R.Tensor(["n", "k"], "int32")) + ) + ): + return x + + # slot assignment: + sindex = {"n": 0, "m": 1, "k": 2} + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tuple( + R.Tensor(["n", "m"], "float32"), R.Tuple(R.Shape, R.Tensor(["n", "k"], "int32")) + ) + ): + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(3)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + # recursively unpack tuple for static info check + _ = R.call_packed("vm.builtin.check_tuple_info", x, 2, "", sinfo_args=[R.Tuple()]) + t0 = x[0] + _ = R.call_packed( + "vm.builtin.check_tensor_info", + t0, + 2, + R.dtype("float32"), + "", + sinfo_args=[R.Tuple()], + ) + t1 = x[1] + _ = R.call_packed("vm.builtin.check_tuple_info", t1, 2, "", sinfo_args=[R.Tuple()]) + t1x0 = t1[0] + _ = R.call_packed("vm.builtin.check_shape_info", t1x0, -1, "", sinfo_args=[R.Tuple()]) + t1x1 = t1[1] + _ = R.call_packed( + "vm.builtin.check_tensor_info", + t1x1, + 2, + R.dtype("int32"), + "", + sinfo_args=[R.Tuple()], + ) + # match shape checks. + _ = R.call_packed( + "vm.builtin.match_shape", + t0, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + _ = R.call_packed( + "vm.builtin.match_shape", + t1x1, + shape_heap, + 2, + MS.ASSERT_EQUAL_TO_LOAD, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["k"], + "", + sinfo_args=[R.Tuple()], + ) + return x + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_return_match_check(): + """Test when return body is not same as ret_struct_info, runtime match check needed.""" + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["n", "m"], "float32"), y: R.Object + ) -> R.Tuple(R.Tensor(["n", "m"], "float32")): + return y + + # slot assignment: + sindex = { + "n": 0, + "m": 1, + } + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["n", "m"], "float32"), y: R.Object + ) -> R.Tuple(R.Tensor(["n", "m"], "float32")): + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(2)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + _ = R.call_packed("vm.builtin.check_tuple_info", y, 1, "", sinfo_args=[R.Tuple()]) + # emit runtime function call since y do not have the right type. + y1 = R.call_packed("vm.builtin.tuple_getitem", y, 0, sinfo_args=[R.Object]) + # run check + _ = R.call_packed( + "vm.builtin.check_tensor_info", + y1, + 2, + R.dtype("float32"), + "", + sinfo_args=[R.Tuple()], + ) + # shape check + _ = R.call_packed( + "vm.builtin.match_shape", + y1, + shape_heap, + 2, + MS.ASSERT_EQUAL_TO_LOAD, + sindex["n"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + + return y + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_binding_rewrite.py b/tests/python/relax/test_binding_rewrite.py new file mode 100644 index 000000000000..d0d3344eb61e --- /dev/null +++ b/tests/python/relax/test_binding_rewrite.py @@ -0,0 +1,336 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.testing +from tvm._ffi.base import TVMError +from tvm.relax.analysis import name_to_binding +from tvm.relax.binding_rewrite import DataflowBlockRewrite +from tvm.relax.expr import DataflowVar, Var +from tvm.script import relax as R + + +@tvm.script.ir_module +class Identity: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + +def assert_immutability(rwt, original_dfb, original_root_fn): + assert rwt.mutated_dfb() != original_dfb + assert rwt.mutated_root_fn() != original_root_fn + assert rwt.mutated_root_fn().body.blocks[0] != original_dfb + assert rwt.mutated_root_fn().body.blocks[0] == rwt.mutated_dfb() + + +def test_null_construct(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + DataflowBlockRewrite(dfb, root_fn) + + +def test_simple_add(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(name="tmp", expr=Identity["main"].params[0], is_dfvar=True) + + assert_immutability(rwt, dfb, root_fn) + + # check "tmp" added + assert "tmp" in name_to_binding(rwt.mutated_root_fn()) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + tmp: R.Tensor((32, 32), "float32") = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_simple_auto_add_var(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(root_fn.params[0], is_dfvar=False) + + assert isinstance(rwt.mutated_dfb().bindings[-1].var, Var) + + assert_immutability(rwt, dfb, root_fn) + + +def test_simple_auto_add_dfvar(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(root_fn.params[0], is_dfvar=True) + + assert isinstance(rwt.mutated_dfb().bindings[-1].var, DataflowVar) + + # immutatbility + assert_immutability(rwt, dfb, root_fn) + + +def test_simple_remove_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused = lv0 + R.output(lv0) + return lv0 + + root_fn = IdentityUnused["main"] + dfb = root_fn.body.blocks[0] + + n2binding = name_to_binding(IdentityUnused["main"]) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(n2binding["unused"][0].var) + + assert_immutability(rwt, dfb, root_fn) + + # check "unused" removed + assert "unused" not in name_to_binding(rwt.mutated_root_fn()) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_remove_unused_undef(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + with pytest.raises(TVMError): + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(Var("whatever")) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(Var("whatever"), allow_undef=True) + + assert root_fn == rwt.mutated_root_fn() + + +def test_simple_rm_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = lv0 + unused1 = lv0 + R.output(lv0) + return lv0 + + root_fn = IdentityUnused["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +@tvm.script.ir_module +class DeadDFBlock: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + with R.dataflow(): + lv0 = x + R.output(lv0) + return x + + +def test_empty_dfb_after_removal(): + root_fn = DeadDFBlock["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(DeadDFBlock["main"].body.blocks[0].bindings[0].var) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + return x + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_empty_dfb_after_all_removal(): + dfb = DeadDFBlock["main"].body.blocks[0] + root_fn = DeadDFBlock["main"] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + return x + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_chained_rm_all_unused(): + @tvm.script.ir_module + class IdentityChainedUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed( + "my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32") + ) + R.output(lv0) + return lv0 + + root_fn = IdentityChainedUnused["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_simple_replace_all_uses(): + @tvm.script.ir_module + class Lv0To1: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + # lv0 => lv1 + # / \ + # lv2 lv3 + # \ / + # lv4 + with R.dataflow(): + lv0: R.Tensor((32, 32), "float32") = R.call_dps_packed( + "my_relu", (x,), R.Tensor((32, 32), dtype="float32") + ) + lv1: R.Tensor((32, 32), "float32") = R.call_dps_packed( + "my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32") + ) + lv2: R.Tensor((32, 32), "float32") = R.call_dps_packed( + "my_add", (x, lv0), R.Tensor((32, 32), dtype="float32") + ) + lv3: R.Tensor((32, 32), "float32") = R.call_dps_packed( + "my_mul", (x, lv0), R.Tensor((32, 32), dtype="float32") + ) + lv4: R.Tensor((32, 32), "float32") = R.call_dps_packed( + "my_whatever", (lv2, lv3), R.Tensor((32, 32), dtype="float32") + ) + R.output(lv4) + return lv4 + + root_fn = Lv0To1["main"] + dfb = root_fn.body.blocks[0] + + n2binding = name_to_binding(root_fn) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.replace_all_uses(n2binding["lv0"][0].var, n2binding["lv1"][0].var) + rwt.remove_unused(n2binding["lv0"][0].var) + + assert_immutability(rwt, dfb, root_fn) + + n2binding_after = name_to_binding(rwt.mutated_root_fn()) + assert "lv0" not in n2binding_after + + +def test_simple_module_update(): + @tvm.script.ir_module + class Identity: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(name="tmp", expr=root_fn.params[0], is_dfvar=True) + + new_ir = rwt.mutate_irmodule(Identity) + + # immutatbility + assert new_ir != Identity + assert 2 == len(new_ir["main"].body.blocks[0].bindings) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + tmp: R.Tensor((32, 32), "float32") = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(new_ir, GroundTruth) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py new file mode 100644 index 000000000000..9d9d28d7d615 --- /dev/null +++ b/tests/python/relax/test_blockbuilder.py @@ -0,0 +1,582 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.testing + +from tvm import te, tir, topi +from tvm import relax as rx, relay +from tvm.ir.base import assert_structural_equal +from tvm.relax import ExternFunc +from tvm.script import relax as R +from tvm.tir.function import PrimFunc + + +@tvm.register_func("test.blockbuilder.nop") +def nop(): + pass + + +def test_block_builder(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + bb._begin_binding_block() + gv0 = bb.emit(rx.op.add(x, y)) + bb._begin_dataflow_block() + lv0 = bb.emit(rx.op.multiply(gv0, y)) + gv1 = bb.emit_output(rx.op.multiply(lv0, lv0)) + b0 = bb._end_block() + bb._begin_dataflow_block() + lv1 = bb.emit(rx.op.multiply(gv0, y)) + gv2 = bb.emit_output(rx.op.multiply(lv1, lv1)) + b1 = bb._end_block() + gv3 = bb.emit(rx.op.add(x, y)) + b2 = bb._end_block() + + assert isinstance(b0, rx.DataflowBlock) + assert isinstance(b1, rx.DataflowBlock) + assert not isinstance(b2, rx.DataflowBlock) + + +def test_emit_with_name(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + bb._begin_dataflow_block() + lv0 = bb.emit(rx.op.add(x, y), "add") + gv0 = bb.emit_output(rx.op.multiply(lv0, y), "multi") + b0 = bb._end_block() + + assert b0.bindings[0].var.name_hint == "add" + assert b0.bindings[1].var.name_hint == "multi" + + +def test_function_single_block(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + lv1 = bb.emit(rx.op.multiply(lv0, y)) + assert lv1.name_hint == "lv1" + gv0 = bb.emit_output(lv1) + assert gv0.name_hint == "gv" + bb.emit_func_output(gv0) + + func = bb.get()["func"] + assert func.params[0] == x + assert func.params[1] == y + assert func.body.body == gv0 + assert_structural_equal(gv0.struct_info, rx.TensorStructInfo([m, n], "float16")) + assert len(func.body.blocks) == 1 + assert len(func.body.blocks[0].bindings) == 3 + + +def test_function_multi_blocks(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + gv0 = bb.emit_output(lv0) + assert gv0.name_hint == "gv" + gv1 = bb.emit(rx.op.add(gv0, gv0)) + assert gv1.name_hint == "gv1" + with bb.dataflow(): + lv1 = bb.emit(rx.op.add(gv1, gv1)) + assert lv1.name_hint == "lv1" + gv2 = bb.emit_output(gv1) + bb.emit_func_output(gv2) + + func = bb.get()["func"] + + assert_structural_equal(gv2.struct_info, rx.TensorStructInfo([m, n], "float16")) + assert func.params[0] == x + assert func.params[1] == y + assert func.body.body == gv2 + assert len(func.body.blocks) == 3 + assert len(func.body.blocks[0].bindings) == 2 + assert len(func.body.blocks[1].bindings) == 1 + assert len(func.body.blocks[2].bindings) == 2 + + +def test_multi_functions(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with bb.function("func1", [x, y]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + gv0 = bb.emit_output(lv0) + bb.emit_func_output(gv0) + + with bb.function("func2", [x, y]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(y, x)) + # TODO(@yuchen): enable block builder to reset local var unique name map + assert lv0.name_hint == "lv1" + gv0 = bb.emit_output(lv0) + bb.emit_func_output(gv0) + + mod = bb.get() + func1 = mod["func1"] + assert func1.params[0] == x + assert func1.params[1] == y + assert len(func1.body.blocks) == 1 + func2 = mod["func2"] + assert func2.params[0] == x + assert func2.params[1] == y + assert len(func2.body.blocks) == 1 + + +def test_binary_shape_type_deduction(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + k = tir.Var("k", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, 1], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + z = rx.Var("z", rx.TensorStructInfo([5], "float16")) + w = rx.Var("w", rx.TensorStructInfo([k], "float16")) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y, z, w]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(x, y)) + assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float16")) + + lv1 = bb.emit(rx.op.multiply(x, z)) + assert_structural_equal(lv1.struct_info, rx.TensorStructInfo([m, 5], "float16")) + + lv2 = bb.emit(rx.op.multiply(z, w)) + assert isinstance(lv2.struct_info, rx.TensorStructInfo) + assert lv2.struct_info.ndim == 1 + assert lv2.struct_info.dtype == "float16" + + lv3 = bb.emit(rx.op.multiply(y, w)) + assert isinstance(lv3.struct_info, rx.TensorStructInfo) + assert lv3.struct_info.ndim == 1 + assert lv3.struct_info.dtype == "float16" + + gv0 = bb.emit_output(lv3) + bb.emit_func_output(gv0) + + assert isinstance(gv0.checked_type, rx.DynTensorType) + assert gv0.checked_type.ndim == 1 + assert gv0.checked_type.dtype == "float16" + + +def test_emit_match_cast(): + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + x = rx.Var("tensor_value", rx.TensorStructInfo(dtype="float32", ndim=-1)) + y = rx.Var("shape_value", rx.ShapeStructInfo([16, 8])) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y]): + with bb.dataflow(): + # lv0: Tensor((m, n), "float32") = + # match_cast(x: Tensor(_, "float32"], [m, n)) + lv0 = bb.match_cast(x, rx.TensorStructInfo([m, n], "float32")) + assert isinstance(lv0, rx.DataflowVar) + assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float32")) + + # lv1: Shape = match_cast(shape, rx.ShapeStructInfo([m, n])) + lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n])) + assert lv1.struct_info == rx.ShapeStructInfo([m, n]) + gv0 = bb.emit_output(lv1) + + bb.emit_func_output(gv0) + func = bb.get()["func"] + block = func.body.blocks[0] + b0, b1 = block.bindings[:2] + assert isinstance(b0, rx.MatchCast) + assert isinstance(b1, rx.MatchCast) + + assert b0.value == x + assert b0.struct_info == rx.TensorStructInfo([m, n], "float32") + assert b0.var == lv0 + + assert b1.value == y + assert b1.struct_info == rx.ShapeStructInfo([m, n]) + assert b1.var == lv1 + + +def test_emit_match_cast_binding_in_dataflow_block(): + bb = rx.BlockBuilder() + + x = rx.Var("x", rx.TensorStructInfo(dtype="float32", ndim=-1)) + m = tir.Var("m", dtype="int64") + gv = rx.Var("gv", rx.TensorStructInfo(dtype="float32", ndim=-1)) + match_cast = rx.MatchCast(gv, x, rx.TensorStructInfo((m,), "float32")) + + with bb.function("main", [x]): + with bb.dataflow(): + bb.emit_normalized(match_cast) + bb.emit_output(gv) + bb.emit_func_output(x) + + func = bb.get()["main"] + block = func.body.blocks[0] + b0 = block.bindings[0] + assert isinstance(b0, rx.MatchCast) + + assert b0.value == x + assert isinstance(b0.struct_info, rx.TensorStructInfo) + assert b0.struct_info.shape[0] == m + assert b0.var == gv + + +def test_normalize(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + # Call node + add_call = rx.op.multiply(x, y) + + bb.normalize(add_call) + shape = rx.get_shape_of(add_call) + + assert isinstance(shape, rx.ShapeExpr) + assert shape[0] == m + assert shape[1] == n + + # Tuple node + tuple_1 = rx.Tuple([x, y]) + bb.normalize(tuple_1) + assert isinstance(tuple_1.struct_info, rx.TupleStructInfo) + assert isinstance(tuple_1.struct_info.fields[0], rx.TensorStructInfo) + assert isinstance(tuple_1.struct_info.fields[1], rx.TensorStructInfo) + + # Nested Tuple + tuple_2 = rx.Tuple([x, rx.Tuple([x, y])]) + bb.normalize(tuple_2) + type_anno0 = x.checked_type + type_anno1 = y.checked_type + assert_structural_equal( + tuple_2.checked_type, rx.TupleType([type_anno0, rx.TupleType([type_anno0, type_anno1])]) + ) + assert isinstance(tuple_2.struct_info, rx.TupleStructInfo) + assert isinstance(tuple_2.struct_info.fields[0], rx.TensorStructInfo) + assert isinstance(tuple_2.struct_info.fields[1], rx.TupleStructInfo) + assert isinstance(tuple_2.struct_info.fields[1].fields[0], rx.TensorStructInfo) + assert isinstance(tuple_2.struct_info.fields[1].fields[1], rx.TensorStructInfo) + + +def test_call_te(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + z = rx.Var("z", rx.TensorStructInfo([n, m], "float32")) + + def te_func(args, args_dict, msg): + A, B = args + C = args_dict["C"] + D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j]) + return E + + with bb.function("rx_func", [x, y, z]): + with bb.dataflow(): + out = bb.emit_output(bb.call_te(te_func, [x, y], {"C": z}, msg="hello")) + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_func"] + + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert rx_func.params[2] == z + assert rx_func.body.body == out + assert len(rx_func.body.blocks) == 1 + assert len(rx_func.body.blocks[0].bindings) == 1 + + +def test_call_te_with_unsupported_shape_arg(): + bb = rx.BlockBuilder() + x = rx.Var("x", rx.TensorStructInfo((200,), "float32")) + s = rx.Var("s", rx.ShapeStructInfo((200,))) + + with pytest.raises(AssertionError): + with bb.function("rx_func", [x]): + out = bb.emit(bb.call_te(topi.reshape, x, s)) + bb.emit_func_output(out) + + +def test_emit_te(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + z = rx.Var("z", rx.TensorStructInfo([n, m], "float32")) + + def te_func(args, args_dict, msg): + A, B = args + C = args_dict["C"] + D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j]) + return E + + with bb.function("rx_func", [x, y, z]): + out = bb.emit_te(te_func, [x, y], {"C": z}, msg="hello") + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_func"] + + def get_tir_func(): + A = te.placeholder((n, m), dtype="float32", name="A") + B = te.placeholder((n, m), dtype="float32", name="B") + C = te.placeholder((n, m), dtype="float32", name="C") + out = te_func((A, B), {"C": C}, "") + return tvm.te.create_prim_func([A, B, C, out], index_dtype_override="int64") + + # check TIR structure matches expected + assert_structural_equal(mod["te_func"].body, get_tir_func().body) + + # check Relax function calls TIR function with call_tir call + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert rx_func.params[2] == z + assert rx_func.body.body == out + assert len(rx_func.body.blocks) == 1 + assert len(rx_func.body.blocks[0].bindings) == 1 + + call_node = rx_func.body.blocks[0].bindings[0].value + assert isinstance(call_node, rx.Call) + assert call_node.op == relay.op.get("relax.call_tir") + assert len(call_node.args) == 2 + assert call_node.args[0].name_hint == "te_func" + assert call_node.args[1][0] == x + assert call_node.args[1][1] == y + assert call_node.args[1][2] == z + + +def test_emit_te_multiple(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + z = rx.Var("z", rx.TensorStructInfo([128, m], "float32")) + + def te_func(A): + B = te.compute((128, 128), lambda i, j: A[i, j] + 1) + return B + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, x) + y1 = bb.emit_te(te_func, y) + z1 = bb.emit_te(te_func, z) + bb.emit_func_output(z1) + + mod = bb.get() + rx_func = mod["rx_func"] + + prim_func = [] + for gv in mod.get_global_vars(): + if isinstance(mod[gv], PrimFunc): + prim_func.append(mod[gv]) + + # only two PrimFuncs were generated since two of them are equal so got deduped + assert len(prim_func) == 2 + assert rx_func.body.blocks[0].bindings[0].value.args[0].name_hint == "te_func" + assert rx_func.body.blocks[0].bindings[1].value.args[0].name_hint == "te_func" + assert rx_func.body.blocks[0].bindings[2].value.args[0].name_hint == "te_func1" + + +def test_emit_te_multiple_output(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + + def te_func(A): + B0, B1 = te.compute((n, m), lambda i, j: (A[i, j] + 1, A[i, j] * 2), name="B") + return (B0, B1) + + with bb.function("rx_func", [x]): + y = bb.emit_te(te_func, x) + z = rx.TupleGetItem(y, 0) + bb.emit_func_output([y, z]) + + rx_func = bb.get()["rx_func"] + + # check call tir output shape is a Tuple of ShapeExpr + assert rx_func.params[0] == x + call_node = rx_func.body.blocks[0].bindings[0].value + assert call_node.op == relay.op.get("relax.call_tir") + assert call_node.args[0].name_hint == "te_func" + assert isinstance(call_node.sinfo_args[0], rx.TupleStructInfo) + assert len(call_node.sinfo_args[0].fields) == 2 + assert isinstance(call_node.sinfo_args[0].fields[0].shape, rx.ShapeExpr) + assert isinstance(call_node.sinfo_args[0].fields[1].shape, rx.ShapeExpr) + + +def test_emit_te_extern(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([m, n], "float32")) + + with bb.function("rx_cblas_matmul", [x, y]): + out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_cblas_matmul"] + + # check Relax function calls TIR function with call_tir call + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert len(rx_func.body.blocks) == 1 + call_node = rx_func.body.blocks[0].bindings[0].value + assert isinstance(call_node, rx.Call) + assert call_node.op == relay.op.get("relax.call_tir") + assert len(call_node.args) == 2 + assert call_node.args[0].name_hint == "matmul" + assert call_node.args[1][0] == x + assert call_node.args[1][1] == y + assert call_node.sinfo_args[0].shape[0] == n + assert call_node.sinfo_args[0].shape[1] == n + + +def test_emit_te_prim_value(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", R.Tensor([n, m], "float32")) + a_min = rx.PrimValue(0) + a_max = rx.PrimValue(6) + + with bb.function("rx_clip", [x]): + out = bb.emit_te(topi.clip, x, a_min, a_max) + bb.emit_func_output(out) + + rx_func = bb.get()["rx_clip"] + + # check Relax function calls TIR function with call_tir call + assert rx_func.params[0] == x + assert len(rx_func.body.blocks) == 1 + call_node = rx_func.body.blocks[0].bindings[0].value + assert isinstance(call_node, rx.Call) + assert call_node.op == relay.op.get("relax.call_tir") + assert len(call_node.args) == 2 + assert call_node.args[1][0] == x + + +def test_nested_function_fail(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, x)) + with bb.function("func1", [x, y]): + gv1 = bb.emit(rx.op.add(x, x)) + bb.emit_func_output(gv0) + + +def test_emit_func_output_twice_fail(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv0) + bb.emit_func_output(gv0) + + +def test_func_params_twice_fail(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv0, [x]) + + +def test_no_func_params_fail(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func"): + gv0 = bb.emit(rx.Call(ExternFunc("test.blockbuilder.nop"), [])) + bb.emit_func_output(gv0) + + +def test_block_builder_scope_recovery(): + bb = rx.BlockBuilder() + + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([m, n], "float32")) + + with pytest.raises(RuntimeError): + # this line fails + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + + # current should be recovered + assert rx.BlockBuilder.current() is None + + # second attempt to do it correctly. + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv0) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py new file mode 100644 index 000000000000..c8ca44311de5 --- /dev/null +++ b/tests/python/relax/test_codegen_cutlass.py @@ -0,0 +1,685 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest + +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relax +from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul +from tvm.contrib.pickle_memoize import memoize +from tvm.relax.backend import get_patterns_with_prefix +from tvm.relax.backend.contrib.cutlass import partition_for_cutlass +from tvm.script import relax as R +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder + + +@pytest.fixture(autouse=True) +def reset_seed(): + np.random.seed(0) + + +@tvm.script.ir_module +class Conv2dBiasReLU: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), "float16"), + weight: R.Tensor((32, 3, 3, 16), "float16"), + bias: R.Tensor((1, 1, 1, 32), "float16"), + ): + with R.dataflow(): + conv1 = R.nn.relu( + R.nn.conv2d(data, weight, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI") + + bias, + ) + R.output(conv1) + + return conv1 + + +@tvm.script.ir_module +class Conv2dx2: + @R.function + def main( + data: R.Tensor((16, 32, 32, 8), "float16"), + weight1: R.Tensor((8, 3, 3, 8), "float16"), + weight2: R.Tensor((8, 3, 3, 8), "float16"), + ): + with R.dataflow(): + conv1 = relax.op.nn.conv2d( + data, weight1, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + conv2 = relax.op.nn.conv2d( + conv1, weight2, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + R.output(conv2) + + return conv2 + + +has_cutlass = tvm.get_global_func("relax.ext.cutlass", True) + +cutlass_enabled = pytest.mark.skipif( + not has_cutlass, + reason="CUTLASS not enabled.", +) + +pytestmark = [cutlass_enabled] + + +def build_and_run(mod, inputs_np, target, legalize=False): + if legalize: + mod = relax.transform.LegalizeOps()(mod) + + dev = tvm.device(target, 0) + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, dev) + f = vm["main"] + inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + return f(*inputs).numpy() + + +def get_result_with_relax_cutlass_offload(mod, *args, assert_all_bindings_fused=True): + patterns = [(entry.name, entry.pattern) for entry in get_patterns_with_prefix("cutlass")] + assert len(patterns) != 0, "Cannot find cutlass patterns" + + mod = partition_for_cutlass(mod) + + if assert_all_bindings_fused: + assert len(mod["main"].body.blocks[0].bindings) == 1 + + codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}}) + mod = codegen_pass(mod) + + return build_and_run(mod, args, "cuda") + + +def test_kernel_sharing(): + low, high = -1, 1 + data_np = np.random.randint(low, high, size=(16, 32, 32, 8)).astype("float16") + weight1_np = np.random.randint(low, high, size=(8, 3, 3, 8)).astype("float16") + weight2_np = np.random.randint(low, high, size=(8, 3, 3, 8)).astype("float16") + + out = get_result_with_relax_cutlass_offload( + Conv2dx2, data_np, weight1_np, weight2_np, assert_all_bindings_fused=False + ) + ref = build_and_run(Conv2dx2, [data_np, weight1_np, weight2_np], "llvm", legalize=True) + + np.testing.assert_equal(out, ref) + + +def get_relax_conv2d_module( + data_shape, + weight_shape, + dtype, + with_bias=False, + activation=None, + residual_bin_op=None, + residual_activation=None, +): + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + data = R.arg("data", R.Tensor(data_shape, dtype)) + weight = R.arg("weight", R.Tensor(weight_shape, dtype)) + if with_bias: + bias = R.arg("bias", R.Tensor((1, 1, 1, weight_shape[0]), dtype)) + + with R.dataflow() as frame: + output = R.emit( + R.nn.conv2d( + data, + weight, + out_dtype=dtype, + padding=(1, 1), + data_layout="NHWC", + kernel_layout="OHWI", + ) + ) + if with_bias: + output = R.emit(output + bias) + if activation is not None: + output = R.emit(activation(output)) + if residual_bin_op is not None: + output = R.emit(residual_bin_op(output, data)) + if residual_activation is not None: + output = R.emit(residual_activation(output)) + R.output(output) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def get_relax_matmul_module( + x_shape, + y_shape, + dtype, + transposed_y=False, + with_bias=False, + activation=None, + residual_bin_op=None, + residual_activation=None, +): + if transposed_y: + n = y_shape[-2] + else: + n = y_shape[-1] + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + x = R.arg("x", R.Tensor(x_shape, dtype)) + y = R.arg("y", R.Tensor(y_shape, dtype)) + if with_bias: + bias = R.arg("bias", R.Tensor((n,), dtype)) + + with R.dataflow() as frame: + if transposed_y: + axes = list(range(len(y_shape) - 2)) + [-1, -2] + y = R.emit(R.permute_dims(y, axes=axes)) + result = R.emit(R.matmul(x, y, out_dtype=dtype)) + if with_bias: + result = R.emit(result + bias) + if activation is not None: + result = R.emit(activation(result)) + if residual_bin_op is not None: + result = R.emit(residual_bin_op(result, x)) + if residual_activation is not None: + result = R.emit(residual_activation(result)) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def _to_concrete_shape(symbolic_shape, var_table): + result = [] + for dim in symbolic_shape: + if not isinstance(dim, tvm.tir.expr.Var): + result.append(dim) + continue + + if dim not in var_table: + var_table[dim] = np.random.randint(10, 50) + result.append(var_table[dim]) + + return tuple(result) + + +_vars = { + "a": tvm.tir.expr.Var("a", "int64"), + "b": tvm.tir.expr.Var("b", "int64"), +} + + +_epilogue_table = { + "none": (False, None), + "bias": (True, None), + "relu": (True, R.nn.relu), + "gelu": (True, R.nn.gelu), + "silu": (True, R.nn.silu), +} + + +_residual_block_table = { + "none": (None, None), + "add_relu": (R.add, R.nn.relu), + "mul_relu": (R.multiply, R.nn.relu), + "add": (R.add, None), + "mul": (R.multiply, None), +} + + +@pytest.mark.parametrize( + "data_shape, weight_shape, dtype, epilogue, residual_block", + [ + # Regular + ((16, 32, 32, 16), (32, 3, 3, 16), "float16", "none", "none"), + ((40, 128, 50, 16), (16, 2, 2, 16), "float16", "bias", "none"), + ((3, 64, 64, 128), (32, 1, 1, 128), "float16", "relu", "none"), + ((12, 32, 32, 16), (45, 5, 5, 16), "float16", "silu", "none"), + # residual block + ((3, 64, 64, 16), (16, 3, 3, 16), "float16", "relu", "add"), + ((16, 32, 32, 16), (16, 3, 3, 16), "float16", "relu", "mul_relu"), + ((40, 128, 50, 16), (16, 3, 3, 16), "float16", "bias", "add_relu"), + ((128, 32, 32, 16), (16, 3, 3, 16), "float16", "silu", "mul"), + ], +) +def test_conv2d_offload(data_shape, weight_shape, dtype, epilogue, residual_block): + low, high = -1, 1 + data = np.random.randint(low, high, size=data_shape).astype(dtype) + weight = np.random.randint(low, high, size=weight_shape).astype(dtype) + bias = np.random.randint(low, high, size=(1, 1, 1, weight_shape[0])).astype(dtype) + + with_bias, activation = _epilogue_table[epilogue] + residual_bin_op, residual_activation = _residual_block_table[residual_block] + + if with_bias: + args = (data, weight, bias) + else: + args = (data, weight) + + mod = get_relax_conv2d_module( + data_shape, + weight_shape, + dtype, + with_bias=with_bias, + activation=activation, + residual_bin_op=residual_bin_op, + residual_activation=residual_activation, + ) + out = get_result_with_relax_cutlass_offload(mod, *args) + + ref = build_and_run(mod, args, "llvm", legalize=True) + + tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + + +def test_cutlass_partition_conv2d_residual_blocked(): + @tvm.script.ir_module + class Conv2dReLU: + """ + This conv2d should not be fused as conv2d residual block, because both lhs and rhs of + the last R.add depends on the result of conv2d. + """ + + @R.function + def main( + data: R.Tensor((32, 3, 3, 16), "float32"), + weight: R.Tensor((16, 3, 3, 16), "float32"), + bias: R.Tensor((1, 1, 1, 16), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d( + data, + weight, + padding=(1, 1), + data_layout="NHWC", + kernel_layout="OHWI", + ) + out = R.nn.relu(conv1 + bias) + # residual depends on conv result, which cannot be handled in cutlass + result = out + out + R.output(result) + + return result + + mod = partition_for_cutlass(Conv2dReLU, annotate_codegen=False) + for f_var in mod.functions: + func = mod[f_var] + if func.attrs and "Composite" in func.attrs: + # verify that the function is not fused as residual block + assert func.attrs["Composite"] == "cutlass.conv2d_bias_relu" + + +@pytest.mark.parametrize( + "x_shape, y_shape, transpose_y, epilogue, residual_block", + [ + # Regular + ((32, 6), (6, 16), False, "none", "none"), + ((_vars["a"], 6), (6, 16), False, "bias", "none"), + # Transposed + ((4, 16), (16, 128), True, "relu", "none"), + ((35, 8), (8, 8), True, "gelu", "none"), + # 3D x 3D + ((6, 32, 8), (6, 8, 10), False, "bias", "none"), + ((6, 32, 8), (6, 8, 10), True, "none", "none"), + ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu", "none"), + # 3D x 2D + ((6, 32, 8), (8, 10), False, "none", "none"), + ((_vars["a"], 32, 8), (8, 10), False, "bias", "none"), + ((10, 16, 8), (8, 10), True, "relu", "none"), + # 2D x 3D + ((32, 8), (10, 8, 10), False, "relu", "none"), + ((32, 8), (_vars["a"], 8, 10), True, "gelu", "none"), + # ND x 2D + ((3, 6, 32, 8), (8, 10), False, "bias", "none"), + ((_vars["a"], _vars["b"], 6, 32, 8), (8, 10), False, "none", "none"), + # 2D x ND + ((32, 8), (5, 3, 8, 10), False, "gelu", "none"), + # ND x ND + ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu", "none"), + ((3, 2, 4, 16, 15), (1, 1, 15, 2), True, "gelu", "none"), + ((1, 1, 16, 15), (3, 2, _vars["a"], 15, 2), False, "none", "none"), + # Residual + ((32, 8), (8, 8), False, "bias", "add"), + ((4, 16), (16, 16), True, "relu", "add_relu"), + # Residual fusion without bias - this is supported via the matmul + bias pattern + # where bias == residual input + ((4, 16), (16, 16), False, "none", "add"), + ], +) +@pytest.mark.parametrize( + "dtype", + [ + "float16", + ], +) +def test_matmul_offload( + x_shape, + y_shape, + transpose_y, + epilogue, + residual_block, + dtype, +): + with_bias, activation = _epilogue_table[epilogue] + var_table = {} + concrete_x_shape = _to_concrete_shape(x_shape, var_table) + concrete_y_shape = _to_concrete_shape(y_shape, var_table) + x = np.random.randn(*concrete_x_shape).astype(dtype) + y = np.random.randn(*concrete_y_shape).astype(dtype) + + if transpose_y: + y = np.swapaxes(y, -2, -1) + y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2]) + + if with_bias: + bias = np.random.randn(concrete_y_shape[-1]).astype(dtype) + args = (x, y, bias) + else: + bias = None + args = (x, y) + + residual_bin_op, residual_activation = _residual_block_table[residual_block] + + mod = get_relax_matmul_module( + x_shape, + y_shape, + dtype, + with_bias=with_bias, + transposed_y=transpose_y, + activation=activation, + residual_bin_op=residual_bin_op, + residual_activation=residual_activation, + ) + out = get_result_with_relax_cutlass_offload(mod, *args) + ref = build_and_run(mod, args, "llvm", legalize=True) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize( + "x_shape, y_shape, expected", + [ + # Regular matmul + ((3, 4), (4, 5), True), + # Batch matmul without stretching + ((3, 16, 15), (3, 15, 2), True), + ((_vars["a"], 16, 15), (_vars["a"], 15, 2), True), + # Broadcast 2D to 3D + ((3, 16, 15), (15, 2), True), + ((_vars["a"], 16, 15), (15, 2), True), + ((16, 15), (3, 15, 2), True), + # Broadcast one-length dimension + ((1, 16, 15), (3, 15, 2), True), + ((3, 16, 15), (1, 15, 2), True), + ((1, 1, 16, 15), (3, 2, 4, 15, 2), True), + ((1, 1, 16, 15), (3, _vars["a"], 4, 15, 2), True), + # ND x ND + ((3, 2, 4, 16, 15), (3, 2, 4, 15, 2), True), + ((_vars["a"], 2, 4, 16, 15), (_vars["a"], 2, 4, 15, 2), True), + ( + (_vars["a"], _vars["b"], 4, 16, 15), + (_vars["a"], _vars["b"], 4, 15, 2), + True, + ), + # ND x ND with one-length dimension + ((1, 2, 4, 16, 15), (1, 2, 4, 15, 2), True), + ((3, 2, 1, 16, 15), (3, 2, 1, 15, 2), True), + # Extra one-length dimension doesn't block broadcasting + ((3, 2, 1, 16, 15), (1, 1, 3, 2, 1, 15, 2), True), + # Not broadcasting all dims. Cannot be computed by stride-based batch gemm + ((3, 1, 1, 16, 15), (3, 2, 4, 15, 2), False), + ((3, 2, 4, 16, 15), (2, 4, 15, 2), False), + # Different shape + ((3, 4, 16, 15), (3, 2, 15, 2), False), + ((3, _vars["a"], 16, 15), (3, _vars["b"], 15, 2), False), + # Cannot prove that broadcast dimensions are equal + ((_vars["a"], 16, 15), (3, 15, 2), False), + ((3, _vars["a"], 1, 16, 15), (1, 1, 3, 2, 1, 15, 2), False), + # Reduction axis must be constant + ((3, _vars["a"]), (_vars["a"], 5), False), + ], +) +def test_is_shape_valid_for_cutlass_matmul(x_shape, y_shape, expected): + assert is_shape_valid_for_cutlass_matmul(x_shape, y_shape) == expected + + +@pytest.mark.parametrize( + "x_shape, y_shape, transpose_y, dtype", + [ + # Not broadcasting all dims. Cannot be computed by stride-based batch gemm + ((3, 1, 1, 16, 15), (3, 2, 4, 15, 2), False, "float16"), + ((3, 2, _vars["a"], 16, 15), (3, 2, 4, 15, 2), False, "float16"), + ((1, 2, 1, 16, 15), (2, 1, 4, 15, 2), False, "float16"), + ((3, 2, 4, 16, 15), (2, 4, 15, 2), True, "float16"), + ((3, 16, 15), (2, 1, 3, 15, 2), True, "float16"), + ((3, 16, 15), (_vars["a"], 1, 3, 15, 2), True, "float16"), + ((_vars["a"], 1, 3, 16, 15), (_vars["b"], 1, 3, 15, 2), True, "float16"), + ((_vars["a"], _vars["b"], 3, 16, 15), (_vars["a"], 1, 3, 15, 2), True, "float16"), + ], +) +def test_cutlass_partition_matmul_blocked(x_shape, y_shape, transpose_y, dtype): + if transpose_y: + y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2]) + + mod = get_relax_matmul_module( + x_shape, y_shape, dtype, with_bias=False, transposed_y=transpose_y + ) + mod = partition_for_cutlass(mod) + + assert len(mod.functions) == 1 + + +def test_cutlass_partition_matmul_tuple_return_blocked(): + @tvm.script.ir_module + class TransposedMatmul: + @R.function + def main( + x: R.Tensor((4, 4), "float32"), + y: R.Tensor((4, 4), "float32"), + ): + with R.dataflow(): + lv1 = R.permute_dims(y) + # Because lv1 is used by both lv2 and out, it should stay out of + # the fused function. Otherwise the fused function will return + # tuple output, which isn't possible in cutlass, e.g. + # @R.function + # def fused_relax_permute_dims_relax_matmul(...): + # R.func_attr({"Composite": "cutlass.matmul_transposed", "Primitive": 1}) + # with R.dataflow(): + # gv: R.Tensor((4, 4), dtype="float32") = R.permute_dims(y, axes=None) + # gv1: R.Tensor((4, 4), dtype="float32") = R.matmul(x, gv, out_dtype="void") + # R.output(gv, gv1) + # return (gv, gv1) # Cannot get `gv` if dispatch to cutlass kernel. + lv2 = R.matmul(x, lv1) + out = R.matmul(lv1, lv2) + R.output(out) + + return out + + mod = partition_for_cutlass(TransposedMatmul, annotate_codegen=False) + for f_var in mod.functions: + func = mod[f_var] + if func.attrs and "Composite" in func.attrs: + # verify that the function is not fused as transposed matmul + assert func.attrs["Composite"] == "cutlass.matmul" + + +def test_cutlass_partition_matmul_cyclic_dependency_blocked(): + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor((128, 128), "float16"), w: R.Tensor((128, 128), "float16")): + with R.dataflow(): + # Because lv1 depends on lv, this block should be fused as matmul instead of matmul_bias. + lv = R.matmul(x, w) + lv1 = R.power(lv, R.const(2.0, "float16")) + lv2 = R.add(lv, lv1) + R.output(lv2) + return lv2 + + mod = partition_for_cutlass(Module, annotate_codegen=False) + for f_var in mod.functions: + func = mod[f_var] + if func.attrs and "Composite" in func.attrs: + assert func.attrs["Composite"] == "cutlass.matmul" + + +@pytest.fixture(params=["float16", "float32"]) +def attention_dtype(request): + return request.param + + +@pytest.fixture( + params=[ + # B, S, N, H + (32, (8, 8), 16, (8, 8)), + (4, (16, 8), 32, (8, 8)), # s != s_kv + (4, (16, 8), 32, (8, 16)), # h != h_v + (32, (8, 8), 16, (4, 4)), # h is not aligned + (2, (8, 8), 8, (256, 256)), # needs output accumulator buffer + ] +) +def attention_size(request): + return request.param + + +def get_relax_attention_module(q, k, v, bias=None, qk_scale=None): + dtype = str(q.dtype) + + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import relax as relax_builder, tir as T + + if qk_scale is not None: + qk_scale = T.FloatImm("float32", qk_scale) + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + q = R.arg("q", R.Tensor(q.shape, dtype)) + k = R.arg("k", R.Tensor(k.shape, dtype)) + v = R.arg("v", R.Tensor(v.shape, dtype)) + if bias is not None: + bias = R.arg("bias", R.Tensor(bias.shape, dtype)) + with R.dataflow() as frame: + result = R.emit(R.nn.attention(q, k, v, bias, qk_scale)) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +@memoize("topi.tests.test_codegen_cutlass.test_attention_offload") +def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, qk_scale, dtype): + q = np.random.randn(b, s, n, h).astype(dtype) + k = np.random.randn(b, s_kv, n, h).astype(dtype) + v = np.random.randn(b, s_kv, n, h_v).astype(dtype) + qt = q.transpose(0, 2, 1, 3) # b, n, s, h + kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv + if not qk_scale == "none": + score = qt @ kt * qk_scale # b, n, s, s_kv + else: + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv + if not bias_shape == "none": + bias = np.random.randn(*bias_shape).astype(dtype) + score = score + bias.reshape(*bias_reshape) # b, n, s, s_kv + else: + bias = None + attn = tvm.topi.testing.softmax_python(score, -1) + vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v + ref = attn @ vt # b, n, s, h_v + return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + + +def test_attention_offload(attention_size, attention_dtype): + b, (s, s_kv), n, (h, h_v) = attention_size + q, k, v, _, ref = get_numpy_attention_ref( + b, s, s_kv, n, h, h_v, "none", "none", "none", attention_dtype + ) + + mod = get_relax_attention_module(q, k, v) + out = get_result_with_relax_cutlass_offload(mod, q, k, v) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +@pytest.fixture( + params=[ + # B, S, N, H, bias_shape, bias_reshape + (4, (16, 8), 32, (8, 16), (4, 32, 16, 8), (4, 32, 16, 8)), + (4, (16, 8), 32, (8, 16), (4, 16, 8), (4, 1, 16, 8)), + (4, (16, 8), 32, (8, 16), (4, 8), (4, 1, 1, 8)), + ] +) +def attention_bias_size(request): + return request.param + + +def test_attention_bias_offload(attention_bias_size, attention_dtype): + b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_bias_size + q, k, v, bias, ref = get_numpy_attention_ref( + b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, "none", attention_dtype + ) + + mod = get_relax_attention_module(q, k, v, bias) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +@pytest.fixture( + params=[ + # B, S, N, H, bias_shape, bias_reshape + (4, (16, 8), 32, (8, 16), (4, 32, 16, 8), (4, 32, 16, 8)), + (4, (16, 8), 32, (8, 16), "none", "none"), + ] +) +def attention_scale_size(request): + return request.param + + +@pytest.fixture(params=[0.01, 1e-8, -0.5, 1.23]) +def attention_scale(request): + return request.param + + +def test_attention_scale_offload(attention_scale_size, attention_scale, attention_dtype): + b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_scale_size + q, k, v, bias, ref = get_numpy_attention_ref( + b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, attention_scale, attention_dtype + ) + + mod = get_relax_attention_module(q, k, v, bias, attention_scale) + if bias is None: + out = get_result_with_relax_cutlass_offload(mod, q, k, v) + else: + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_codegen_dnnl.py b/tests/python/relax/test_codegen_dnnl.py new file mode 100644 index 000000000000..66f442f16519 --- /dev/null +++ b/tests/python/relax/test_codegen_dnnl.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import numpy as np +import tvm +import tvm.testing + +from tvm import relax +from tvm.script import relax as R +from tvm.relax.dpl import make_fused_bias_activation_pattern +from tvm.contrib.pickle_memoize import memoize + + +@tvm.script.ir_module +class Conv2dReLUx2: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + weight2: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = relax.op.nn.relu(relax.op.nn.conv2d(data, weight1, padding=(1, 1))) + conv2 = relax.op.nn.relu(relax.op.nn.conv2d(conv1, weight2, padding=(0, 0))) + R.output(conv2) + + return conv2 + + +has_dnnl = tvm.get_global_func("relax.ext.dnnl", True) + +dnnl_enabled = pytest.mark.skipif( + not has_dnnl, + reason="DNNL note enabled.", +) + +pytestmark = [dnnl_enabled] + + +def build_and_run(mod, inputs, legalize=False): + if legalize: + mod = relax.transform.LegalizeOps()(mod) + + target = tvm.target.Target("llvm") + dev = tvm.cpu() + inputs = [tvm.nd.array(inp, dev) for inp in inputs] + + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, dev) + f = vm["main"] + return f(*inputs).numpy() + + +def test_dnnl_offload(): + pat = make_fused_bias_activation_pattern( + "relax.nn.conv2d", with_bias=False, activation="relax.nn.relu" + ) + + seq = tvm.transform.Sequential( + [ + relax.transform.FuseOpsByPattern([("dnnl.conv2d_relu", pat)]), + relax.transform.MergeCompositeFunctions(), + relax.transform.RunCodegen(), + ] + ) + + @memoize("relax.tests.test_codegen_dnnl.conv2d_relu_x2") + def get_ref(): + data_np = np.random.randn(1, 64, 56, 56).astype("float32") + weight1_np = np.random.randn(64, 64, 3, 3).astype("float32") + weight2_np = np.random.randn(64, 64, 3, 3).astype("float32") + inputs = [data_np, weight1_np, weight2_np] + ref = build_and_run(Conv2dReLUx2, inputs, legalize=True) + return inputs, ref + + inputs, ref = get_ref() + + out = build_and_run(seq(Conv2dReLUx2), inputs) + + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test_dnnl_offload() diff --git a/tests/python/relax/test_codegen_tensorrt.py b/tests/python/relax/test_codegen_tensorrt.py new file mode 100644 index 000000000000..595103bc5fb7 --- /dev/null +++ b/tests/python/relax/test_codegen_tensorrt.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import numpy as np +import tvm +import tvm.testing + +from tvm import relax +from tvm.script import relax as R +from tvm.relax.dpl import make_fused_bias_activation_pattern, is_op, wildcard +from tvm.contrib.pickle_memoize import memoize + + +@tvm.script.ir_module +class Conv2dResidualBlock: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + weight2: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = relax.op.nn.relu(relax.op.nn.conv2d(data, weight1, padding=(1, 1))) + conv2 = relax.op.nn.relu(relax.op.nn.conv2d(conv1, weight2, padding=(1, 1))) + out = relax.op.add(conv2, data) + R.output(out) + + return out + + +has_tensorrt = tvm.get_global_func("relax.ext.tensorrt", True) + +tensorrt_enabled = pytest.mark.skipif( + not has_tensorrt, + reason="TENSORRT not enabled.", +) + +pytestmark = [tensorrt_enabled] + + +def build_and_run(mod, inputs_np, target, legalize=False): + if legalize: + mod = relax.transform.LegalizeOps()(mod) + + dev = tvm.device(target, 0) + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, dev) + f = vm["main"] + inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + return f(*inputs).numpy() + + +def test_tensorrt_offload(): + @memoize("relax.tests.test_codegen_tensorrt.conv2d_residual") + def get_ref(): + data_np = np.random.randn(1, 64, 56, 56).astype("float32") + weight1_np = np.random.randn(64, 64, 3, 3).astype("float32") + weight2_np = np.random.randn(64, 64, 3, 3).astype("float32") + inputs = [data_np, weight1_np, weight2_np] + ref = build_and_run(Conv2dResidualBlock, inputs, "llvm", legalize=True) + return inputs, ref + + inputs, ref = get_ref() + + conv_pat = make_fused_bias_activation_pattern( + "relax.nn.conv2d", with_bias=False, activation=None + ) + relu_pat = is_op("relax.nn.relu")(wildcard()) + add_pat = is_op("relax.add")(wildcard(), wildcard()) + + patterns = [ + ("tensorrt.nn.conv2d", conv_pat), + ("tensorrt.nn.relu", relu_pat), + ("tensorrt.add", add_pat), + ] + + params_np = {"weight1": inputs[1], "weight2": inputs[2]} + + mod = tvm.transform.Sequential( + [ + relax.transform.BindParams("main", params_np), + relax.transform.FuseOpsByPattern(patterns), + relax.transform.MergeCompositeFunctions(), + relax.transform.RunCodegen(), + ] + )(Conv2dResidualBlock) + + out = build_and_run(mod, inputs[:1], "cuda") + + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test_tensorrt_offload() diff --git a/tests/python/relax/test_codegen_tir_cutlass.py b/tests/python/relax/test_codegen_tir_cutlass.py new file mode 100644 index 000000000000..9c960ed355d3 --- /dev/null +++ b/tests/python/relax/test_codegen_tir_cutlass.py @@ -0,0 +1,709 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +import tempfile + +from tvm import relax, runtime +import tvm +import tvm.testing +from tvm import relax +import scipy +from scipy.special import erf +import numpy as np +from tvm.target import Target +from tvm.relax.vm_build import build as relax_build +from tvm.script.ir_builder import relax as R +from tvm.script.ir_builder import ir as I +from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import IRBuilder + +from tvm.relax.backend_tir import get_tir_pattern +from tvm.relax.backend_tir.contrib.cutlass import cutlass_fcodegen, compile_options + +A_TYPE = "float16" +B_TYPE = "float16" +C_TYPE = "float16" + +target = Target("cuda") + + +def f_run(rt_mod: runtime.Module, device: runtime.ndarray.Device, *input): + vm = relax.vm.VirtualMachine(rt_mod=rt_mod, device=device) + return vm["main"](*input) + + +def build(mod): + mod = relax.transform.LegalizeOps()(mod) + mod = relax.transform.AnnotateTIROpPattern()(mod) + mod = relax.transform.FuseOps()(mod) + mod = relax.transform.FuseTIR()(mod) + mod = relax.transform.SplitCallTIRByPattern(get_tir_pattern(), cutlass_fcodegen())(mod) + mod = relax.transform.DeadCodeElimination()(mod) + print(mod.script()) + f = tempfile.NamedTemporaryFile(suffix=".so", delete=True) + executable = relax_build(mod, target) + + executable.mod.export_library(f.name, **compile_options(target)) + rt_mod = runtime.load_module(f.name) + f.close() + return rt_mod + + +def build_and_run_reference(mod, inputs_np): + mod = relax.transform.LegalizeOps()(mod) + dev = tvm.device("llvm", 0) + ex = relax.build(mod, "llvm") + vm = relax.VirtualMachine(ex, dev) + f = vm["main"] + inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + return f(*inputs).numpy() + + +def constructGEMM(M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + R.output(C) + (C,) = df.output_vars + R.func_ret_value(C) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_dense(): + m, n, k = 128, 64, 256 + executable = build(constructGEMM(m, n, k)) + dev = tvm.cuda() + A = np.random.randn(m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + result = f_run(executable, dev, A_tvm, B_tvm) + np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2) + + +def constructGEMM_bias(M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, N), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + +def constructGEMM_bias2(M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((N,), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_dense_bias(): + m, n, k = 128, 64, 256 + executable = build(constructGEMM_bias(m, n, k)) + dev = tvm.cuda() + A = np.random.randn(m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(1, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +@tvm.testing.requires_cutlass +def test_cutlass_dense_bias2(): + m, n, k = 128, 64, 256 + executable = build(constructGEMM_bias2(m, n, k)) + dev = tvm.cuda() + A = np.random.randn(m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +def constructGEMM_bias_relu(M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, N), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + E = R.emit(R.nn.relu(D)) + R.output(E) + (E,) = df.output_vars + R.func_ret_value(E) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_dense_bias_relu(): + m, n, k = 128, 64, 256 + executable = build(constructGEMM_bias_relu(m, n, k)) + dev = tvm.cuda() + A = np.random.randn(m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(1, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), np.maximum(A @ B + bias, 0), rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + R.output(C) + (C,) = df.output_vars + R.func_ret_value(C) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + result = f_run(executable, dev, A_tvm, B_tvm) + np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM2(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((batch, K, N), B_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + R.output(C) + (C,) = df.output_vars + R.func_ret_value(C) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense2(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM2(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(b, k, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + result = f_run(executable, dev, A_tvm, B_tvm) + np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM_bias(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, N), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense_bias(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM_bias(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(1, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM_bias2(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((N,), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense_bias2(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM_bias2(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM_bias2_gelu(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((N,), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + E = R.emit(R.nn.gelu(D)) + R.output(E) + (E,) = df.output_vars + R.func_ret_value(E) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense_bias2_gelu(): + b, m, n, k = 2, 128, 64, 256 + executable = build(constructBatchGEMM_bias2_gelu(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + C = A @ B + bias + O = 0.5 * C * (1 + erf(C / np.sqrt(2))) + np.testing.assert_allclose(result.numpy(), O, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM_bias2_mul(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((N,), A_TYPE) + ) # pylint: disable=invalid-name + residual = R.arg("residual", relax.TensorStructInfo((batch, M, N), A_TYPE)) + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + E = R.emit(R.multiply(D, residual)) + R.output(E) + (E,) = df.output_vars + R.func_ret_value(E) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense_bias2_mul(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM_bias2_mul(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(n).astype("float16") + residual = np.random.randn(b, m, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + residual_tvm = tvm.nd.array(residual, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, residual_tvm) + np.testing.assert_allclose(result.numpy(), ((A @ B) + bias) * residual, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM2_bias(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((batch, K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, N), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + +@tvm.testing.requires_cutlass +def test_cutlass_batch_dense2_bias(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM2_bias(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(b, k, n).astype("float16") + bias = np.random.randn(1, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +def constructConv2D(N, C, H, W, KH, KW, O, strides, padding, dilation): + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import ir as I + from tvm.script.ir_builder import relax as R + from tvm.script.ir_builder import tir as T + + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + x = R.arg( + "x", relax.TensorStructInfo((N, H, W, C), A_TYPE) + ) # pylint: disable=invalid-name + w = R.arg( + "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit( + R.nn.conv2d( + x, + w, + strides=strides, + padding=padding, + dilation=dilation, + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype=C_TYPE, + ) + ) + R.output(C) + (C,) = df.output_vars + R.func_ret_value(C) + mod = ib.get() + return mod + + +@tvm.testing.requires_cutlass +def test_cutlass_conv2d(): + n, c, h, w = 1, 3, 224, 224 + kh, kw, o = 3, 3, 64 + for strides in [(1, 1), (2, 2)]: + for padding in [(0, 0), (3, 3)]: + for dilation in [(1, 1), (4, 4)]: + mod = constructConv2D(n, c, h, w, kh, kw, o, strides, padding, dilation) + executable = build(mod) + dev = tvm.cuda() + np.random.seed(0) + A = np.random.randn(n, h, w, c).astype("float16") + B = np.random.randn(o, kh, kw, c).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + result = f_run(executable, dev, A_tvm, B_tvm) + result_ref = build_and_run_reference(mod, [A, B]) + np.testing.assert_allclose( + result.numpy(), + result_ref, + rtol=5e-2, + atol=5e-2, + ) + + +def constructConv2D_bias(N, C, H, W, KH, KW, O, strides, padding, dilation): + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import ir as I + from tvm.script.ir_builder import relax as R + from tvm.script.ir_builder import tir as T + + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + x = R.arg( + "x", relax.TensorStructInfo((N, H, W, C), A_TYPE) + ) # pylint: disable=invalid-name + w = R.arg( + "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit( + R.nn.conv2d( + x, + w, + strides=strides, + padding=padding, + dilation=dilation, + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype=C_TYPE, + ) + ) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + mod = ib.get() + return mod + + +@tvm.testing.requires_cutlass +def test_cutlass_conv2d_bias(): + c, h, w = 3, 224, 224 + kh, kw, o = 3, 3, 64 + for n in [1, 2]: + for strides in [(1, 1), (2, 2)]: + for padding in [(0, 0), (3, 3)]: + for dilation in [(1, 1), (4, 4)]: + mod = constructConv2D_bias(n, c, h, w, kh, kw, o, strides, padding, dilation) + executable = build(mod) + dev = tvm.cuda() + np.random.seed(0) + A = np.random.randn(n, h, w, c).astype("float16") + B = np.random.randn(o, kh, kw, c).astype("float16") + bias = np.random.randn(1, 1, 1, o).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + result_ref = build_and_run_reference(mod, [A, B, bias]) + np.testing.assert_allclose( + result.numpy(), + result_ref, + rtol=5e-2, + atol=5e-2, + ) + + +def constructConv2D_bias_add(N, C, H, W, KH, KW, O, OH, OW, strides, padding, dilation): + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import ir as I + from tvm.script.ir_builder import relax as R + from tvm.script.ir_builder import tir as T + + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + x = R.arg( + "x", relax.TensorStructInfo((N, H, W, C), A_TYPE) + ) # pylint: disable=invalid-name + w = R.arg( + "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE) + ) # pylint: disable=invalid-name + res = R.arg( + "res", relax.TensorStructInfo((N, OH, OW, O), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit( + R.nn.conv2d( + x, + w, + strides=strides, + padding=padding, + dilation=dilation, + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype=C_TYPE, + ) + ) + D = R.emit(R.add(C, bias)) + E = R.emit(R.add(D, res)) + R.output(E) + (E,) = df.output_vars + R.func_ret_value(E) + mod = ib.get() + return mod + + +@tvm.testing.requires_cutlass +def test_cutlass_conv2d_bias_add(): + n, c, h, w = 2, 3, 224, 224 + kh, kw, o = 3, 3, 64 + for strides in [(1, 1), (2, 2)]: + for padding in [(0, 0), (3, 3)]: + for dilation in [(1, 1), (4, 4)]: + oh = (h + 2 * padding[0] - dilation[0] * (kh - 1) - 1) // strides[0] + 1 + ow = (w + 2 * padding[1] - dilation[1] * (kw - 1) - 1) // strides[1] + 1 + mod = constructConv2D_bias_add( + n, c, h, w, kh, kw, o, oh, ow, strides, padding, dilation + ) + executable = build(mod) + dev = tvm.cuda() + np.random.seed(0) + A = np.random.randn(n, h, w, c).astype("float16") + B = np.random.randn(o, kh, kw, c).astype("float16") + bias = np.random.randn(1, 1, 1, o).astype("float16") + res = np.random.randn(n, oh, ow, o).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + res_tvm = tvm.nd.array(res, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, res_tvm) + result_ref = build_and_run_reference(mod, [A, B, bias, res]) + np.testing.assert_allclose( + result.numpy(), + result_ref, + rtol=5e-2, + atol=5e-2, + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py new file mode 100644 index 000000000000..b85543cafcb8 --- /dev/null +++ b/tests/python/relax/test_dataflow_pattern.py @@ -0,0 +1,1231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm.testing + +from tvm import relay, relax +from tvm.relax.dpl import * +from tvm.relax.analysis import get_var2val +from tvm import relax as rx, tir +from tvm.script import relax as R, tir as T + + +@tvm.script.ir_module +class Module: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + k = T.int32() + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + C = T.match_buffer(z, (32, 32)) + + for (i0, j0, k0) in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + @T.prim_func + def tir_relu(x: T.handle, y: T.handle): + T.func_attr({"global_symbol": "tir_relu"}) + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + for (i, j) in T.grid(32, 32): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.max(A[vi, vj], 0.0) + + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + cls = Module + with R.dataflow(): + lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) + R.output(lv1) + return lv1 + + +main_fn = Module["main"] +bindings = main_fn.body.blocks[0].bindings + +## Node-wise Matching +def test_expr_pattern(): + ep = is_expr(rx.Var("x")) + assert isinstance(ep, ExprPattern) + assert isinstance(ep.expr, rx.Var) + + +def test_var_pattern(): + v = is_var("x") + assert isinstance(v, VarPattern) + assert v.name == "x" + assert v.match(rx.Var("x")) + assert is_var().match(rx.Var("x")) + assert is_var().match(rx.DataflowVar("x")) # DataflowVar is also a Var + assert not v.match(rx.GlobalVar("x")) + + +def test_dataflow_var_pattern(): + v = is_dfv("x") + assert isinstance(v, DataflowVarPattern) + assert v.name == "x" + assert v.match(rx.DataflowVar("x")) + assert not v.match(rx.GlobalVar("x")) + assert is_dfv().match(bindings[0].var) + + +def test_global_var_pattern(): + assert is_gv("x").match(rx.GlobalVar("x")) + assert is_gv().match(rx.GlobalVar("x")) + assert not is_gv("x").match(rx.GlobalVar("y")) + assert not is_gv("x").match(rx.Var("x")) + + +def test_constant_pattern(): + c = is_const() + assert isinstance(c, ConstantPattern) + assert c.match(rx.const([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]])) + + +def test_wildcard_pattern(): + wc = wildcard() + assert isinstance(wc, WildcardPattern) + assert wc.match(rx.Var("x")) + + +def test_call_pattern(): + wc1 = wildcard() + wc2 = wildcard() + c = is_op("relax.add")(wc1, wc2) + assert isinstance(c, CallPattern) + assert isinstance(c.args[0], WildcardPattern) + assert isinstance(c.args[1], WildcardPattern) + assert c.match(rx.op.add(rx.Var("x"), rx.Var("y"))) + + +def test_function_pattern(): + wc1 = wildcard() + wc2 = wildcard() + f = FunctionPattern([wc1, wc2], is_op("relax.add")(wc1, wc2)) + assert isinstance(f, FunctionPattern) + assert isinstance(f.params[0], WildcardPattern) + assert isinstance(f.params[1], WildcardPattern) + assert isinstance(f.body, CallPattern) + assert isinstance(f.body.args[0], WildcardPattern) + assert isinstance(f.body.args[1], WildcardPattern) + x = rx.Var("x", R.Tensor("float32")) + y = rx.Var("y", R.Tensor("float32")) + assert f.match(rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32"))) + assert not f.match( + rx.Function([x, y], rx.op.multiply(x, y), ret_struct_info=R.Tensor("float32")) + ) + + +def test_tuple_pattern(): + wc1 = wildcard() + wc2 = is_dfv() + t = is_tuple([wc1, wc2]) + assert isinstance(t, TuplePattern) + assert isinstance(t.fields[0], WildcardPattern) + assert isinstance(t.fields[1], DataflowVarPattern) + assert t.match(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")])) + assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.GlobalVar("y")])) + assert not t.match(rx.Tuple([])) + assert t[0].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)) + assert t[1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1)) + # Negative index is also allowed + assert t[-1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1)) + # None means any index. + assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)) + assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1)) + with pytest.raises(IndexError): + t[2] # index cannot be greater than or equal to the tuple size. + + +def test_unordered_tuple_pattern(): + t = is_tuple([is_const(), is_dfv()], unordered=True) + assert isinstance(t, UnorderedTuplePattern) + assert isinstance(t.fields[0], ConstantPattern) + assert isinstance(t.fields[1], DataflowVarPattern) + assert t.match(rx.Tuple([rx.const([]), rx.DataflowVar("x")])) + assert t.match(rx.Tuple([rx.DataflowVar("x"), rx.const([])])) + assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.DataflowVar("y")])) + assert not t.match(rx.Tuple([])) + + +def test_tuple_get_item_pattern(): + assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match( + rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0) + ) + assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match( + rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0) + ) + + +def test_or_pattern(): + dfv_or_gv = is_dfv("x") | is_gv("x") + assert isinstance(dfv_or_gv, OrPattern) + assert dfv_or_gv.match(rx.DataflowVar("x")) + assert dfv_or_gv.match(rx.GlobalVar("x")) + assert not dfv_or_gv.match(rx.Var("x")) + assert not dfv_or_gv.match(rx.DataflowVar("y")) + assert not dfv_or_gv.match(rx.GlobalVar("y")) + + +def test_and_pattern(): + # float[2, 3, 3] + f32_233 = wildcard().has_shape((2, 3, 3)) & has_dtype("float32") + assert isinstance(f32_233, AndPattern) + assert f32_233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32"))) + assert not f32_233.match(rx.Var("x", R.Tensor((3, 3, 3), "float32"))) + assert not f32_233.match(rx.Var("x", R.Tensor("float32", ndim=3))) + + +def test_not_pattern(): + no_shape233 = ~wildcard().has_shape((2, 3, 3)) + assert isinstance(no_shape233, NotPattern) + assert no_shape233.match(rx.Var("x", R.Tensor((3, 3, 3), "float32"))) + assert not no_shape233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32"))) + + +def test_type_pattern(): + assert wildcard().has_type(rx.DynTensorType(2, "float32")).match(bindings[0].var) + + +def test_dtype_pattern(): + dtype = "float16" + pattern = has_dtype(dtype) + assert isinstance(pattern, DataTypePattern) + assert pattern.dtype == dtype + assert has_dtype("float32").match(bindings[0].var) + + +def test_shape_pattern(): + shape = [32, 32] + pattern = wildcard().has_shape(shape) + assert isinstance(pattern, ShapePattern) + tvm.ir.structural_equal(pattern.shape, shape) + assert pattern.match(bindings[0].var) + assert wildcard().has_shape([32, 32]).match(bindings[0].var) + n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64") + symsh_var = rx.Var("x", R.Tensor([n, m, n + m], "float32")) + assert wildcard().has_shape([n, m, n + m]).match(symsh_var) + assert wildcard().has_shape([n, m, m + n]).match(symsh_var) # + is commutative. + assert not wildcard().has_shape([1, 2, 3]).match(symsh_var) + assert not wildcard().has_shape([m, n, n + m]).match(symsh_var) + + +def test_prim_arr_pattern(): + """ + The difference between is_shape and has_shape is that: + 1) is_shape directly matches a shape (e.g., as an argument); + 2) has_shape matches a tensor and puts assumptions on the tensor's shape. + """ + pattern = is_shape([32, 32]) + assert pattern[0] == 32 + assert pattern[1] == 32 + assert isinstance(pattern, PrimArrPattern) + assert pattern.match(rx.get_shape_of(bindings[0].var)) + n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64") + symbolic_shape = rx.ShapeExpr([n, m, n + m]) + assert is_shape([n, m, n + m]).match(symbolic_shape) + assert not is_shape([n, m, n * m]).match(symbolic_shape) + + +def test_extern_fn_pattern(): + pattern = ExternFuncPattern("test.blockbuilder.nop") + assert pattern.match(rx.ExternFunc("test.blockbuilder.nop")) + + +def test_op_attr(): + x = rx.Var("x", R.Tensor("float32")) + y = rx.Var("y", R.Tensor("float32")) + conv2d = relay.nn.conv2d(x, y, kernel_size=(3, 3)) + xp = is_var("x") + yp = is_var("y") + # TODO(@yuchen): reenable the assert after figuring out why it fails + # assert is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size": [3, 3]}).match(conv2d) + assert not is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size": [4, 3]}).match(conv2d) + assert not is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size_": [3, 3]}).match(conv2d) + + +def test_match_call_attr(): + x = rx.Var("x", R.Tensor("float32")) + y = rx.Var("y", R.Tensor("float32")) + fn = rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32")) + annotated_fn = fn.with_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"}) + xp = is_var("x") + yp = is_var("y") + root_pattern = FunctionPattern([xp, yp], is_op("relax.add")(xp, yp)) + assert root_pattern.has_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"}).match( + annotated_fn + ) + + assert root_pattern.has_attr({"Codegen": "test-codegen"}).match(annotated_fn) + assert not root_pattern.has_attr({"ping": "pong"}).match(annotated_fn) + assert root_pattern.has_attr({}).match(annotated_fn) + + +def test_is_call_tir(): + lv1_val = bindings[1].value + var2val = get_var2val(Module["main"]) + assert is_call_tir("tir_relu").match(lv1_val) + assert is_call_tir("tir_relu", [is_call_tir("tir_matmul")]).match(lv1_val, var2val=var2val) + assert not is_call_tir("tir_relu", [is_call_tir("tir_relu")]).match(lv1_val, var2val=var2val) + + +@R.function +def simple_call_packed( + x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") +) -> R.Tensor: + gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) + return gv0 + + +def test_varg_default_wildcard(): + expr = simple_call_packed.body.blocks[0].bindings[0].value + yes_pattern_explicit = ExternFuncPattern("test.vm.mul")(wildcard(), wildcard()) + yes_pattern_implicit = ExternFuncPattern("test.vm.mul")(varg_default_wildcard=True) + no_pattern = ExternFuncPattern("test.vm.mul")(wildcard()) + + assert yes_pattern_explicit.match(expr) + assert yes_pattern_implicit.match(expr) + assert not no_pattern.match(expr) + + +def test_simple_call_packed(): + expr = simple_call_packed.body.blocks[0].bindings[0].value + assert is_call_packed("test.vm.mul").match(expr) + assert is_call_packed("test.vm.mul", [is_var("x"), is_var("w")]).match(expr) + + +## Graph-wise Matching +def test_simple_used_by(): + with PatternContext() as ctx: + n0 = is_var("x") # x is a free var (fn arg) + n1 = wildcard() + n0 ^ n1 + dfb = main_fn.body.blocks[0] + matched = ctx.match_dfb(dfb) + assert matched + assert matched[n0] == main_fn.params[0] + assert matched[n1] == dfb.bindings[0].var + + +def test_simple_call_tir_edge(): + with PatternContext() as ctx: + n0 = is_call_tir("tir_matmul") + n1 = is_call_tir("tir_relu") + n0.used_by(n1) + dfb = main_fn.body.blocks[0] + matched = ctx.match_dfb(dfb) + assert matched + assert matched[n0] == dfb.bindings[0].var + assert matched[n1] == dfb.bindings[1].var + + +def test_simple_oub(): + with PatternContext() as ctx: + n0 = is_call_tir("tir_matmul") + n1 = is_call_tir("tir_relu") + n0 >> n1 + dfb = main_fn.body.blocks[0] + matched = ctx.match_dfb(dfb) + assert matched + assert matched[n0] == dfb.bindings[0].var + assert matched[n1] == dfb.bindings[1].var + + +def test_counter_syntax_match(): + with PatternContext() as ctx: + n0 = is_call_dps_packed("extern_matmul") + n1 = is_call_dps_packed("extern_impossible") + n0 >> n1 + dfb = main_fn.body.blocks[0] + assert not ctx.match_dfb(dfb) + + with PatternContext() as ctx: + n0 = is_call_dps_packed("extern_matmul") + n1 = is_call_dps_packed("extern_impossible") + n0 ^ n1 + dfb = main_fn.body.blocks[0] + assert not ctx.match_dfb(dfb) + + +@tvm.script.ir_module +class Diamond: + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + # matmul + # / \ + # relu sigmoid + # \ / + # add + lv0 = R.call_dps_packed("extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("extern_relu", (lv0,), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) + lv3 = R.call_dps_packed("extern_add", (lv1, lv2), R.Tensor((32, 32), dtype="float32")) + R.output(lv3) + return lv3 + + +def test_diamond(): + with PatternContext() as ctx: + n0 = is_call_dps_packed("extern_matmul") + n1 = is_call_dps_packed("extern_relu") + n2 = is_call_dps_packed("extern_sigmoid") + n3 = is_call_dps_packed("extern_add") + + n0 ^ n1 + n0 ^ n2 + n1 >> n3 + n2 >> n3 + + dfb = Diamond["main"].body.blocks[0] + + assert ctx.match_dfb(dfb) + # simplify it with fork_to + with PatternContext() as ctx: + n1 = is_call_dps_packed("extern_relu") + n2 = is_call_dps_packed("extern_sigmoid") + n3 = is_call_dps_packed("extern_add") + + is_call_dps_packed("extern_matmul").fork_to(n1, n2) + n1 >> n3 + n2 >> n3 + + dfb = Diamond["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + +def test_diamond_counter_oub(): + with PatternContext() as ctx: + n0 = is_call_dps_packed("extern_matmul") + n1 = is_call_dps_packed("extern_relu") + n2 = is_call_dps_packed("extern_sigmoid") + n3 = is_call_dps_packed("extern_add") + + n0 >> n1 + n0 >> n2 + n1 >> n3 + n2 >> n3 + + dfb = Diamond["main"].body.blocks[0] + assert not ctx.match_dfb(dfb) + + +@tvm.script.ir_module +class SmallDiamond: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + # relu + # / \ + # \ / + # add + lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("my_add", (lv0, lv0), R.Tensor((32, 32), dtype="float32")) + R.output(lv1) + return lv1 + + +@tvm.script.ir_module +class SmallParallel: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + # relu relu + # \ / + # add + lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed("my_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")) + R.output(lv2) + return lv2 + + +def test_distinguish_diamond_and_parallel(): + # relay pattern lang cannot distinguish the two cases above. + diamond = SmallDiamond["main"].body.blocks[0] + parallel = SmallParallel["main"].body.blocks[0] + + with PatternContext() as ctx: + # describe a diamond pattern + fork = is_call_dps_packed("my_relu") + join = is_call_dps_packed("my_add") + fork.only_used_by(join, index=0) + fork.only_used_by(join, index=1) + + assert ctx.match_dfb(diamond) + assert not ctx.match_dfb(parallel) + + with PatternContext() as ctx: + # describe a parallel pattern + join = is_call_dps_packed("my_add") + # Due to one-one matching: + # is_call_dps_packed("my_relu") creates the 1st relu + is_call_dps_packed("my_relu") >> join + # is_call_dps_packed("my_relu") + # creates the another different relu (obj address is different) + is_call_dps_packed("my_relu") >> join + + assert ctx.match_dfb(parallel) + assert not ctx.match_dfb(diamond) + + +@tvm.script.ir_module +class CBRx2: + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w0: R.Tensor((1, 1), "float32"), + bias0: R.Tensor((32, 32), "float32"), + w1: R.Tensor((1, 1), "float32"), + bias1: R.Tensor((32, 32), "float32"), + ) -> R.Tensor: + # R.TensorRT's CBR Optimization Pattern + # input + # / \ + # cbr0 cbr1 + # \ / + # concat + with R.dataflow(): + lv0 = R.call_dps_packed("conv1x1", (x, w0), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("bias_add", (lv0, bias0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed("my_relu", (lv1), R.Tensor((32, 32), dtype="float32")) + lv3 = R.call_dps_packed("conv1x1", (x, w1), R.Tensor((32, 32), dtype="float32")) + lv4 = R.call_dps_packed("bias_add", (lv3, bias1), R.Tensor((32, 32), dtype="float32")) + lv5 = R.call_dps_packed("my_relu", (lv4), R.Tensor((32, 32), dtype="float32")) + lv6 = R.call_dps_packed("concat", (lv2, lv5), R.Tensor((32, 64), dtype="float32")) + R.output(lv6) + return lv6 + + +def test_single_cbr(): + with PatternContext() as ctx: + ( + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") + ) + dfb = CBRx2["main"].body.blocks[0] + matched = ctx.match_dfb(dfb) + assert matched + + with PatternContext() as ctx: + chain = ( + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") + ) + dfb = CBRx2["main"].body.blocks[0] + # we want to specifically match the first CBR (lv0) + matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var) + assert matched + assert matched[chain[0]] == dfb.bindings[0].var + # we want to specifically match the second CBR (lv3) + matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[3].var) + assert matched + assert matched[chain[0]] == dfb.bindings[3].var + + +def test_counter_single_crb(): + with PatternContext() as ctx: + ( + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("my_relu") + >> is_call_dps_packed("bias_add") + ) + dfb = CBRx2["main"].body.blocks[0] + assert not ctx.match_dfb(dfb) + # Quickly fails unpromising matches by assuming `start_hint` must be matched by a pattern. + # This is usually faster than the full match: + # Full match: let one pattern to match -> all Var: complexity ~ #Var + # must_include_hint: let `start_hint` to match -> all patterns: complexity ~ #patterns + # Usually #patterns is much smaller than #Var, so this is faster. + assert not ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var, must_include_hint=True) + + +def test_nested_context(): + dfb = CBRx2["main"].body.blocks[0] + with PatternContext() as ctx0: + ( + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") + ) + with PatternContext() as ctx1: + is_call_dps_packed("conv1x1") >> is_call_dps_packed("my_relu") # pattern to miss + with PatternContext() as ctx2: + is_call_dps_packed("bias_add") >> is_call_dps_packed("my_relu") + assert ctx2.match_dfb(dfb) + assert PatternContext.current() == ctx2 + assert not ctx1.match_dfb(dfb) + assert PatternContext.current() == ctx1 + assert ctx0.match_dfb(dfb) + assert PatternContext.current() == ctx0 + + +def test_two_cbr(): + with PatternContext() as ctx: + cbr0 = ( + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") + ) + cbr1 = cbr0.dup() + + assert cbr0.patterns[0] != cbr1.patterns[0] + assert cbr0.patterns[1] != cbr1.patterns[1] + assert cbr0.patterns[2] != cbr1.patterns[2] + + is_var("x").fork_to(cbr0, cbr1) + dfb = CBRx2["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + with PatternContext() as ctx: + # Deny the pattern + cbr0 = ( + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") + ) + cbr1 = cbr0.dup() + + # input has no fork at y. + is_var("y").fork_to(cbr0, cbr1) + dfb = CBRx2["main"].body.blocks[0] + assert not ctx.match_dfb(dfb) + + +def test_two_matmul(): + # Same as Figure 2(a) in TASO paper. + @tvm.script.ir_module + class MatMul2: + @R.function + def main( + a: R.Tensor((32, 16), "float32"), + b: R.Tensor((16, 48), "float32"), + c: R.Tensor((48, 32), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.call_dps_packed("matmul", (a, b), R.Tensor((32, 48), dtype="float32")) + lv1 = R.call_dps_packed("matmul", (lv0, c), R.Tensor((32, 32), dtype="float32")) + R.output(lv1) + return lv1 + + with PatternContext() as ctx: + is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") + dfb = MatMul2["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + with PatternContext() as ctx: + is_call_dps_packed("matmul").has_shape([32, 48]) >> is_call_dps_packed("matmul").has_shape( + [32, 32] + ) + dfb = MatMul2["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + with PatternContext() as ctx: + is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") + dfb = MatMul2["main"].body.blocks[0] + # Three MatMul cannot match + assert not ctx.match_dfb(dfb) + + +def test_concat_mm_split(): + # Same as Figure 2(b) in TASO paper. + @tvm.script.ir_module + class CMS: + @R.function + def main( + a: R.Tensor((32, 32), "float32"), + b: R.Tensor((16, 32), "float32"), + c: R.Tensor((16, 32), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.call_dps_packed("my_concat", (b, c), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("my_matmul", (a, lv0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed( + "my_split", + (lv1,), + [R.Tensor((16, 32), dtype="float32"), R.Tensor((16, 32), dtype="float32")], + ) + lv3 = R.TupleGetItem(lv2, 0) + lv4 = R.TupleGetItem(lv2, 1) + lv5 = R.add(lv3, lv4) + R.output(lv5) + return lv5 + + with PatternContext() as ctx: + ( + is_call_dps_packed("my_concat") + >> is_call_dps_packed("my_matmul") + >> is_call_dps_packed("my_split") + ) + dfb = CMS["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + with PatternContext() as ctx: + split = is_call_dps_packed("my_split") + lv3 = TupleGetItemPattern(split, 0).has_shape([16, 32]) + lv4 = TupleGetItemPattern(split, 1).has_shape([16, 32]) + split.fork_to(lv3, lv4) + add = is_op("relax.add")(lv3, lv4) + # TODO(@ganler): simplify this through implicit graph pattern. + lv3 >> add + lv4 >> add + + dfb = CMS["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + +def test_self_attention(): + # The example comes from. + # https://developer.nvidia.com/blog/nlu-with-tensorrt-bert/ + @tvm.script.ir_module + class SelfAttention: + @R.function + def main( + x: R.Tensor(("b", "s", "n", "h"), "float32"), + wq: R.Tensor(("h", "h"), "float32"), + wk: R.Tensor(("h", "h"), "float32"), + wv: R.Tensor(("h", "h"), "float32"), + ) -> R.Tensor: + b, s, n, h = T.int64(), T.int64(), T.int64(), T.int64() + with R.dataflow(): + fcq = R.call_dps_packed("my_fc", (x, wq), R.Tensor((b, s, n, h), dtype="float32")) + tpq = R.call_dps_packed( + "my_transpose", (fcq,), R.Tensor((b, s, h, n), dtype="float32") + ) + + fck = R.call_dps_packed("my_fc", (x, wk), R.Tensor((b, s, n, h), dtype="float32")) + tpk = R.call_dps_packed( + "my_transpose", (fck,), R.Tensor((b, s, h, n), dtype="float32") + ) + + mul = R.multiply(tpq, tpk) + scale = R.multiply(mul, R.const(1.1, "float32")) + softmax = R.call_dps_packed( + "softmax", (scale,), R.Tensor((b, s, n, h), dtype="float32") + ) + + fcv = R.call_dps_packed("my_fc", (x, wv), R.Tensor((b, s, n, h), dtype="float32")) + tpv = R.call_dps_packed( + "my_transpose", (fcv,), R.Tensor((b, s, h, n), dtype="float32") + ) + + out = R.multiply(softmax, tpv) + R.output(out) + + return out + + with PatternContext() as ctx: + fc_trans_q = is_call_dps_packed("my_fc") >> is_call_dps_packed("my_transpose") + fc_trans_k = fc_trans_q.dup() + fc_trans_v = fc_trans_q.dup() + + is_var("x").fork_to(fc_trans_q, fc_trans_k, fc_trans_v) + dfb = SelfAttention["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + +def test_nested_diamond(): + @tvm.script.ir_module + class DiamondInDiamond: + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + # matmul0 matmul1 + # / \ / \ + # sigmoid2 add4 sigmoid3 + # \ / \ / + # add5 add6 + # \ / + # add7 + lv0 = R.call_dps_packed( + "extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32") + ) + lv1 = R.call_dps_packed( + "extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32") + ) + lv2 = R.call_dps_packed( + "extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32") + ) + lv3 = R.call_dps_packed( + "extern_sigmoid", (lv1), R.Tensor((32, 32), dtype="float32") + ) + lv4 = R.call_dps_packed( + "extern_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32") + ) + lv5 = R.call_dps_packed( + "extern_add", (lv2, lv4), R.Tensor((32, 32), dtype="float32") + ) + lv6 = R.call_dps_packed( + "extern_add", (lv3, lv4), R.Tensor((32, 32), dtype="float32") + ) + lv7 = R.call_dps_packed( + "extern_add", (lv5, lv6), R.Tensor((32, 32), dtype="float32") + ) + R.output(lv7) + return lv7 + + # match matmul0 diamond + with PatternContext() as ctx: + sigmoid2 = is_call_dps_packed("extern_sigmoid") + add4 = is_call_dps_packed("extern_add") + is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4) + add5 = is_call_dps_packed("extern_add") + sigmoid2 >> add5 + add4 ^ add5 + assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) + + # counter case: mis-match matmul0 diamond + with PatternContext() as ctx: + sigmoid2 = is_call_dps_packed("extern_sigmoid") + add4 = is_call_dps_packed("extern_add") + is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4) + add5 = is_call_dps_packed("extern_add") + sigmoid2 >> add5 + add4 >> add5 # not only-used-by relation + assert not ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) + + # match matmul1 diamond + with PatternContext() as ctx: + sigmoid3 = is_call_dps_packed("extern_sigmoid") + add4 = is_call_dps_packed("extern_add") + is_call_dps_packed("extern_matmul").fork_to(sigmoid3, add4) + add6 = is_call_dps_packed("extern_add") + sigmoid3 >> add6 + add4 ^ add6 + assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) + + # match add-4-5-6-7 + with PatternContext() as ctx: + add5, add6, add7 = ( + is_call_dps_packed("extern_add"), + is_call_dps_packed("extern_add"), + is_call_dps_packed("extern_add"), + ) + is_call_dps_packed("extern_add").fork_to(add5, add6) # add4 + add5 >> add7 + add6 >> add7 + assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) + + +def test_incremental_solving(): + @R.function + def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + # relu -> sigmoid -> neg + lv0 = R.call_dps_packed("extern_relu", (x), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed("extern_neg", (lv1), R.Tensor((32, 32), dtype="float32")) + R.output(lv2) + return lv2 + + relu = is_call_dps_packed("extern_relu") + sigmoid = is_call_dps_packed("extern_sigmoid") + neg = is_call_dps_packed("extern_neg") + + with PatternContext() as ctx0: + relu >> sigmoid + with PatternContext(incremental=True) as ctx1: + # because we are doing incremental solving + # relu >> sigmoid is still a constraint in this context. + # that said the total constraint is: + # relu >> sigmoid >> neg + sigmoid >> neg + assert ctx1.match_dfb(simple_chain.body.blocks[0]) + + # match relue -> sigmoid + assert ctx0.match_dfb(simple_chain.body.blocks[0]) + + +def test_incremental_solving_counter(): + @R.function + def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + # sigmoid -> neg + lv0 = R.call_dps_packed("extern_sigmoid", (x), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("extern_neg", (lv0), R.Tensor((32, 32), dtype="float32")) + R.output(lv1) + return lv1 + + relu = is_call_dps_packed("extern_relu") + sigmoid = is_call_dps_packed("extern_sigmoid") + neg = is_call_dps_packed("extern_neg") + + with PatternContext() as ctx0: + relu >> sigmoid # cannot match + + with PatternContext(incremental=False) as ctx1: + # total constraint: sigmoid >> neg + sigmoid >> neg + assert ctx1.match_dfb(simple_chain.body.blocks[0]) + + with PatternContext(incremental=True) as ctx1: + # total constraint: relu >> sigmoid >> neg + sigmoid >> neg + assert not ctx1.match_dfb(simple_chain.body.blocks[0]) + + +def test_rewrite_simple(): + @R.function + def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), "float32"): + with R.dataflow(): + x2 = R.add(x, x) + x4 = R.add(x2, x2) + R.output(x4) + return x4 + + @R.function + def expected1(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(2, "float32")) + x4: R.Tensor((16, 16), dtype="float32") = R.multiply(lv, R.const(2, "float32")) + R.output(x4) + return x4 + + @R.function + def expected2(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"): + with R.dataflow(): + x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(4, "float32")) + R.output(x4) + return x4 + + x = wildcard() + pattern = is_op("relax.add")(x, x) + + def rewriter(_, matchings): + return R.multiply(matchings[x], R.const(2, "float32")) + + rewritten = rewrite_call(pattern, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, expected1) + + add1 = is_op("relax.add")(x, x) + pattern = is_op("relax.add")(add1, add1) + + def rewriter(_, matchings): + return R.multiply(matchings[x], R.const(4, "float32")) + + rewritten = rewrite_call(pattern, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, expected2) + + # No rewriting, return the original call node as is + def rewriter(orig, _): + return orig + + rewritten = rewrite_call(pattern, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, main) + + +def test_rewrite_attention(): + @R.function + def main( + Q: R.Tensor((2, 4096, 8, 40), "float32"), + K: R.Tensor((2, 4096, 8, 40), "float32"), + V: R.Tensor((2, 4096, 8, 40), "float32"), + ) -> R.Tensor((2, 4096, 8, 40), "float32"): + with R.dataflow(): + lv58 = R.permute_dims(Q, axes=[0, 2, 1, 3]) + lv59 = R.reshape(lv58, R.shape([16, 4096, 40])) + + lv61 = R.permute_dims(K, axes=[0, 2, 1, 3]) + lv62 = R.reshape(lv61, R.shape([16, 4096, 40])) + + lv64 = R.permute_dims(V, axes=[0, 2, 1, 3]) + lv65 = R.reshape(lv64, R.shape([16, 4096, 40])) + + lv62_transposed = R.permute_dims(lv62, axes=[0, 2, 1]) + lv3_1 = R.matmul(lv59, lv62_transposed) + lv68 = R.multiply(lv3_1, R.const(0.15811388194561005, "float32")) + lv69 = R.nn.softmax(lv68, axis=-1) + lv_3 = R.matmul(lv69, lv65) + + lv71 = R.reshape(lv_3, R.shape([2, 8, 4096, 40])) + lv72 = R.permute_dims(lv71, axes=[0, 2, 1, 3]) + R.output(lv72) + + return lv72 + + @R.function + def expected( + Q: R.Tensor((2, 4096, 8, 40), dtype="float32"), + K: R.Tensor((2, 4096, 8, 40), dtype="float32"), + V: R.Tensor((2, 4096, 8, 40), dtype="float32"), + ) -> R.Tensor((2, 4096, 8, 40), dtype="float32"): + with R.dataflow(): + lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") = R.nn.attention(Q, V, K) + R.output(lv72) + return lv72 + + def BSNH_to_BSH(tensor): + return is_op("relax.reshape")(is_op("relax.permute_dims")(tensor), wildcard()) + + def BSH_to_BSNH(tensor): + return is_op("relax.permute_dims")(is_op("relax.reshape")(tensor, wildcard())) + + Q = wildcard() + K = wildcard() + V = wildcard() + + Q_3D = BSNH_to_BSH(Q) + V_3D = BSNH_to_BSH(V) + K_3D = BSNH_to_BSH(K) + + matmul1 = is_op("relax.matmul")(Q_3D, is_op("relax.permute_dims")(V_3D)) + multiply = is_op("relax.multiply")(matmul1, is_const()) + softmax = is_op("relax.nn.softmax")(multiply) + matmul2 = is_op("relax.matmul")(softmax, K_3D) + + pattern = BSH_to_BSNH(matmul2) + + def rewriter(_, matchings): + return R.nn.attention(matchings[Q], matchings[K], matchings[V]) + + rewritten = rewrite_call(pattern, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, expected) + + +def test_attention_qkv(): + @tvm.script.ir_module + class QKV_proj: + @R.function + def main( + x: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.matmul(x, w0) + lv1 = R.matmul(x, w1) + lv2 = R.matmul(x, w2) + out = (lv0, lv1, lv2) + R.output(out) + return out + + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + dfb = QKV_proj["main"].body.blocks[0] + out = ctx.match_dfb(dfb) + + assert out[Q_weight_pat].name_hint == "w0" + assert out[K_weight_pat].name_hint == "w1" + assert out[V_weight_pat].name_hint == "w2" + + +def test_attention_fake_qkv(): + @tvm.script.ir_module + class QKV_proj: + @R.function + def main( + x1: R.Tensor((2, 1024, 640), "float32"), + x2: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.matmul(x1, w0) + lv1 = R.matmul(x2, w1) + lv2 = R.matmul(x2, w2) + out = (lv0, lv1, lv2) + R.output(out) + return out + + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + dfb = QKV_proj["main"].body.blocks[0] + assert ctx.match_dfb(dfb) is None + + +def get_qkv_proj_rewriter( + inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 +): + def qkv_proj_rewriter(matchings): + inp = matchings[inp_pat] + Q_weight = matchings[Q_weight_pat] + K_weight = matchings[K_weight_pat] + V_weight = matchings[V_weight_pat] + width = Q_weight.struct_info.shape[1] + + concat = R.concat([Q_weight, K_weight, V_weight], axis=1) + matmul = R.matmul(inp, concat) + Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width]) + K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2]) + V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width * 3]) + + return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V} + + return qkv_proj_rewriter + + +def test_combine_matmul_twice(): + @R.function + def qkv_x2( + x1: R.Tensor((2, 1024, 640), "float32"), + x2: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + w3: R.Tensor((640, 640), "float32"), + w4: R.Tensor((640, 640), "float32"), + w5: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.matmul(x1, w0) + lv1 = R.matmul(x1, w1) + lv2 = R.matmul(x1, w2) + lv3 = R.matmul(x2, w3) + lv4 = R.matmul(x2, w4) + lv5 = R.matmul(x2, w5) + out = (lv0, lv1, lv2, lv3, lv4, lv5) + R.output(out) + return out + + @R.function + def expected( + x1: R.Tensor((2, 1024, 640), "float32"), + x2: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + w3: R.Tensor((640, 640), "float32"), + w4: R.Tensor((640, 640), "float32"), + w5: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv = R.concat((w0, w1, w2), axis=1) + lv1 = R.matmul(x1, lv) + lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640]) + lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280]) + lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920]) + lv2_1 = R.concat((w3, w4, w5), axis=1) + lv3 = R.matmul(x2, lv2_1, out_dtype="void") + lv3_1 = R.strided_slice(lv3, axes=[2], begin=[0], end=[640]) + lv4 = R.strided_slice(lv3, axes=[2], begin=[640], end=[1280]) + lv5 = R.strided_slice(lv3, axes=[2], begin=[1280], end=[1920]) + out = lv0, lv1_1, lv2, lv3_1, lv4, lv5 + R.output(out) + return out + + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + rewriter = get_qkv_proj_rewriter( + inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 + ) + rewritten = rewrite_bindings(ctx, rewriter, qkv_x2) + tvm.ir.assert_structural_equal(rewritten, expected) + + +def test_combine_matmul_emit_order(): + @R.function + def main( + x1: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + w0_t = R.permute_dims(w0, axes=None) + lv0 = R.matmul(x1, w0_t) + w1_t = R.permute_dims(w1, axes=None) + w1_t_t = R.permute_dims(w1_t, axes=None) + lv1 = R.matmul(x1, w1_t_t) + w2_t = R.permute_dims(w2, axes=None) + lv2 = R.matmul(x1, w2_t) + out = (lv0, lv1, lv2) + R.output(out) + return out + + @R.function + def expected( + x1: R.Tensor((2, 1024, 640), dtype="float32"), + w0: R.Tensor((640, 640), dtype="float32"), + w1: R.Tensor((640, 640), dtype="float32"), + w2: R.Tensor((640, 640), dtype="float32"), + ) -> R.Tensor: + with R.dataflow(): + w0_t = R.permute_dims(w0, axes=None) + w1_t = R.permute_dims(w1, axes=None) + w1_t_t = R.permute_dims(w1_t, axes=None) + w2_t = R.permute_dims(w2, axes=None) + lv = R.concat((w0_t, w1_t_t, w2_t), axis=1) + lv1 = R.matmul(x1, lv, out_dtype="void") + lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640]) + lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280]) + lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920]) + out = lv0, lv1_1, lv2 + R.output(out) + return out + + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + rewriter = get_qkv_proj_rewriter( + inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 + ) + rewritten = rewrite_bindings(ctx, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, expected) + + # make sure it builds + mod = tvm.IRModule() + mod["main"] = rewritten + mod = relax.transform.LegalizeOps()(mod) + + relax.build(mod, target="llvm") + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py new file mode 100644 index 000000000000..902c4785610f --- /dev/null +++ b/tests/python/relax/test_expr.py @@ -0,0 +1,258 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import relax as rx +from tvm import tir +from tvm.script import relax as R +import pytest + + +def _check_equal(x, y, map_free_vars=False): + tvm.ir.assert_structural_equal(x, y, map_free_vars) + tvm.ir.assert_structural_equal(y, x, map_free_vars) + + xhash = tvm.ir.structural_hash(x, map_free_vars) + yhash = tvm.ir.structural_hash(y, map_free_vars) + + assert xhash == yhash + + +def _check_json_roundtrip(x): + xret = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, xret, map_free_vars=True) + return xret + + +def test_var() -> None: + v0 = rx.Var("v0") + assert v0.name_hint == "v0" + assert v0._checked_type_ is None + assert v0.struct_info_ is None + shape = [54, 96] + v1 = rx.Var("v1", R.Tensor(shape, "float32")) + assert v1.name_hint == "v1" + for s0, s1 in zip(v1.struct_info.shape, shape): + assert s0 == s1 + assert v1.checked_type == rx.DynTensorType(2, "float32") + tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float32")) + + +def test_dataflow_var() -> None: + v0 = rx.DataflowVar("v0") + assert v0.name_hint == "v0" + assert v0._checked_type_ is None + assert v0.struct_info_ is None + + shape = [54, 96] + v1 = rx.DataflowVar("v1", R.Tensor(shape, "float16")) + assert v1.name_hint == "v1" + + assert v1._checked_type_ == rx.DynTensorType(2, "float16") + assert isinstance(v1, rx.DataflowVar) + tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float16")) + + +def test_tuple() -> None: + v0 = rx.Var("v0") + v1 = rx.Var("v1") + t = rx.Tuple((v0, v1)) + + assert t.fields[0] == v0 + assert t.fields[1] == v1 + assert t[0] == v0 + assert t[1] == v1 + assert t[-1] == v1 + assert t[-2] == v0 + + with pytest.raises(IndexError, match="Tuple index out of range"): + t[2] + + with pytest.raises(IndexError, match="Tuple index out of range"): + t[-3] + + +def test_match_cast() -> None: + # match_cast([16, 8], [m, n]) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + shape = rx.const([16, 8], "int32") + var = rx.Var("v0", R.Shape()) + b0 = rx.MatchCast(var, shape, R.Tensor([m, n], "int32")) + assert b0.value == shape + assert b0.pattern[0] == m + assert b0.pattern[1] == n + assert b0.var is not None + assert b0.var.checked_type == rx.ShapeType() + + # var1: R.Tensor((m, n), "float32") = + # match_cast(var0: R.Tensor("float32", ndim=-1), R.Tensor((m, n), "float32")) + value = rx.Var("value", R.Tensor("float32", ndim=-1)) + + var = rx.Var("v1", R.Tensor([m, n], "float32")) + b1 = rx.MatchCast(var, value, R.Tensor([m, n], "float32")) + assert b1.value == value + assert b1.pattern[0] == m + assert b1.pattern[1] == n + assert b1.var is not None + assert b1.var.checked_type == rx.DynTensorType(2, "float32") + + +def test_match_cast() -> None: + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + ivalue = rx.Var("input_value") + sinfo = rx.TensorStructInfo([n, m], "float32") + b0 = rx.MatchCast(rx.Var("v"), ivalue, sinfo) + assert b0.value.same_as(ivalue) + assert b0.struct_info == sinfo + _check_json_roundtrip(b0) + + +def test_var_binding() -> None: + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b0 = rx.VarBinding(v0, val) + assert b0.var.name_hint == "v0" + assert b0.value == val + + +def test_binding_block() -> None: + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.BindingBlock([b0, b1]) + assert block0.bindings[0] == b0 + assert block0.bindings[1] == b1 + + +def test_dataflow_block() -> None: + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.DataflowBlock([b0, b1]) + assert block0.bindings[0] == b0 + assert block0.bindings[1] == b1 + assert isinstance(block0, rx.DataflowBlock) + + +def test_seq_expr() -> None: + x = rx.Var("foo") + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + seqe = rx.SeqExpr(blocks, x) + assert seqe.blocks[0] == blocks[0] + assert seqe.body == x + + +def test_func(): + x = rx.Var("foo", R.Tensor(dtype="float32", ndim=2)) + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + + seqe = rx.SeqExpr(blocks, x) + ret_struct_info = R.Tensor(dtype="float32", ndim=-1) + func = rx.Function([x], seqe, ret_struct_info) + func = func.with_attr("global_symbol", "func") + assert func.params[0] == x + assert func.body == seqe + assert func.ret_struct_info == ret_struct_info + assert func.attrs["global_symbol"] == "func" + + +def test_shape_of(): + shape = [96, 54] + v1 = rx.Var("v1", R.Tensor(shape)) + s1 = rx.get_shape_of(v1) + for x, y in zip(shape, s1): + assert x == y + + +def test_shape_expr(): + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + s = rx.ShapeExpr([m, n]) + assert s.values[0] == m + assert s.values[1] == n + assert s[0] == m + assert s[1] == n + assert s[-1] == n + assert s[-2] == m + assert isinstance(s.struct_info, rx.ShapeStructInfo) + + with pytest.raises(IndexError, match="ShapeExpr index out of range"): + s[2] + + with pytest.raises(IndexError, match="ShapeExpr index out of range"): + s[-3] + + shape_expr = rx.ShapeExpr([10, 20]) + assert shape_expr.values[0] == 10 + assert shape_expr.values[1] == 20 + assert shape_expr.checked_type == rx.ShapeType(ndim=2) + tvm.ir.assert_structural_equal(shape_expr.struct_info, R.Shape((10, 20))) + + x = rx.Var("v0", R.Tensor((10, 20), "float32")) + assert x.struct_info.shape[0] == 10 + assert x.struct_info.shape[1] == 20 + assert x.struct_info.shape.checked_type == rx.ShapeType(ndim=2) + tvm.ir.assert_structural_equal(x.struct_info.shape.struct_info, R.Shape((10, 20))) + + m = tir.Var("m", "int32") + with pytest.raises( + tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64" + ): + rx.ShapeExpr([m, 3]) + + +def test_prim_value(): + pv = rx.PrimValue(tir.IntImm("int64", 1)) + assert pv.value.value == 1 + _check_equal(pv, rx.PrimValue(tir.IntImm("int64", 1))) + _check_json_roundtrip(pv) + + +def test_string_imm(): + s0 = rx.StringImm("hello") + s1 = rx.StringImm("hello") + assert s0.value == "hello" + _check_equal(s0, s1) + _check_json_roundtrip(s0) + + +def test_datatype_imm(): + d0 = rx.DataTypeImm("int32") + d1 = rx.DataTypeImm("int32") + assert d0.value == "int32" + _check_equal(d0, d1) + _check_json_roundtrip(d0) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_expr_args_converter.py b/tests/python/relax/test_expr_args_converter.py new file mode 100644 index 000000000000..bd058e897979 --- /dev/null +++ b/tests/python/relax/test_expr_args_converter.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Callable, List, Optional, Union + +import pytest +import tvm +import tvm.testing +from tvm import relax +from tvm.relax import Expr +from tvm.relax.utils import args_converter + + +def _test_base(f_checker: Callable, arg: Any, *args: Any, **kwargs: Any) -> None: + # Test converting to `Expr` + assert f_checker(arg) + # Test converting `*args` + assert isinstance(args, tuple) + assert all([f_checker(arg) for arg in args]) + # Test converting `**kwargs` + assert isinstance(kwargs, dict) + assert all([f_checker(arg) for arg in kwargs.values()]) + + +def _test_expr(arg: Expr, *args: Expr, **kwargs: Expr) -> None: + f_checker = lambda x: isinstance(x, Expr) + _test_base(f_checker, arg, *args, **kwargs) + + +def _test_optional_expr( + arg: Optional[Expr], *args: Optional[Expr], **kwargs: Optional[Expr] +) -> None: + f_checker = lambda x: x is None or isinstance(x, Expr) + _test_base(f_checker, arg, *args, **kwargs) + + +def _test_list_expr(arg: List[Expr], *args: List[Expr], **kwargs: List[Expr]) -> None: + f_checker = lambda x: isinstance(x, list) and all([isinstance(arg, Expr) for arg in x]) + _test_base(f_checker, arg, *args, **kwargs) + + +def _test_optional_list_expr( + arg: Optional[List[Expr]], *args: Optional[List[Expr]], **kwargs: Optional[List[Expr]] +) -> None: + f_checker = lambda x: x is None or ( + isinstance(x, list) and all([isinstance(arg, Expr) for arg in x]) + ) + _test_base(f_checker, arg, *args, **kwargs) + + +prim_value = 1 +str_value = "value_to_be_convert" +shape_value = (1, 1) +tuple_value = (relax.const(1), (1, 1)) +placeholder = relax.const(0) + +test_cases = [prim_value, str_value, shape_value, tuple_value, placeholder] + + +def test_args_to_expr(): + for _f in [_test_expr, _test_optional_expr]: + f = args_converter.to_expr("arg", "args", "kwargs")(_f) + for x in test_cases: + f( + x, + x, # the first argument in *args + x, # the second argument in *args + test_kwargs=x, + ) + + if _f == _test_optional_expr: + f(None, None, x, test_kwargs=None) + + +def test_args_to_list_expr(): + for _f in [_test_list_expr, _test_optional_list_expr]: + f = args_converter.to_list_expr("arg", "args", "kwargs")(_f) + for x in test_cases: + f( + [x], + [x], # the first argument in *args + [x, x], # the second argument in *args + test_kwargs=[x, (x,)], + ) + + if _f == _test_optional_list_expr: + f(None, None, [x], test_kwargs=None) + + +def test_error(): + f = args_converter.to_list_expr("arg", "args", "kwargs")(_test_list_expr) + with pytest.raises(TypeError): + f(prim_value) # fail to convert prim_value to `List[Expr]` + + +def test_auto_convert(): + for _f in [_test_expr, _test_optional_expr]: + f = args_converter.auto(_f) + for x in test_cases: + f(x, (x,), test_kwargs=x) + + if _f == _test_optional_expr: + f(None, x, test_kwargs=None) + + for _f in [_test_list_expr, _test_optional_list_expr]: + f = args_converter.auto(_f) + for x in test_cases: + f([x], [x, x], test_kwargs=[x, (x,)]) + + if _f == _test_optional_list_expr: + f(None, None, [x], test_kwargs=None) + + +def test_auto_convert_skip(): + def _test_expr_skip(arg: int, *args: Union[str, Expr], **kwargs: List[Optional[Expr]]) -> None: + f_checker = lambda x: not isinstance(x, Expr) + _test_base(f_checker, arg, *args, **kwargs) + + f = args_converter.auto(_test_expr_skip) + f(1, "str", test_kwargs=[None]) + + +def test_empty_tuple(): + def _test(arg: Expr): + assert isinstance(arg, relax.Tuple) + + f = args_converter.auto(_test) + f(()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py new file mode 100644 index 000000000000..8165107394c9 --- /dev/null +++ b/tests/python/relax/test_expr_functor.py @@ -0,0 +1,746 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relax, tir +from tvm.ir import Op +from tvm.ir.base import assert_structural_equal +from tvm.relax import PyExprMutator, PyExprVisitor +from tvm.relax.expr import ( + BindingBlock, + Call, + Constant, + DataflowBlock, + DataflowVar, + Expr, + ExternFunc, + Function, + GlobalVar, + If, + MatchCast, + SeqExpr, + ShapeExpr, + Tuple, + TupleGetItem, + PrimValue, + StringImm, + DataTypeImm, + Var, + VarBinding, +) +from tvm.script import relax as R +import pytest + +m, n = tir.Var("m", "int64"), tir.Var("n", "int64") +x = relax.Var("x", R.Tensor([n], "float32")) +y = relax.Var("y", R.Tensor([m, n], "float32")) +bb = relax.BlockBuilder() + + +@relax.expr_functor.visitor +class BasicVisitor(PyExprVisitor): + """Default ExprVisitor""" + + +class ASTLog: + """Helper class to log AST""" + + def __init__(self) -> None: + self.log = [] + self.indent = "\t" + self.level = 0 + + def push_scope(self): + self.level += 1 + + def pop_scope(self): + self.level -= 1 + + def add(self, s: str): + self.log.append(self.indent * self.level + s) + + def __str__(self) -> str: + return "\n".join(self.log) + + +@relax.expr_functor.visitor +class ASTPrinter(PyExprVisitor): + """Print relax AST in structured format. The shape of Node is ignored.""" + + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_constant_(self, op: Constant) -> None: + self.log.add("Constant") + + def visit_global_var_(self, op: GlobalVar) -> None: + self.log.add("GlobalVar") + + def visit_tuple_(self, op: Tuple) -> None: + self.log.add("Tuple") + self.log.push_scope() + for field in op.fields: + self.visit_expr(field) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_dataflow_var_(self, op: DataflowVar) -> None: + self.log.add("DataflowVar") + + def visit_function_(self, op: Function) -> None: + self.log.add("Function") + self.log.push_scope() + for param in op.params: + self.visit_var_def(param) + + self.visit_expr(op.body) + self.log.pop_scope() + + def visit_call_(self, op: Call) -> None: + self.log.add("Call") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_if_(self, op: If) -> None: + self.log.add("If") + self.log.push_scope() + self.visit_expr(op.cond) + self.visit_expr(op.true_branch) + self.visit_expr(op.false_branch) + self.log.pop_scope() + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + def visit_tuple_getitem_(self, op: TupleGetItem) -> None: + self.log.add("TupleGetItem") + self.log.push_scope() + self.visit_expr(op.tuple_value) + self.log.pop_scope() + + def visit_prim_value_(self, op: PrimValue) -> None: + self.log.add("PrimValue") + + def visit_string_imm_(self, op: StringImm) -> None: + self.log.add("StringImm") + + def visit_data_type_imm_(self, op: DataTypeImm) -> None: + self.log.add("DataTypeImm") + + def visit_shape_expr_(self, op: ShapeExpr) -> None: + self.log.add("ShapeExpr") + + def visit_extern_func_(self, op: ExternFunc) -> None: + self.log.add("ExternFunc") + + def visit_seq_expr_(self, op: SeqExpr) -> None: + self.log.add("SeqExpr") + self.log.push_scope() + for block in op.blocks: + self.visit_binding_block(block) + self.visit_expr(op.body) + self.log.pop_scope() + + def visit_var_binding_(self, binding: VarBinding) -> None: + self.log.add("VarBinding") + self.log.push_scope() + self.visit_expr(binding.value) + self.visit_var_def(binding.var) + self.log.pop_scope() + + def visit_match_cast_(self, binding: MatchCast) -> None: + self.log.add("MatchCast") + self.log.push_scope() + self.visit_var_def(binding.var) + self.visit_expr(binding.value) + self.log.pop_scope() + + def visit_binding_block_(self, block: BindingBlock) -> None: + self.log.add("BindingBlock") + self.log.push_scope() + for binding in block.bindings: + self.visit_binding(binding) + self.log.pop_scope() + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + self.log.add("DataflowBlock") + self.log.push_scope() + for binding in block.bindings: + self.visit_binding(binding) + self.log.pop_scope() + + def visit_var_def_(self, var: Var) -> None: + self.log.add("VarDef") + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + self.log.add("DataflowVarDef") + + +@relax.expr_functor.mutator +class BasicMutator(PyExprMutator): + """Default ExprMutator""" + + +@relax.expr_functor.mutator +class ASTPostPrinterMutator(PyExprMutator): + """Print relax AST in the post order format.""" + + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_constant_(self, op: Constant) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Constant") + return op + + def visit_global_var_(self, op: GlobalVar) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("GlobalVar") + return op + + def visit_tuple_(self, op: Tuple) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Tuple") + return op + + def visit_var_(self, op: Var) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Var") + return op + + def visit_dataflow_var_(self, op: DataflowVar) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("DataflowVar") + return op + + def visit_function_(self, op: Function) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Function") + return op + + def visit_call_(self, op: Call) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Call") + return op + + def visit_if_(self, op: If) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("If") + return op + + def visit_op_(self, op: Op) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Op") + return op + + def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("TupleGetItem") + return op + + def visit_prim_value_(self, op: PrimValue) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("PrimValue") + return op + + def visit_string_imm_(self, op: StringImm) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("StringImm") + return op + + def visit_data_type_imm_(self, op: DataTypeImm) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("DataTypeImm") + return op + + def visit_shape_expr_(self, op: ShapeExpr) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("ShapeExpr") + return op + + def visit_extern_func_(self, op: ExternFunc) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("ExternFunc") + return op + + def visit_seq_expr_(self, op: SeqExpr) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("SeqExpr") + return op + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Identical with ExprMutator::VisitBinding_(const VarBindingNode* binding) on the C++ side.""" + new_value = self.visit_expr(binding.value) + new_var = self.visit_var_def(binding.var) + + self.log.add("VarBinding") + if binding.var.same_as(new_var) and binding.value.same_as(new_value): + self.builder_.emit_normalized(binding) + return + + temp = self.with_struct_info(new_var, new_value.struct_info) + if not temp.same_as(new_var): + new_var = temp + self.set_var_remap(binding.var.vid, new_var) + + self.builder_.emit_normalized(VarBinding(new_var, new_value)) + + def visit_match_cast_(self, binding: MatchCast) -> None: + """Identical with ExprMutator::VisitBinding_(const MatchCastNode* binding) on the C++ side.""" + new_var = self.visit_var_def(binding.var) + new_value = self.visit_expr(binding.value) + + temp = self.with_struct_info(new_var, binding.struct_info) + if not temp.same_as(new_var): + new_var = temp + self.set_var_remap(binding.var.vid, new_var) + + self.log.add("MatchCast") + self.builder_.emit_normalized(MatchCast(new_var, new_value, binding.struct_info)) + + def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: + """Identical with ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) on the C++ side.""" + self.builder_._begin_binding_block() + for binding in block.bindings: + self.visit_binding(binding) + self.log.add("BindingBlock") + return self.builder_._end_block() + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + """Identical with ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) on the C++ side.""" + self.builder_._begin_dataflow_block() + for binding in block.bindings: + self.visit_binding(binding) + self.log.add("DataflowBlock") + return self.builder_._end_block() + + def visit_var_def_(self, var: Var) -> None: + """Identical with ExprMutator::VisitVarDef_(const VarNode* var) on the C++ side.""" + self.log.add("VarDef") + return var + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + """Identical with ExprMutator::VisitVarDef_(const DataflowVarNode* var) on the C++ side.""" + self.log.add("DataflowVarDef") + return var + + +def basic_check(expr, visitor_str, mutator_str): + def visit(f, expr): + if isinstance(expr, relax.Expr): + return f.visit_expr(expr) + elif isinstance(expr, relax.BindingBlock): + return f.visit_binding_block(expr) + + # check no overloading case + basic_visitor = BasicVisitor() + visit(basic_visitor, expr) + + # check the output log + log_visitor = ASTPrinter() + visit(log_visitor, expr) + assert str(log_visitor.log) == visitor_str + + # check no overloading case + basic_mutator = BasicMutator() + # skip normalize GlobalVar since it requires context IRModule to get the checked_type_ + if isinstance(expr, relax.Expr) and not isinstance(expr, relax.GlobalVar): + expr = bb.normalize(expr) + assert_structural_equal(visit(basic_mutator, expr), expr) + + # check the output log and return value + post_log_mutator = ASTPostPrinterMutator() + if isinstance(expr, relax.Expr) and not isinstance(expr, relax.GlobalVar): + expr = bb.normalize(expr) + assert_structural_equal(visit(post_log_mutator, expr), expr) + assert str(post_log_mutator.log) == mutator_str + + +def test_constant(): + basic_check(relax.const(1.0), "Constant", "Constant") + + +def test_var(): + basic_check(x, "Var", "Var") + + +def test_dataflow_var(): + lv = relax.DataflowVar("lv", R.Tensor([n], "float32")) + basic_check(lv, "DataflowVar", "DataflowVar") + + +def test_tuple(): + t = relax.Tuple([x, y]) + basic_check(t, "\n".join(["Tuple", "\tVar", "\tVar"]), "\n".join(["Var", "Var", "Tuple"])) + + +def test_global_var(): + gv = relax.GlobalVar("gv") + basic_check(gv, "GlobalVar", "GlobalVar") + + +def test_seq_expr(): + bindings = [relax.VarBinding(x, relax.const(1))] + blocks = [relax.BindingBlock(bindings)] + seq_expr = relax.SeqExpr(blocks, x) + basic_check( + seq_expr, + "\n".join( + [ + "SeqExpr", + "\tBindingBlock", + "\t\tVarBinding", + "\t\t\tConstant", + "\t\t\tVarDef", + "\tVar", + ] + ), + "\n".join(["Constant", "VarDef", "VarBinding", "BindingBlock", "Var", "SeqExpr"]), + ) + + +def test_shape_expr(): + x = relax.ShapeExpr([m, n]) + basic_check(x, "ShapeExpr", "ShapeExpr") + + +def test_call(): + call_node = relax.op.add(x, y) + basic_check( + call_node, + "\n".join(["Call", "\tOp", "\tVar", "\tVar"]), + "\n".join(["Op", "Var", "Var", "ShapeExpr", "Call"]), + ) + + +def test_if(): + if_node = relax.If(x, x, x) + basic_check( + if_node, + "\n".join(["If", "\tVar", "\tVar", "\tVar"]), + "\n".join(["Var", "Var", "SeqExpr", "Var", "SeqExpr", "If"]), + ) + + +def test_tuple_getitem(): + tuple_getitem_node = relax.TupleGetItem(relax.Tuple([x, y]), 0) + basic_check( + tuple_getitem_node, + "\n".join(["TupleGetItem", "\tTuple", "\t\tVar", "\t\tVar"]), + "\n".join(["Var", "Var", "Tuple", "TupleGetItem"]), + ) + + +def test_binding_block(): + bb._begin_binding_block() + gv0 = bb.emit(relax.op.add(x, y)) + gv1 = bb.match_cast(y, R.Tensor([m, n], "float32")) + b0 = bb._end_block() + basic_check( + b0, + "\n".join( + [ + "BindingBlock", + "\tVarBinding", + "\t\tCall", + "\t\t\tOp", + "\t\t\tVar", + "\t\t\tVar", + "\t\tVarDef", + "\tMatchCast", + "\t\tVarDef", + "\t\tVar", + ] + ), + "\n".join( + [ + "Op", + "Var", + "Var", + "Call", + "ShapeExpr", + "VarDef", + "VarBinding", + "Var", + "ShapeExpr", + "ShapeExpr", + "VarDef", + "MatchCast", + "BindingBlock", + ] + ), + ) + + +def test_dataflow_block(): + bb._begin_dataflow_block() + lv0 = bb.emit(relax.op.add(x, y)) + gv1 = bb.match_cast(y, R.Tensor([m, n], "float32")) + b0 = bb._end_block() + basic_check( + b0, + "\n".join( + [ + "DataflowBlock", + "\tVarBinding", + "\t\tCall", + "\t\t\tOp", + "\t\t\tVar", + "\t\t\tVar", + "\t\tDataflowVarDef", + "\tMatchCast", + "\t\tDataflowVarDef", + "\t\tVar", + ] + ), + "\n".join( + [ + "Op", + "Var", + "Var", + "Call", + "ShapeExpr", + "DataflowVarDef", + "VarBinding", + "Var", + "ShapeExpr", + "ShapeExpr", + "DataflowVarDef", + "MatchCast", + "DataflowBlock", + ] + ), + ) + + +def test_function(): + bindings = [relax.VarBinding(x, relax.const(1))] + blocks = [relax.BindingBlock(bindings)] + seq_expr = relax.SeqExpr(blocks, x) + func = relax.Function([x], seq_expr, R.Tensor([n], "float32")) + basic_check( + func, + "\n".join( + [ + "Function", + "\tVarDef", + "\tSeqExpr", + "\t\tBindingBlock", + "\t\t\tVarBinding", + "\t\t\t\tConstant", + "\t\t\t\tVarDef", + "\t\tVar", + ] + ), + "\n".join( + [ + "VarDef", + "Constant", + "VarDef", + "VarBinding", + "BindingBlock", + "Var", + "SeqExpr", + "Function", + ] + ), + ) + + +def test_extern_func(): + func = relax.ExternFunc("f") + basic_check(func, "ExternFunc", "ExternFunc") + + +def test_inherit(): + # The internal class is not instantiated. + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + call_node = relax.op.add(x, y) + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "\tOp", "\tVar", "\tVar"]) + + +def test_inherit_with_cls(): + # The decorator converts `InternalVisitor` to a wrapper class. + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + # `InternalVisitor._cls` refers to the original `InternalVisitor` users defined. + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + call_node = relax.op.add(x, y) + iv = InternalVisitor() + iv.visit_expr(call_node) + assert str(iv.log) == "\n".join(["InternalCall", "\tOp", "\tVar", "\tVar"]) + + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "\tOp", "\tVar", "\tVar"]) + + +def test_wrong_inherit(): + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def visit_call_(self, op: Call) -> None: + pass + + with pytest.raises( + TypeError, + match="Inheritance from a decorated object `LeafVisitor` is not allowed. Please inherit from `LeafVisitor._cls`.", + ): + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor): + def visit_call_(self, op: Call) -> None: + pass + + +def test_call_visitor_super(): + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + super().visit_call_(op) # call PyExprVisitor.visit_call_ + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + super().visit_call_(op) # call InternalVisit.visit_call_ + + call_node = relax.op.add(x, y) + iv = InternalVisitor() + iv.visit_expr(call_node) + assert str(iv.log) == "\n".join(["InternalCall", "Op", "Var", "Var"]) + + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"]) + + +def test_call_mutator_super(): + @relax.expr_functor.mutator + class InternalMutator(PyExprMutator): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + return super().visit_call_(op) # call PyExprMutator.visit_call_ + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + return super().visit_var_(op) # call PyExprMutator.visit_var_ + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + return super().visit_op_(op) # call PyExprMutator.visit_op_ + + @relax.expr_functor.mutator + class LeafMutator(InternalMutator._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + return super().visit_call_(op) # call InternalMutator.visit_call_ + + call_node = relax.op.add(x, y) + im = InternalMutator() + im.visit_expr(call_node) + assert str(im.log) == "\n".join(["InternalCall", "Op", "Var", "Var"]) + + lm = LeafMutator() + lm.visit_expr(call_node) + assert str(lm.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_frontend_common.py b/tests/python/relax/test_frontend_common.py new file mode 100644 index 000000000000..39f9af103134 --- /dev/null +++ b/tests/python/relax/test_frontend_common.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm.relax.frontend import detach_params +from tvm.script.parser import relax as R + + +def test_detach_params(): + @R.function + def func(x: R.Tensor((2, 3), "float32")): + return x + + param = tvm.nd.empty((3,), "float32") + mod = tvm.IRModule({"func": func.with_attr("params", [param])}) + detached_mod, detached_params = detach_params(mod) + + tvm.ir.assert_structural_equal(detached_mod, tvm.IRModule({"func": func})) + assert len(detached_params) == 1 + assert "func" in detached_params + assert isinstance(detached_params["func"], list) + assert len(detached_params["func"]) == 1 + tvm.testing.assert_allclose(detached_params["func"][0].numpy(), param.numpy()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py new file mode 100644 index 000000000000..72ea193a029b --- /dev/null +++ b/tests/python/relax/test_frontend_dynamo.py @@ -0,0 +1,369 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("torch._dynamo") + + +import tvm +from tvm import relax, meta_schedule as ms, tir +import tvm.testing +import torch +import torch._dynamo as dynamo +from tvm.relax.frontend.torch import relax_dynamo +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_relax_dynamo(): + class Input1(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(100, 10) + + def forward(self, x): + return torch.nn.functional.relu(self.lin(x)) + + model = Input1() + ### construct the database + @tvm.script.ir_module + class Input1_ir: + @T.prim_func + def main( + inp_0: T.Buffer((T.int64(10), T.int64(100)), "float32"), + param_0: T.Buffer((T.int64(100), T.int64(10)), "float32"), + param_1: T.Buffer(T.int64(10), "float32"), + compute: T.Buffer((T.int64(10), T.int64(10)), "float32"), + ): + # function attr dict + T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + matmul = T.alloc_buffer([T.int64(10), T.int64(10)], dtype="float32") + T_add = T.alloc_buffer([T.int64(10), T.int64(10)], dtype="float32") + for i0, i1, k in T.grid(T.int64(10), T.int64(10), T.int64(100)): + with T.block("matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(inp_0[v_i0, v_k], param_0[v_k, v_i1]) + T.writes(matmul[v_i0, v_i1]) + with T.init(): + matmul[v_i0, v_i1] = T.float32(0) + matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + inp_0[v_i0, v_k] * param_0[v_k, v_i1] + for ax0, ax1 in T.grid(T.int64(10), T.int64(10)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(matmul[v_ax0, v_ax1], param_1[v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + param_1[v_ax1] + for i0, i1 in T.grid(T.int64(10), T.int64(10)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_add[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.max(T_add[v_i0, v_i1], T.float32(0)) + + db = ms.Database.create("memory") + workload = db.commit_workload(Input1_ir) + + sch = tir.Schedule(Input1_ir, debug_mask="all") + b0 = sch.get_block(name="matmul", func_name="main") + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.compute_inline(block=b1) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") + l3, l4, l5 = sch.get_loops(block=b0) + v6, v7, v8, v9 = sch.sample_perfect_tile( + loop=l3, n=4, max_innermost_factor=64, decision=[1, 2, 5, 1] + ) + l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9], preserve_unit_iters=True) + v14, v15, v16, v17 = sch.sample_perfect_tile( + loop=l4, n=4, max_innermost_factor=64, decision=[1, 1, 10, 1] + ) + l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17], preserve_unit_iters=True) + v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64, decision=[100, 1]) + l24, l25 = sch.split(loop=l5, factors=[v22, v23], preserve_unit_iters=True) + sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21) + (b26,) = sch.get_consumers(block=b0) + sch.reverse_compute_at(block=b26, loop=l18, preserve_unit_loops=True, index=-1) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.parallel", ann_val=96) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.vectorize", ann_val=64) + v27 = sch.sample_categorical( + candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0 + ) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v27) + + tuning_record = ms.database.TuningRecord(sch.trace, workload, run_secs=[0.0]) + db.commit_tuning_record(tuning_record) + ### Optimize the model with tuned-log + with db: + opt_model = torch.compile(model, backend=relax_dynamo()) + inp = torch.randn(10, 100) + tvm.testing.assert_allclose( + opt_model(inp).detach().numpy(), model(inp).detach().numpy(), rtol=1e-5, atol=1e-5 + ) + + +def test_subgraph_capture(): + import torch + from tvm.relax.frontend.torch.dynamo import dynamo_capture_subgraphs + + class Input1(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(100, 10) + + def forward(self, x): + return torch.nn.functional.relu(self.lin(x)) + + @tvm.script.ir_module + class Expected1: + @R.function + def subgraph_0( + inp_0: R.Tensor((10, 100), dtype="float32"), + w0: R.Tensor((10, 100), dtype="float32"), + w1: R.Tensor((10,), dtype="float32"), + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, axes=None) + lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32") + lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1) + lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2) + gv: R.Tensor((10, 10), dtype="float32") = lv3 + R.output(gv) + return gv + + model = Input1() + mod = dynamo_capture_subgraphs(model, torch.randn(10, 100)) + binding = {"w0": model.lin.weight.detach().numpy(), "w1": model.lin.bias.detach().numpy()} + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + expected = relax.transform.BindParams("subgraph_0", binding)(Expected1) + tvm.ir.assert_structural_equal(mod, expected) + + def Input2(a, b): + x = a / (torch.sin(a) + 1) + if torch.sum(b) < 1: + b = b * -1 + return x * b + + @tvm.script.ir_module + class Expected2: + @R.function + def subgraph_0( + inp_0: R.Tensor((10,), dtype="float32"), inp_1: R.Tensor((10,), dtype="float32") + ) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((), dtype="bool")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10,), dtype="float32") = R.sin(inp_0) + lv1: R.Tensor((10,), dtype="float32") = R.add(lv, R.const(1, "float32")) + lv2: R.Tensor((10,), dtype="float32") = R.divide(inp_0, lv1) + lv3: R.Tensor((), dtype="float32") = R.sum(inp_1, axis=None, keepdims=False) + lv4: R.Tensor((), dtype="bool") = R.less(lv3, R.const(1, "float32")) + gv: R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((), dtype="bool")) = ( + lv2, + lv4, + ) + R.output(gv) + return gv + + @R.function + def subgraph_1( + inp_01: R.Tensor((10,), dtype="float32"), inp_11: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # block 0 + with R.dataflow(): + lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_11, inp_01) + gv1: R.Tensor((10,), dtype="float32") = lv5 + R.output(gv1) + return gv1 + + mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10)) + tvm.ir.assert_structural_equal(mod, Expected2) + + class Input3(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(100, 10) + + def forward(self, x, add_one=False): + if add_one: + x = x + 1 + return torch.nn.functional.relu(self.lin(x)) + + @tvm.script.ir_module + class Expected3: + @R.function + def subgraph_0( + inp_0: R.Tensor((10, 100), dtype="float32"), + w0: R.Tensor((10, 100), dtype="float32"), + w1: R.Tensor((10,), dtype="float32"), + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv0 = R.add(inp_0, R.const(1, "float32")) + lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, axes=None) + lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(lv0, lv, out_dtype="float32") + lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1) + lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2) + gv: R.Tensor((10, 10), dtype="float32") = lv3 + R.output(gv) + return gv + + model = Input3() + mod = dynamo_capture_subgraphs(model, torch.randn(10, 100), add_one=True) + binding = {"w0": model.lin.weight.detach().numpy(), "w1": model.lin.bias.detach().numpy()} + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + expected = relax.transform.BindParams("subgraph_0", binding)(Expected3) + tvm.ir.assert_structural_equal(mod, expected) + + +def verify_dynamo_model(torch_model, input_info, binding, expected): + import torch + import torch._dynamo as dynamo + from tvm.relax.frontend.torch import from_fx + + args = [] + for info in input_info: + args.append(torch.zeros(*info[0], dtype=_convert_data_type(info[1]))) + graph_model = dynamo.export(torch_model, *args)[0] + mod = from_fx(graph_model, input_info, unwrap_unit_return_tuple=True) + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + expected = relax.transform.BindParams("main", binding)(expected) + tvm.ir.assert_structural_equal(mod, expected) + + +def _convert_data_type(input_type): + """converts the PyTorch scalar type input_type to a TVM dtype.""" + import torch # type: ignore + + input_type = input_type.lower() if isinstance(input_type, str) else input_type + if input_type == "float32": + return torch.float32 + elif input_type == "float16": + return torch.float16 + elif input_type == "int64": + return torch.int64 + elif input_type == "int32": + return torch.int32 + elif input_type == "bool": + return torch.bool + else: + raise NotImplementedError("input_type {} is not handled yet".format(input_type)) + + +@tvm.testing.requires_gpu +def test_ones(): + import torch + from torch.nn import Module + + class Ones(Module): + def forward(self, input): + return torch.ones((10, 10), dtype=torch.float32) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.full( + R.shape([10, 10]), R.const(1, "float32"), dtype="float32" + ) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_dynamo_model( + Ones(), + [([256, 256], "float32")], + {}, + Expected1, + ) + + +@tvm.testing.requires_gpu +def test_full(): + import torch + from torch.nn import Module + + class Full(Module): + def forward(self, input): + return torch.full((10, 10), 1, dtype=torch.float32) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.full( + R.shape([10, 10]), R.const(1, "float32"), dtype="float32" + ) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_dynamo_model( + Full(), + [([256, 256], "float32")], + {}, + Expected1, + ) + + +@tvm.testing.requires_gpu +def test_masked_fill(): + import torch + from torch.nn import Module + + class MaskedFill(Module): + def forward(self, mask, input): + return input.masked_fill(mask, 0) + + class InplaceMaskedFill(Module): + def forward(self, mask, input): + input.masked_fill_(mask, 0) + return input + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="bool"), inp_1: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = R.full_like( + inp_1, R.const(0, "int32"), dtype="void" + ) + lv1: R.Tensor((256, 256), dtype="float32") = R.where(inp_0, lv, inp_1) + gv: R.Tensor((256, 256), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_dynamo_model( + MaskedFill(), [([256, 256], "bool"), ([256, 256], "float32")], {}, Expected1 + ) + verify_dynamo_model( + InplaceMaskedFill(), [([256, 256], "bool"), ([256, 256], "float32")], {}, Expected1 + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py new file mode 100644 index 000000000000..9e07ff7b59f7 --- /dev/null +++ b/tests/python/relax/test_frontend_from_fx.py @@ -0,0 +1,2747 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import relax +import tvm.testing +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def verify_model(torch_model, input_info, binding, expected): + import torch + from torch import fx + from tvm.relax.frontend.torch import from_fx + + graph_model = fx.symbolic_trace(torch_model) + mod = from_fx(graph_model, input_info) + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + expected = relax.transform.BindParams("main", binding)(expected) + tvm.ir.assert_structural_equal(mod, expected) + + +@tvm.testing.requires_gpu +def test_conv1d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Conv1D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tensor((1, 6, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1]) + lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tensor((1, 6, 4), dtype="float32") = lv3 + R.output(gv) + return gv + + class Conv1D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7), dtype="float32"), + ) -> R.Tensor((1, 6, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + gv: R.Tensor((1, 6, 4), dtype="float32") = lv1 + R.output(gv) + return gv + + input_info = [([1, 3, 10], "float32")] + + model = Conv1D1() + binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy()} + verify_model(model, input_info, binding, expected1) + + model = Conv1D2() + binding = {"w1": model.conv.weight.numpy()} + verify_model(model, input_info, binding, expected2) + + +@tvm.testing.requires_gpu +def test_conv2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Conv2D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1]) + lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3 + R.output(gv) + return gv + + class Conv2D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv1 + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + + model = Conv2D1() + binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy()} + verify_model(model, input_info, binding, expected1) + + model = Conv2D2() + binding = {"w1": model.conv.weight.numpy()} + verify_model(model, input_info, binding, expected2) + + +@tvm.testing.requires_gpu +def test_linear(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + # nn.Linear + class Dense1(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=True) + + def forward(self, input): + return self.linear(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((7, 10), dtype="float32"), + w2: R.Tensor((7,), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 7), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) + lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( + input_1, lv, out_dtype="float32" + ) + lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2) + gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv2 + R.output(gv) + return gv + + class Dense2(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=False) + + def forward(self, input): + return self.linear(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((7, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 7), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) + lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( + input_1, lv, out_dtype="float32" + ) + gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv1 + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + + model = Dense1() + binding = {"w1": model.linear.weight.numpy(), "w2": model.linear.bias.numpy()} + verify_model(model, input_info, binding, expected1) + + model = Dense2() + binding = {"w1": model.linear.weight.numpy()} + verify_model(model, input_info, binding, expected2) + + # matmul + class MatMul1(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32"), + input_2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + MatMul1(), + [([10, 10], "float32"), ([10, 10], "float32")], + {}, + expected3, + ) + + +@tvm.testing.requires_gpu +def test_bmm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class BMM(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.bmm(x, y) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input_1: R.Tensor((4, 128, 256), dtype="float32"), + input_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tensor((4, 128, 512), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tensor((4, 128, 512), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + BMM(), + [((4, 128, 256), "float32"), ((4, 256, 512), "float32")], + {}, + Expected, + ) + + +@tvm.testing.requires_gpu +def test_baddbmm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class BAddBMM1(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y) + + class BAddBMM2(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y, alpha=2, beta=0) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tensor((4, 128, 512), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv, inp_0) + gv: R.Tensor((4, 128, 512), dtype="float32") = lv1 + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tensor((4, 128, 512), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( + lv, R.const(2, "float32") + ) + gv: R.Tensor((4, 128, 512), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model( + BAddBMM1(), + [((4, 128, 512), "float32"), ((4, 128, 256), "float32"), ((4, 256, 512), "float32")], + {}, + Expected1, + ) + + verify_model( + BAddBMM2(), + [((4, 128, 512), "float32"), ((4, 128, 256), "float32"), ((4, 256, 512), "float32")], + {}, + Expected2, + ) + + +@tvm.testing.requires_gpu +def test_relu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class ReLU0(Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, input): + return self.relu(input) + + class ReLU1(Module): + def forward(self, input): + return torch.nn.functional.relu(input) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.nn.relu(input_1) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([10, 10], "float32")] + verify_model(ReLU0(), input_info, {}, expected) + verify_model(ReLU1(), input_info, {}, expected) + + +@tvm.testing.requires_gpu +def test_relu6(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class ReLU6(Module): + def __init__(self): + super().__init__() + self.relu6 = torch.nn.ReLU6() + + def forward(self, input): + return self.relu6(input) + + @tvm.script.ir_module + class expected: + @R.function + def main(input: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.clip(input, 0, 6) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([10, 10], "float32")] + verify_model(ReLU6(), input_info, {}, expected) + + +@tvm.testing.requires_gpu +def test_maxpool2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class MaxPool2d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[1, 1], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool2d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[2, 2], + strides=[2, 2], + dilation=[2, 3], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tensor((1, 3, 4, 4), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool2d3(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 6, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 6, 6), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[4, 4], + strides=[2, 2], + dilation=[1, 1], + padding=[2, 2, 2, 2], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tensor((1, 3, 6, 6), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(MaxPool2d(), input_info, {}, expected1) + verify_model(MaxPool2d2(), input_info, {}, expected2) + verify_model(MaxPool2d3(), input_info, {}, expected3) + + +@tvm.testing.requires_gpu +def test_adaptive_avgpool2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class AdaptiveAvgPool2d0(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool2d1(Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.adaptive_avg_pool2d( + input_1, output_size=[10, 10], layout="NCHW", out_layout="NCHW" + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(AdaptiveAvgPool2d0(), input_info, {}, expected1) + verify_model(AdaptiveAvgPool2d1(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_flatten(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Flatten(Module): + def __init__(self): + super().__init__() + self.f = torch.nn.Flatten(2, -1) + + def forward(self, input): + return self.f(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 100), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 100), dtype="float32") = R.reshape(input_1, (1, 3, 100)) + gv: R.Tensor((1, 3, 100), dtype="float32") = lv + R.output(gv) + return gv + + # call_module + verify_model(Flatten(), input_info, {}, expected1) + # call_method + verify_model(torch.nn.Flatten(2, -1), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_batchnorm2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class BatchNorm2d(Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, input): + return self.bn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + w3: R.Tensor((3,), dtype="float32"), + w4: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = R.nn.batch_norm( + input_1, + w1, + w2, + w3, + w4, + axis=1, + epsilon=1e-05, + center=True, + scale=True, + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv1 + R.output(gv) + return gv + + model = BatchNorm2d() + binding = { + "w1": model.bn.weight.numpy(), + "w2": model.bn.bias.numpy(), + "w3": model.bn.running_mean.numpy(), + "w4": model.bn.running_var.numpy(), + } + verify_model(BatchNorm2d(), input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_embedding(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([4], "int64")] + + class Embedding(Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(10, 3) + + def forward(self, input): + return self.embedding(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3), dtype="float32") + ) -> R.Tensor((4, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4,), dtype="int32") = R.astype(input_1, dtype="int32") + lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0) + gv: R.Tensor((4, 3), dtype="float32") = lv1 + R.output(gv) + return gv + + model = Embedding() + binding = {"w1": model.embedding.weight.numpy()} + verify_model(model, input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_dropout(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Dropout1(Module): + def __init__(self): + super().__init__() + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, input): + return self.dropout(input) + + class Dropout2(Module): + def forward(self, input): + return torch.dropout(input, 0.5, train=True) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = input_1 + R.output(gv) + return gv + + verify_model(Dropout1(), input_info, {}, expected1) + verify_model(Dropout2(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_layernorm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class LayerNorm(Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm((10, 10)) + + def forward(self, input): + return self.ln(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((10, 10), dtype="float32"), + w2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm( + input_1, + w1, + w2, + axes=[-2, -1], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + model = LayerNorm() + binding = { + "w1": model.ln.weight.numpy(), + "w2": model.ln.bias.numpy(), + } + verify_model(LayerNorm(), input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_functional_layernorm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class LayerNorm(Module): + def __init__(self, shape): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(shape)) + self.bias = torch.nn.Parameter(torch.zeros(shape)) + + def forward(self, input): + return torch.nn.functional.layer_norm( + input, self.weight.shape, self.weight, self.bias, 1e-5 + ) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((10, 10), dtype="float32"), + w2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm( + input_1, + w1, + w2, + axes=[-2, -1], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + model = LayerNorm((10, 10)) + binding = { + "w1": model.weight.numpy(), + "w2": model.bias.numpy(), + } + verify_model(model, input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_silu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class SiLU(Module): + def __init__(self): + super().__init__() + self.silu = torch.nn.SiLU() + + def forward(self, input): + return self.silu(input) + + class SiLU2(Module): + def forward(self, input): + return torch.nn.functional.silu(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.silu(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(SiLU(), input_info, {}, expected1) + verify_model(SiLU2(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_groupnorm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class GroupNorm(Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.GroupNorm(3, 3) + + def forward(self, input): + return self.gn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.group_norm( + input_1, + w1, + w2, + num_groups=3, + channel_axis=1, + axes=[2, 3], + epsilon=1.0000000000000001e-05, + center=True, + scale=True, + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + model = GroupNorm() + binding = { + "w1": model.gn.weight.numpy(), + "w2": model.gn.bias.numpy(), + } + verify_model(model, input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_softmax(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Softmax(Module): + def __init__(self): + super().__init__() + self.sm = torch.nn.Softmax(dim=1) + + def forward(self, input): + return self.sm(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softmax(input_1, axis=1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Softmax(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_binary(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] + input_info2 = [([1, 3, 10, 10], "float32")] + + # Add + class Add1(Module): + def forward(self, lhs, rhs): + return lhs + rhs + + @tvm.script.ir_module + class expected1: + @R.function + def main( + lhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs, rhs) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Add2(Module): + def forward(self, lhs): + return lhs + 1.0 + + @tvm.script.ir_module + class expected2: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Add1(), input_info1, {}, expected1) + verify_model(Add2(), input_info2, {}, expected2) + + # Sub + class Sub1(Module): + def forward(self, lhs, rhs): + return lhs - rhs + + @tvm.script.ir_module + class expected3: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Sub2(Module): + def forward(self, lhs): + return lhs - 1.0 + + @tvm.script.ir_module + class expected4: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sub1(), input_info1, {}, expected3) + verify_model(Sub2(), input_info2, {}, expected4) + + # Mul + class Mul1(Module): + def forward(self, lhs, rhs): + return lhs * rhs + + @tvm.script.ir_module + class expected5: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Mul2(Module): + def forward(self, lhs): + return lhs * 1.0 + + @tvm.script.ir_module + class expected6: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Mul1(), input_info1, {}, expected5) + verify_model(Mul2(), input_info2, {}, expected6) + + # True div + class TrueDiv1(Module): + def forward(self, lhs, rhs): + return lhs / rhs + + @tvm.script.ir_module + class expected7: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class TrueDiv2(Module): + def forward(self, lhs): + return lhs / 1.0 + + @tvm.script.ir_module + class expected8: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(TrueDiv1(), input_info1, {}, expected7) + verify_model(TrueDiv2(), input_info2, {}, expected8) + + # Floor div + class FloorDiv1(Module): + def forward(self, lhs, rhs): + return lhs // rhs + + @tvm.script.ir_module + class expected9: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.floor_divide(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class FloorDiv2(Module): + def forward(self, lhs): + return lhs // 1.0 + + @tvm.script.ir_module + class expected10: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.floor_divide(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(FloorDiv1(), input_info1, {}, expected9) + verify_model(FloorDiv2(), input_info2, {}, expected10) + + # Power + class Power1(Module): + def forward(self, lhs, rhs): + return lhs**rhs + + @tvm.script.ir_module + class expected11: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.power(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Power2(Module): + def forward(self, lhs): + return lhs**1.0 + + @tvm.script.ir_module + class expected12: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.power(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Power1(), input_info1, {}, expected11) + verify_model(Power2(), input_info2, {}, expected12) + + # LT + class LT1(Module): + def forward(self, lhs, rhs): + return lhs < rhs + + @tvm.script.ir_module + class expected13: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="bool"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv + R.output(gv) + return gv + + class LT2(Module): + def forward(self, lhs): + return lhs < 1.0 + + @tvm.script.ir_module + class expected14: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="bool"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv + R.output(gv) + return gv + + verify_model(LT1(), input_info1, {}, expected13) + verify_model(LT2(), input_info2, {}, expected14) + + +@tvm.testing.requires_gpu +def test_size(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Size(Module): + def forward(self, input): + return input.size() + + @tvm.script.ir_module + class expected1: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 10, 10]): + # block 0 + with R.dataflow(): + gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10]) + R.output(gv) + return gv + + verify_model(Size(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_squeeze(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([3, 1, 4, 1], "float32")] + + class Squeeze1(Module): + def forward(self, input): + return input.squeeze(1) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + ) -> R.Tensor((3, 4, 1), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[1]) + gv: R.Tensor((3, 4, 1), dtype="float32") = lv + R.output(gv) + return gv + + class Squeeze2(Module): + def forward(self, input): + return input.squeeze() + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + ) -> R.Tensor((3, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, axis=None) + gv: R.Tensor((3, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Squeeze1(), input_info, {}, Expected1) + verify_model(Squeeze2(), input_info, {}, Expected2) + + +@tvm.testing.requires_gpu +def test_unsqueeze(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Unsqueeze1(Module): + def forward(self, input): + return input.unsqueeze(1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = R.expand_dims(input_1, 1) + gv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Unsqueeze2(Module): + def forward(self, input): + return input.unsqueeze(-1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10, 1), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = R.expand_dims(input_1, -1) + gv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Unsqueeze1(), input_info, {}, expected1) + verify_model(Unsqueeze2(), input_info, {}, expected2) + + +@tvm.testing.requires_gpu +def test_getattr(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class GetAttr1(Module): + def forward(self, input): + return input.shape + + @tvm.script.ir_module + class expected1: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 10, 10]): + # block 0 + with R.dataflow(): + gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10]) + R.output(gv) + return gv + + verify_model(GetAttr1(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_getitem(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Slice1(Module): + def forward(self, x): + return x[0, 1::2, :, :3] + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 1, 10, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 1, 10, 3), dtype="float32") = R.strided_slice( + x, + axes=[0, 1, 2, 3], + begin=[0, 1, 0, 0], + end=[1, T.int64(3), T.int64(10), 3], + strides=[1, 2, 1, 1], + ) + lv1: R.Tensor((1, 1, 10, 3), dtype="float32") = R.reshape(lv, (1, 1, 10, 3)) + gv: R.Tensor((1, 1, 10, 3), dtype="float32") = lv1 + R.output(gv) + return gv + + class Slice2(Module): + def forward(self, x): + return x[:, None, None, :, None] + + @I.ir_module + class expected2: + @R.function + def main( + inp_0: R.Tensor((8, 16), dtype="float32") + ) -> R.Tensor((8, 1, 1, 16, 1), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice( + inp_0, axes=[0, 1], begin=[0, 0], end=[8, 16], strides=[1, 1] + ) + lv1: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = R.reshape( + lv, R.shape([8, 1, 1, 16, 1]) + ) + gv: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Slice1(), [([1, 3, 10, 10], "float32")], {}, expected1) + verify_model(Slice2(), [([8, 16], "float32")], {}, expected2) + + +@tvm.testing.requires_gpu +def test_unary(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + # sin + class Sin(Module): + def forward(self, input): + return torch.sin(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sin(), input_info, {}, expected1) + + # cos + class Cos(Module): + def forward(self, input): + return torch.cos(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Cos(), input_info, {}, expected2) + + # exp + class Exp(Module): + def forward(self, input): + return torch.exp(input) + + @tvm.script.ir_module + class expected_exp: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Exp(), input_info, {}, expected_exp) + + # sqrt + class Sqrt(Module): + def forward(self, input): + return torch.sqrt(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sqrt(), input_info, {}, expected3) + + # sigmoid + class Sigmoid(Module): + def forward(self, input): + return torch.sigmoid(input) + + @tvm.script.ir_module + class expected4: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sigmoid(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sigmoid(), input_info, {}, expected4) + + # round + class Round(Module): + def forward(self, input): + return torch.round(input) + + @tvm.script.ir_module + class expected5: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.round(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Round(), input_info, {}, expected5) + + +@tvm.testing.requires_gpu +def test_gelu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Gelu(Module): + def forward(self, input): + return torch.nn.functional.gelu(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.gelu(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Gelu(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_tanh(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Tanh(Module): + def forward(self, input): + return torch.tanh(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tanh(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_clamp(): + import torch + from torch import fx + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Clamp(Module): + def forward(self, input): + return torch.clamp(input, min=0.1, max=0.5) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.1, 0.5) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Clamp(), input_info, {}, expected1) + + from tvm.relax.frontend.torch import from_fx + + with pytest.raises( + ValueError, match="TVM only supports constant max value for torch.clamp/clip" + ): + + class Clamp_Error(Module): + def forward(self, input): + return torch.clamp(input, min=0.5, max=None) + + gm = fx.symbolic_trace(Clamp_Error()) + from_fx(gm, input_info) + + with pytest.raises( + ValueError, match="TVM only supports constant min value for torch.clamp/clip" + ): + + class Clamp_Error(Module): + def forward(self, input): + return torch.clamp(input, min=input, max=input) + + gm = fx.symbolic_trace(Clamp_Error()) + from_fx(gm, input_info) + + +@tvm.testing.requires_gpu +def test_interpolate(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Interpolate(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (5, 5)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 5, 5), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 5), dtype="float32") = R.image.resize2d( + input_1, + (5, 5), + roi=[0.000000, 0.000000, 0.000000, 0.000000], + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="asymmetric", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 3, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Interpolate(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_addmm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [ + ([10, 10], "float32"), + ([10, 10], "float32"), + ([10, 10], "float32"), + ] + + class Addmm(Module): + def forward(self, x1, x2, x3): + return torch.addmm(x1, x2, x3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv) + gv: R.Tensor((10, 10), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Addmm(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_split(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Split(Module): + def forward(self, input): + return torch.split(input, 1, dim=1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=3, axis=1) + gv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = lv + R.output(gv) + return gv + + verify_model(Split(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_cumsum(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Cumsum(Module): + def forward(self, input): + return torch.cumsum(input, dim=1, dtype=torch.int32) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 2, 3, 4), dtype="int32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="int32") = R.cumsum(input_1, axis=1, dtype="int32") + gv: R.Tensor((1, 2, 3, 4), dtype="int32") = lv + R.output(gv) + return gv + + verify_model(Cumsum(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_chunk(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Chunk(Module): + def forward(self, input): + return torch.chunk(input, 3, dim=1) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=3, axis=1) + gv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = lv + R.output(gv) + return gv + + verify_model(Chunk(), input_info, {}, Expected) + + +@tvm.testing.requires_gpu +def test_inplace_fill(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class InplaceFill(Module): + def forward(self, input): + input.fill_(1.5) + return input + + @tvm.script.ir_module + class Expected: + @R.function + def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.full( + R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32" + ) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(InplaceFill(), [([10, 10], "float32")], {}, Expected) + + +@tvm.testing.requires_gpu +def test_arange(): + import numpy as np + import torch + from torch import fx + from torch.nn import Module + from tvm.relax.frontend.torch import from_fx + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Arange(Module): + def forward(self, input): + return torch.arange(0, 20, dtype=torch.int32) + + graph_model = fx.symbolic_trace(Arange()) + mod = from_fx(graph_model, [([10, 10], "float32")]) + assert len(mod["main"].body.blocks) == 1 + assert len(mod["main"].body.blocks[0].bindings) == 1 + assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant) + tvm.testing.assert_allclose( + mod["main"].body.blocks[0].bindings[0].value.data.numpy(), np.arange(0, 20, dtype="int32") + ) + + +@tvm.testing.requires_gpu +def test_empty(): + import torch + from torch import fx + from torch.nn import Module + from tvm.relax.frontend.torch import from_fx + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Empty(Module): + def forward(self, input): + return torch.empty((10, 10), dtype=torch.float32) + + graph_model = fx.symbolic_trace(Empty()) + mod = from_fx(graph_model, [([10, 10], "float32")]) + assert len(mod["main"].body.blocks) == 1 + assert len(mod["main"].body.blocks[0].bindings) == 1 + assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant) + assert mod["main"].body.blocks[0].bindings[0].value.data.shape == (10, 10) + assert mod["main"].body.blocks[0].bindings[0].value.data.dtype == "float32" + + +@tvm.testing.requires_gpu +def test_tensor(): + import torch + from torch import fx + from torch.nn import Module + from tvm.relax.frontend.torch import from_fx + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Empty1(Module): + def forward(self, input): + return torch.tensor(3, dtype=torch.float32) + + class Empty2(Module): + def forward(self, input): + return torch.tensor(3) + + graph_model1 = fx.symbolic_trace(Empty1()) + mod1 = from_fx(graph_model1, [([10, 10], "float32")]) + assert len(mod1["main"].body.blocks) == 1 + assert len(mod1["main"].body.blocks[0].bindings) == 1 + assert isinstance(mod1["main"].body.blocks[0].bindings[0].value, relax.Constant) + assert mod1["main"].body.blocks[0].bindings[0].value.data.shape == () + assert mod1["main"].body.blocks[0].bindings[0].value.data.dtype == "float32" + + graph_model2 = fx.symbolic_trace(Empty2()) + mod2 = from_fx(graph_model2, [([10, 10], "float32")]) + assert len(mod2["main"].body.blocks) == 1 + assert len(mod2["main"].body.blocks[0].bindings) == 1 + assert isinstance(mod2["main"].body.blocks[0].bindings[0].value, relax.Constant) + assert mod2["main"].body.blocks[0].bindings[0].value.data.shape == () + assert mod2["main"].body.blocks[0].bindings[0].value.data.dtype == "int64" + + +@tvm.testing.requires_gpu +def test_tril(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([10, 10], "float32")] + + class Tril(Module): + def forward(self, input): + return torch.tril(input, 1) + + class InplaceTril(Module): + def forward(self, input): + input.tril_(1) + return input + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tril(), input_info, {}, expected1) + verify_model(InplaceTril(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_triu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([10, 10], "float32")] + + class Triu(Module): + def forward(self, input): + return torch.triu(input, 1) + + class InplaceTriu(Module): + def forward(self, input): + input.triu_(1) + return input + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Triu(), input_info, {}, expected1) + verify_model(InplaceTriu(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_new_ones(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3], "float32")] + + class NewOnes(Module): + def forward(self, x): + return x.new_ones(1, 2, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3), dtype="float32") = R.full( + (1, 2, 3), R.const(1, "float32"), dtype="float32" + ) + gv: R.Tensor((1, 2, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(NewOnes(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_expand(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Expand(Module): + def forward(self, x): + return x.expand(4, 2, 3, 4) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((4, 2, 3, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 2, 3, 4), dtype="float32") = R.broadcast_to(x, (4, 2, 3, 4)) + gv: R.Tensor((4, 2, 3, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Expand(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_reduce(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + # sum + class Sum(Module): + def forward(self, x): + return torch.sum(x, (2, 1)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 1], keepdims=False) + gv: R.Tensor((1, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sum(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_datatype(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + # float + class ToFloat(Module): + def forward(self, x): + return x.float() + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 2, 3, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32") + gv: R.Tensor((1, 2, 3, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(ToFloat(), input_info, {}, expected1) + + # half + class ToHalf(Module): + def forward(self, x): + return x.half() + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 2, 3, 4), dtype="float16"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x, dtype="float16") + gv: R.Tensor((1, 2, 3, 4), dtype="float16") = lv + R.output(gv) + return gv + + verify_model(ToHalf(), input_info, {}, expected2) + + # type + class Type(Module): + def forward(self, x): + return x.type(torch.float32) + + # type + class TypeFromAttr(Module): + def forward(self, x): + return x.type(x.getattr("dtype")) + + # astype + class AsType(Module): + def forward(self, x): + return x.astype(torch.float32) + + verify_model(Type(), input_info, {}, expected1) + verify_model(TypeFromAttr(), input_info, {}, expected1) + verify_model(AsType(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_permute(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Permute(Module): + def forward(self, x): + return x.permute(0, 3, 2, 1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4, 3, 2), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Permute(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_reshape(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Reshape(Module): + def forward(self, x): + return x.reshape(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tensor((2, 12), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Reshape(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_transpose(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Transpose(Module): + def forward(self, x): + return x.transpose(1, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4, 3, 2), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Transpose(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_view(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class View(Module): + def forward(self, x): + return x.view(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tensor((2, 12), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(View(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_keep_params(): + import torch + from torch import fx + from torch.nn import Module + from tvm.relax.frontend import detach_params + from tvm.relax.frontend.torch import from_fx + + class Conv2D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6,), dtype="float32"), + w2: R.Tensor((6, 3, 7, 7), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(w1, [1, 6, 1, 1]) + lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3 + R.output(gv) + return gv + + model = Conv2D1() + graph_model = fx.symbolic_trace(model) + mod = from_fx(graph_model, [([1, 3, 10, 10], "float32")], keep_params_as_input=True) + mod, params = detach_params(mod) + tvm.ir.assert_structural_equal(mod, expected1) + func = mod["main"] + params = params["main"] + + assert len(params) == len(func.params) - 1 + for param_var, param_ndarray in zip(func.params[1:], params): + assert tuple(x.value for x in param_var.struct_info.shape.values) == param_ndarray.shape + assert param_var.struct_info.dtype == param_ndarray.dtype + + tvm.testing.assert_allclose(params[0].numpy(), model.conv.bias.detach().numpy()) + tvm.testing.assert_allclose(params[1].numpy(), model.conv.weight.detach().numpy()) + + +@tvm.testing.requires_gpu +def test_unwrap_unit_return_tuple(): + import torch.fx as fx + from torch.nn import Module + from tvm.relax.frontend.torch import from_fx + + class Identity(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return (x,) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((256, 256), dtype="float32") = inp_0 + R.output(gv) + return gv + + graph_model = fx.symbolic_trace(Identity()) + mod = from_fx(graph_model, [([256, 256], "float32")], unwrap_unit_return_tuple=True) + tvm.ir.assert_structural_equal(mod, Expected) + + +@tvm.testing.requires_gpu +def test_argmax(): + import torch + from torch.nn import Module + + class Argmax1(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmax(input, dim=-1) + + class Argmax2(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmax(input, dim=-1, keepdim=True) + + @tvm.script.ir_module + class Expected1: + @R.function + def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor((256,), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=False) + gv: R.Tensor((256,), dtype="int64") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor((256, 1), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=True) + gv: R.Tensor((256, 1), dtype="int64") = lv + R.output(gv) + return gv + + verify_model(Argmax1(), [([256, 256], "float32")], {}, Expected1) + verify_model(Argmax2(), [([256, 256], "float32")], {}, Expected2) + + +@tvm.testing.requires_gpu +def test_argmin(): + import torch + from torch.nn import Module + + class Argmin1(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmin(input) + + class Argmin2(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmin(input, keepdim=True) + + @tvm.script.ir_module + class Expected1: + @R.function + def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor((), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=False) + gv: R.Tensor((), dtype="int64") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor((1, 1), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=True) + gv: R.Tensor((1, 1), dtype="int64") = lv + R.output(gv) + return gv + + verify_model(Argmin1(), [([256, 256], "float32")], {}, Expected1) + verify_model(Argmin2(), [([256, 256], "float32")], {}, Expected2) + + +@tvm.testing.requires_gpu +def test_to(): + import torch + from torch.nn import Module + + class To1(Module): + def forward(self, input): + return input.to(torch.float16) + + class To2(Module): + def forward(self, input): + return input.to("cpu") + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 256), dtype="float16"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float16") = R.astype(inp_0, dtype="float16") + gv: R.Tensor((256, 256), dtype="float16") = lv + R.output(gv) + return gv + + @I.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((256, 256), dtype="float32") = inp_0 + R.output(gv) + return gv + + verify_model(To1(), [([256, 256], "float32")], {}, Expected1) + verify_model(To2(), [([256, 256], "float32")], {}, Expected2) + + +@tvm.testing.requires_gpu +def test_mean(): + import torch + from torch.nn import Module + + class Mean(Module): + def forward(self, input): + return input.mean(-1) + + class MeanKeepDim(Module): + def forward(self, input): + return input.mean(-1, keepdim=True) + + @I.ir_module + class Expected1: + @R.function + def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor((256,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=False) + gv: R.Tensor((256,), dtype="float32") = lv + R.output(gv) + return gv + + @I.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 1), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=True) + gv: R.Tensor((256, 1), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Mean(), [([256, 256], "float32")], {}, Expected1) + verify_model(MeanKeepDim(), [([256, 256], "float32")], {}, Expected2) + + +@tvm.testing.requires_gpu +def test_rsqrt(): + import torch + from torch.nn import Module + + class Rsqrt(Module): + def forward(self, input): + return torch.rsqrt(input) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = R.sqrt(inp_0) + lv1: R.Tensor((256, 256), dtype="float32") = R.divide(R.const(1, "float32"), lv) + gv: R.Tensor((256, 256), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Rsqrt(), [([256, 256], "float32")], {}, Expected1) + + +@tvm.testing.requires_gpu +def test_neg(): + import torch + from torch.nn import Module + + class Neg(Module): + def forward(self, input): + return -input + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = R.negative(inp_0) + gv: R.Tensor((256, 256), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Neg(), [([256, 256], "float32")], {}, Expected1) + + +@tvm.testing.requires_gpu +def test_max(): + import torch + from torch.nn import Module + + class Max(Module): + def forward(self, x, y): + return torch.max(x, y) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32"), + inp_1: R.Tensor((256, 256), dtype="float32"), + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = R.maximum(inp_0, inp_1) + gv: R.Tensor((256, 256), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], {}, Expected1) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py new file mode 100644 index 000000000000..809fe7e98f7c --- /dev/null +++ b/tests/python/relax/test_op_binary.py @@ -0,0 +1,213 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Callable +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + assert relax.op.add(x, y).op == Op.get("relax.add") + assert relax.op.divide(x, y).op == Op.get("relax.divide") + assert relax.op.floor_divide(x, y).op == Op.get("relax.floor_divide") + assert relax.op.multiply(x, y).op == Op.get("relax.multiply") + assert relax.op.power(x, y).op == Op.get("relax.power") + assert relax.op.subtract(x, y).op == Op.get("relax.subtract") + + assert relax.op.equal(x, y).op == Op.get("relax.equal") + assert relax.op.greater(x, y).op == Op.get("relax.greater") + assert relax.op.greater_equal(x, y).op == Op.get("relax.greater_equal") + assert relax.op.less(x, y).op == Op.get("relax.less") + assert relax.op.less_equal(x, y).op == Op.get("relax.less_equal") + assert relax.op.not_equal(x, y).op == Op.get("relax.not_equal") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +(binary_arith_op,) = tvm.testing.parameters( + (relax.op.add,), + (relax.op.divide,), + (relax.op.floor_divide,), + (relax.op.multiply,), + (relax.op.power,), + (relax.op.subtract,), + (relax.op.maximum,), + (relax.op.minimum,), +) + + +def test_binary_arith_infer_struct_info(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor((1, 3), "float32")) + x2 = relax.Var("x", R.Tensor((3, 2, 3), "float32")) + x3 = relax.Var("x", R.Tensor((3, 1, 3), "float32")) + x4 = relax.Var("x", R.Tensor("float32", ndim=2)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor((4, 3, 2, 1), "float32")) + y2 = relax.Var("y", R.Tensor("float32", ndim=2)) + y3 = relax.Var("y", R.Tensor("float32", ndim=-1)) + + _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, binary_arith_op(x1, y0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, binary_arith_op(x1, y1), relax.TensorStructInfo((4, 3, 2, 3), "float32")) + _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x3, y2), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x4, y0), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y1), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x4, y2), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y3), relax.TensorStructInfo(dtype="float32", ndim=-1)) + _check_inference(bb, binary_arith_op(x5, y0), relax.TensorStructInfo(dtype="", ndim=-1)) + + +(binary_cmp_op,) = tvm.testing.parameters( + (relax.op.equal,), + (relax.op.greater,), + (relax.op.greater_equal,), + (relax.op.less,), + (relax.op.less_equal,), + (relax.op.not_equal,), +) + + +def test_binary_cmp_infer_struct_info(binary_cmp_op: Callable): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor((2, 3), "int32")) + _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool")) + + +def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + k = tir.Var("k", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((1, n), "float32")) + x2 = relax.Var("x", R.Tensor((k, n, m), "float32")) + x3 = relax.Var("x", R.Tensor((3, 1, n), "float32")) + x4 = relax.Var("x", R.Tensor("float32", ndim=2)) + y0 = relax.Var("y", R.Tensor((m, n), "float32")) + y1 = relax.Var("y", R.Tensor((m, n + 2), "float32")) + y2 = relax.Var("y", R.Tensor((4, k, m, 1), "float32")) + y3 = relax.Var("y", R.Tensor("float32", ndim=2)) + y4 = relax.Var("y", R.Tensor("float32", ndim=-1)) + _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, binary_arith_op(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x1, y0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, binary_arith_op(x1, y2), relax.TensorStructInfo((4, k, m, n), "float32")) + _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x2, y3), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x3, y3), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x4, y0), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x4, y3), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y4), relax.TensorStructInfo(dtype="float32", ndim=-1)) + + +def test_binary_infer_struct_info_shape_var(binary_arith_op: Callable): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=4)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + s4 = relax.Var("s4", relax.ShapeStructInfo()) + x = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) + y1 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(s2, "float32")) + y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + y4 = relax.Var("y", relax.TensorStructInfo(s4, "float32")) + + _check_inference(bb, binary_arith_op(x, y0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, binary_arith_op(x, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x, y3), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x, y4), relax.TensorStructInfo(dtype="float32")) + + +def test_binary_arith_infer_struct_info_more_input_dtype(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + y0 = relax.Var("y", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + y2 = relax.Var("y", R.Tensor((2, 3), "int64")) + + _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, binary_arith_op(x1, y1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo((2, 3), "int64")) + + +def test_binary_infer_struct_info_shape_unequal_const_int(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(binary_arith_op(x0, y0)) + + +def test_binary_arith_infer_struct_info_dtype_mismatch(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "int32")) + with pytest.raises(TVMError): + bb.normalize(binary_arith_op(x, y)) + + +def test_binary_wrong_input_number(binary_arith_op: Callable): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + + with pytest.raises(TypeError): + binary_arith_op(x, x, x) + with pytest.raises(TypeError): + binary_arith_op(x) + with pytest.raises(TypeError): + binary_arith_op(x, x, x, x) + + +def test_binary_infer_struct_info_wrong_input_type(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(binary_arith_op(x0, y)) + with pytest.raises(TVMError): + bb.normalize(binary_arith_op(x1, y)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_create.py b/tests/python/relax/test_op_create.py new file mode 100644 index 000000000000..6dd0a0d15ead --- /dev/null +++ b/tests/python/relax/test_op_create.py @@ -0,0 +1,638 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + fill_value = relax.Var("fill_value", R.Tensor((), "float32")) + assert relax.op.full((2, 3), fill_value).op == Op.get("relax.full") + assert relax.op.full_like(x, fill_value).op == Op.get("relax.full_like") + assert relax.op.ones((2, 3), "float32").op == Op.get("relax.ones") + assert relax.op.ones_like(x).op == Op.get("relax.ones_like") + assert relax.op.zeros((2, 3), "float32").op == Op.get("relax.zeros") + assert relax.op.zeros_like(x).op == Op.get("relax.zeros_like") + assert relax.op.tril(x).op == Op.get("relax.tril") + assert relax.op.triu(x).op == Op.get("relax.triu") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_full_infer_struct_info(): + bb = relax.BlockBuilder() + v0 = relax.Var("v", R.Tensor((), "float32")) + v1 = relax.Var("v", R.Tensor("float32", ndim=0)) + v2 = relax.Var("v", R.Tensor(())) + v3 = relax.Var("v", R.Tensor(ndim=0)) + s0 = relax.ShapeExpr((2, 3)) + s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + + _check_inference( + bb, relax.op.full((2, 3), v0, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full((2, 3), v0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.full(s0, v0, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full(s0, v0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full(s1, v0, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v0), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full(s2, v0, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v0), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full(s3, v0, "float16"), relax.TensorStructInfo(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v0), relax.TensorStructInfo(s3, "float32")) + _check_inference( + bb, relax.op.full((2, 3), v1, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full((2, 3), v1), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.full(s0, v1, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full(s0, v1), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full(s1, v1, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full(s2, v1, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v1), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full(s3, v1, "float16"), relax.TensorStructInfo(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v1), relax.TensorStructInfo(s3, "float32")) + _check_inference( + bb, relax.op.full((2, 3), v2, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full((2, 3), v2), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference( + bb, relax.op.full(s0, v2, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full(s0, v2), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full(s1, v2, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v2), relax.TensorStructInfo(s1, dtype="")) + _check_inference(bb, relax.op.full(s2, v2, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v2), relax.TensorStructInfo(s2, dtype="")) + _check_inference(bb, relax.op.full(s3, v2, "float16"), relax.TensorStructInfo(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v2), relax.TensorStructInfo(s3, dtype="")) + _check_inference( + bb, relax.op.full((2, 3), v3, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full((2, 3), v3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference( + bb, relax.op.full(s0, v3, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full(s0, v3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full(s1, v3, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference( + bb, + relax.op.full( + s1, + v3, + ), + relax.TensorStructInfo(s1, dtype=""), + ) + _check_inference(bb, relax.op.full(s2, v3, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference( + bb, + relax.op.full( + s2, + v3, + ), + relax.TensorStructInfo(s2, dtype=""), + ) + _check_inference(bb, relax.op.full(s3, v3, "float16"), relax.TensorStructInfo(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v3), relax.TensorStructInfo(s3, dtype="")) + + +def test_full_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + v = relax.Var("v", R.Tensor((), "float32")) + s0 = relax.ShapeExpr((a, 3)) + s1 = relax.Var("s", relax.ShapeStructInfo((a, 3))) + + _check_inference( + bb, relax.op.full((a, 3), v, "float16"), relax.TensorStructInfo((a, 3), "float16") + ) + _check_inference(bb, relax.op.full((a, 3), v), relax.TensorStructInfo((a, 3), "float32")) + _check_inference(bb, relax.op.full(s0, v, "float16"), relax.TensorStructInfo((a, 3), "float16")) + _check_inference(bb, relax.op.full(s0, v), relax.TensorStructInfo((a, 3), "float32")) + _check_inference(bb, relax.op.full(s1, v, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v), relax.TensorStructInfo(s1, "float32")) + + +def test_full_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(())) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + v0 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) + v1 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, relax.op.full((2, 3), v0, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.full((2, 3), v1, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + + +def test_full_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + v0 = relax.Var("v", R.Tensor((), "float16")) + v1 = relax.Var("v", R.Tensor((), "int8")) + v2 = relax.Var("v", R.Tensor((), "int32")) + + _check_inference( + bb, relax.op.full((2, 3), v0, "float32"), relax.TensorStructInfo((2, 3), "float32") + ) + _check_inference(bb, relax.op.full((2, 3), v0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference( + bb, relax.op.full((2, 3), v1, "int32"), relax.TensorStructInfo((2, 3), "int32") + ) + _check_inference(bb, relax.op.full((2, 3), v1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.full((2, 3), v2, "int8"), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.full((2, 3), v2), relax.TensorStructInfo((2, 3), "int32")) + + +def test_full_infer_struct_info_fill_value_not_scalar_tensor(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((1,))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + v0 = relax.Var("v", R.Tensor((1,), "float32")) + v1 = relax.Var("v", R.Tensor("float32", ndim=1)) + v2 = relax.Var("v", R.Tensor("float32")) + v3 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) + v4 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) + v5 = relax.Var("v", relax.TensorStructInfo(s2, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v5)) + + +def test_full_shape_not_tuple(): + m = tir.Var("m", "int64") + v = relax.Var("v", R.Tensor((), "float32")) + + with pytest.raises(TVMError): + relax.op.full(4, v) + with pytest.raises(TVMError): + relax.op.full(m, v) + + +def test_full_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + v0 = relax.Var("v", R.Tensor((), "float32")) + v1 = relax.Var("v", relax.ShapeStructInfo(())) + v2 = relax.Var("v", relax.FuncStructInfo([], R.Tensor((), "float32"))) + s = relax.Var("s", R.Tensor((2, 3))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.full(s, v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v2)) + + +def test_full_like_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor(ndim=2)) + x5 = relax.Var("x", R.Tensor()) + v0 = relax.Var("v", R.Tensor((), "float16")) + v1 = relax.Var("v", R.Tensor("float16", ndim=0)) + v2 = relax.Var("v", R.Tensor(())) + v3 = relax.Var("v", R.Tensor(ndim=0)) + + _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x0, v2), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x0, v3), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.full_like(x1, v0), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.full_like(x1, v1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.full_like(x1, v2), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.full_like(x1, v3), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.full_like(x2, v0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.full_like(x2, v1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.full_like(x2, v3), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x3, v2), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x3, v3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x4, v0), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x4, v1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x4, v2), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x4, v3), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x5, v0), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.full_like(x5, v1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.full_like(x5, v2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.full_like(x5, v3), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.full_like(x0, v0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.full_like(x0, v2, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.full_like(x3, v0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.full_like(x3, v2, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + + +def test_full_like_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + v = relax.Var("v", R.Tensor((), "float16")) + + _check_inference(bb, relax.op.full_like(x0, v), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.full_like(x1, v), relax.TensorStructInfo((m, n), dtype="")) + + +def test_full_like_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", R.Tensor((2, 3), "float32")) + sv0 = relax.Var("sv", relax.ShapeStructInfo(())) + sv1 = relax.Var("sv", relax.ShapeStructInfo(ndim=0)) + v0 = relax.Var("v", relax.TensorStructInfo(sv0, "float16")) + v1 = relax.Var("v", relax.TensorStructInfo(sv1, "float16")) + v2 = relax.Var("v", R.Tensor((), "float16")) + + _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.full_like(x0, v2), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full_like(x1, v2), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full_like(x2, v0), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full_like(x2, v1), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorStructInfo((2, 3), "float32")) + + +def test_full_like_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + v0 = relax.Var("v", R.Tensor((), "int32")) + v1 = relax.Var("v", R.Tensor((), "float64")) + + _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorStructInfo((2, 3), "int8")) + + +def test_full_like_infer_struct_info_fill_value_not_scalar_tensor(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + s0 = relax.Var("s", relax.ShapeStructInfo((1,))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + v0 = relax.Var("v", R.Tensor((1,), "float32")) + v1 = relax.Var("v", R.Tensor("float32", ndim=1)) + v2 = relax.Var("v", R.Tensor("float32")) + v3 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) + v4 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) + v5 = relax.Var("v", relax.TensorStructInfo(s2, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v5)) + + +def test_full_like_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((), "float32"))) + x2 = relax.Var("x", R.Tensor((2, 3))) + v0 = relax.Var("v", R.Tensor(())) + v1 = relax.Var("v", relax.ShapeStructInfo(())) + + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x0, v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x1, v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x2, v1)) + + +def test_ones_zeros_infer_struct_info(): + bb = relax.BlockBuilder() + s0 = relax.ShapeExpr((2, 3)) + s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + + _check_inference( + bb, relax.op.ones((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32") + ) + _check_inference(bb, relax.op.ones(s0, "float32"), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.ones(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.ones(s2, "float32"), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.ones(s3, "float32"), relax.TensorStructInfo(s3, "float32")) + _check_inference( + bb, relax.op.zeros((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32") + ) + _check_inference(bb, relax.op.zeros(s0, "float32"), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.zeros(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.zeros(s2, "float32"), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.zeros(s3, "float32"), relax.TensorStructInfo(s3, "float32")) + + +def test_ones_zeros_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + s0 = relax.ShapeExpr((m, n)) + s1 = relax.Var("s", relax.ShapeStructInfo((m, n))) + + _check_inference( + bb, relax.op.ones((m, n), "float32"), relax.TensorStructInfo((m, n), "float32") + ) + _check_inference(bb, relax.op.ones(s0, "float32"), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.ones(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + _check_inference( + bb, relax.op.zeros((m, n), "float32"), relax.TensorStructInfo((m, n), "float32") + ) + _check_inference(bb, relax.op.zeros(s0, "float32"), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.zeros(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + + +def test_ones_zeros_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + s0 = relax.ShapeExpr((2, 3)) + s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + + _check_inference(bb, relax.op.ones(s0, "float16"), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.ones(s1, "int8"), relax.TensorStructInfo(s1, "int8")) + _check_inference(bb, relax.op.zeros(s2, "int32"), relax.TensorStructInfo(s2, "int32")) + _check_inference(bb, relax.op.zeros(s3, "float64"), relax.TensorStructInfo(s3, "float64")) + + +def test_ones_zeros_shape_not_tuple(): + m = tir.Var("m", "int64") + + with pytest.raises(TVMError): + relax.op.ones(10, "float32") + with pytest.raises(TVMError): + relax.op.zeros(m, "float32") + + +def test_ones_zeros_wrong_dtype(): + with pytest.raises(TypeError): + relax.op.ones((2, 3)) + with pytest.raises(TVMError): + relax.op.ones((2, 3), "") + with pytest.raises(TypeError): + relax.op.zeros((2, 3)) + with pytest.raises(TVMError): + relax.op.zeros((2, 3), "") + + +def test_ones_zeros_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", R.Tensor((2, 3))) + s1 = relax.Var("s", relax.FuncStructInfo([], R.Tensor((2, 3)))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ones(s0, "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.zeros(s1, "float32")) + + +def test_ones_like_zeros_like_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor(ndim=2)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.ones_like(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.zeros_like(x3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.ones_like(x4), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.zeros_like(x5), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.ones_like(x0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.zeros_like(x3, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + + +def test_ones_like_zeros_like_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + + _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo((m, n), dtype="")) + + +def test_ones_like_zeros_like_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.zeros_like(x2), relax.TensorStructInfo(s2, "float32")) + + +def test_ones_like_zeros_like_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + + _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo((2, 3), "int8")) + + +def test_ones_like_zeros_like_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ones_like(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.zeros_like(x1)) + + +def test_tril_triu_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.tril(x0, k=1), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference(bb, relax.op.triu(x0, k=0), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.triu(x3), relax.TensorStructInfo((2, 3, 4), dtype="")) + _check_inference(bb, relax.op.tril(x4), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.triu(x5), relax.TensorStructInfo(dtype="")) + + +def test_tril_triu_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((a, b, c), "float32")) + x1 = relax.Var("x", R.Tensor((a, b, c))) + + _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((a, b, c), "float32")) + _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo((a, b, c), dtype="")) + + +def test_tril_triu_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo(s2, "float32")) + + +def test_tril_triu_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) + + _check_inference(bb, relax.op.triu(x0), relax.TensorStructInfo((2, 3, 4), "float16")) + _check_inference(bb, relax.op.tril(x1), relax.TensorStructInfo((2, 3, 4), "int8")) + _check_inference(bb, relax.op.triu(x2), relax.TensorStructInfo((2, 3, 4), "int32")) + + +def test_tril_triu_infer_struct_info_less_than_two_ndim(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2,))) + s1 = relax.Var("s", relax.ShapeStructInfo(())) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + x0 = relax.Var("x", R.Tensor((2,), "float32")) + x1 = relax.Var("x", R.Tensor((), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=1)) + x3 = relax.Var("x", R.Tensor("float32", ndim=0)) + x4 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x6 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x7 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x5)) + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x7)) + + +def test_tril_triu_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_datatype.py b/tests/python/relax/test_op_datatype.py new file mode 100644 index 000000000000..48820b9e2e00 --- /dev/null +++ b/tests/python/relax/test_op_datatype.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np # type: ignore + + +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + c = relax.Constant(tvm.nd.array(np.array([1, 2, 3], dtype="float16"))) + assert relax.op.astype(x, "float16").op == Op.get("relax.astype") + assert relax.op.wrap_param(c, "float32").op == Op.get("relax.wrap_param") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_astype_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor(ndim=2)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo((2, 3), "float16")) + _check_inference( + bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo(dtype="float16", ndim=2) + ) + _check_inference(bb, relax.op.astype(x2, "float16"), relax.TensorStructInfo(dtype="float16")) + _check_inference(bb, relax.op.astype(x3, "float16"), relax.TensorStructInfo((2, 3), "float16")) + _check_inference( + bb, relax.op.astype(x4, "float16"), relax.TensorStructInfo(dtype="float16", ndim=2) + ) + _check_inference(bb, relax.op.astype(x5, "float16"), relax.TensorStructInfo(dtype="float16")) + + +def test_astype_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + + _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo((m, n), "float16")) + _check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo((m, n), "float16")) + + +def test_astype_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo(s0, "float16")) + _check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.astype(x2, "float16"), relax.TensorStructInfo(s2, "float16")) + + +def test_astype_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int32")) + + _check_inference(bb, relax.op.astype(x0, "float32"), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.astype(x1, "int32"), relax.TensorStructInfo((2, 3), "int32")) + _check_inference(bb, relax.op.astype(x2, "int8"), relax.TensorStructInfo((2, 3), "int8")) + + +def test_astype_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.astype(x0, "float16")) + with pytest.raises(TVMError): + bb.normalize(relax.op.astype(x1, "float16")) + + +def test_wrap_param_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Constant(tvm.nd.array(np.zeros([1, 2, 3], dtype="float16"))) + x1 = relax.Constant(tvm.nd.array(np.zeros([1, 2, 3], dtype="int8"))) + _check_inference( + bb, relax.op.wrap_param(x0, "float32"), relax.TensorStructInfo((1, 2, 3), "float32") + ) + _check_inference( + bb, relax.op.wrap_param(x1, "int32"), relax.TensorStructInfo((1, 2, 3), "int32") + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_image.py b/tests/python/relax/test_op_image.py new file mode 100644 index 000000000000..b06b51a2a198 --- /dev/null +++ b/tests/python/relax/test_op_image.py @@ -0,0 +1,245 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + assert relax.op.image.resize2d(x, (28, 28)).op == Op.get("relax.image.resize2d") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_resize2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) + x3 = relax.Var("x", R.Tensor("float32", ndim=4)) + x4 = relax.Var("x", R.Tensor("float32", ndim=5)) + x5 = relax.Var("x", R.Tensor("float32")) + x6 = relax.Var("x", R.Tensor(ndim=4)) + x7 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.image.resize2d(x0, (28, 28)), relax.TensorStructInfo((2, 3, 28, 28), "float32") + ) + _check_inference( + bb, + relax.op.image.resize2d(x0, size=28), + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x0, size=(28, 30)), + relax.TensorStructInfo((2, 3, 28, 30), "float32"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x1, size=28, layout="NHWC"), + relax.TensorStructInfo((2, 28, 28, 3), "float32"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x0, size=28, out_dtype="float16"), + relax.TensorStructInfo((2, 3, 28, 28), "float16"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x2, size=28, layout="NCHW16c"), + relax.TensorStructInfo((2, 4, 28, 28, 16), "float32"), + ) + _check_inference( + bb, relax.op.image.resize2d(x3, size=28), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, + relax.op.image.resize2d(x4, size=28, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, relax.op.image.resize2d(x5, size=28), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.image.resize2d(x6, size=28), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference( + bb, + relax.op.image.resize2d(x6, size=28, out_dtype="float32"), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, relax.op.image.resize2d(x7, size=28), relax.TensorStructInfo(dtype="", ndim=4) + ) + + +def test_resize2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + oh = tir.Var("oh", "int64") + ow = tir.Var("ow", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, 16), "float32")) + + _check_inference( + bb, relax.op.image.resize2d(x0, size=oh), relax.TensorStructInfo((n, c, oh, oh), "float32") + ) + _check_inference( + bb, + relax.op.image.resize2d(x0, size=(oh, ow)), + relax.TensorStructInfo((n, c, oh, ow), "float32"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x1, size=(oh, ow), layout="NCHW16c"), + relax.TensorStructInfo((n, c, oh, ow, 16), "float32"), + ) + + +def test_resize2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.image.resize2d(x0, size=32), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, + relax.op.image.resize2d(x1, size=32, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.image.resize2d(x2, size=32, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + + +def test_resize2d_infer_struct_info_pool_size_var(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + s0 = relax.Var("s", relax.ShapeStructInfo((30, 30))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + + _check_inference( + bb, + relax.op.image.resize2d(x0, s0), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, relax.op.image.resize2d(x0, s1), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + + +def test_resize2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) + _check_inference( + bb, relax.op.image.resize2d(x0, size=28), relax.TensorStructInfo((2, 3, 28, 28), "float16") + ) + _check_inference( + bb, relax.op.image.resize2d(x1, size=28), relax.TensorStructInfo((2, 3, 28, 28), "int8") + ) + _check_inference( + bb, relax.op.image.resize2d(x2, size=28), relax.TensorStructInfo((2, 3, 28, 28), "int64") + ) + + +def test_resize2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x, size=28, layout="OIHW")) + + +def test_resize2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, size=28, layout="NCHW16c")) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x1, size=28, layout="NCHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x2, size=28)) + + +def test_resize2d_wrong_pool_size_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + s0 = relax.ShapeExpr((3,)) + s1 = relax.Var("s", relax.ShapeStructInfo((30, 30, 30))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s4 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + s5 = relax.Var("s", relax.ShapeStructInfo()) + + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, (3, 3, 3))) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s5)) + + +def test_resize2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + s0 = relax.Var("s", R.Tensor((3, 3))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, size=32)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x1, size=32)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x2, s0)) + with pytest.raises(TVMError): + relax.op.image.resize2d(x2, [30, 30]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_index.py b/tests/python/relax/test_op_index.py new file mode 100644 index 000000000000..a84e70a0eb2e --- /dev/null +++ b/tests/python/relax/test_op_index.py @@ -0,0 +1,626 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + idx = relax.Var("idx", R.Tensor((2,), "float32")) + assert relax.op.take(x, idx, axis=1).op == Op.get("relax.take") + assert relax.op.strided_slice(x, axes=[0], begin=[0], end=[2]).op == Op.get( + "relax.strided_slice" + ) + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_take_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((4, 10), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((4, 10))) + x4 = relax.Var("x", R.Tensor(ndim=2)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((10,), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=1)) + y2 = relax.Var("y", R.Tensor((10,))) + y3 = relax.Var("y", R.Tensor(ndim=1)) + idx0 = relax.Var("idx", R.Tensor((6,), "int64")) + idx1 = relax.Var("idx", R.Tensor("int64", ndim=1)) + idx2 = relax.Var("idx", R.Tensor((6,))) + idx3 = relax.Var("idx", R.Tensor(ndim=1)) + + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float32")) + _check_inference( + bb, relax.op.take(x0, idx0, axis=-1), relax.TensorStructInfo((4, 6), "float32") + ) + _check_inference( + bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx0, axis=1), relax.TensorStructInfo((4, 6), dtype="")) + _check_inference(bb, relax.op.take(x4, idx0, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx0, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx1, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x4, idx1, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx1, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo((4, 6), "float32")) + _check_inference( + bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx2, axis=1), relax.TensorStructInfo((4, 6), dtype="")) + _check_inference(bb, relax.op.take(x4, idx2, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx2, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.take(x0, idx3, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx3, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx3, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx3, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x4, idx3, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx3, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((6,), "float32")) + _check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx0), relax.TensorStructInfo((6,), dtype="")) + _check_inference(bb, relax.op.take(y3, idx0), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx1), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y1, idx1), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx1), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y3, idx1), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx2), relax.TensorStructInfo((6,), "float32")) + _check_inference(bb, relax.op.take(y1, idx2), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx2), relax.TensorStructInfo((6,), dtype="")) + _check_inference(bb, relax.op.take(y3, idx2), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx3), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y1, idx3), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx3), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y3, idx3), relax.TensorStructInfo(dtype="", ndim=1)) + + +def test_take_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + i = tir.Var("i", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + y0 = relax.Var("y", R.Tensor((n,), "float32")) + y1 = relax.Var("y", R.Tensor((n,))) + idx0 = relax.Var("idx", R.Tensor((i,), "int64")) + idx1 = relax.Var( + "idx", + R.Tensor( + (i,), + ), + ) + + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((m, i), "float32")) + _check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo((m, i), dtype="")) + _check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo((m, i), "float32")) + _check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo((m, i), dtype="")) + _check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((i,), "float32")) + _check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo((i,), dtype="")) + _check_inference(bb, relax.op.take(y0, idx1), relax.TensorStructInfo((i,), "float32")) + _check_inference(bb, relax.op.take(y1, idx1), relax.TensorStructInfo((i,), dtype="")) + + +def test_take_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + sx0 = relax.Var("sx", relax.ShapeStructInfo((4, 10))) + sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=2)) + sx2 = relax.Var("sx", relax.ShapeStructInfo()) + sidx0 = relax.Var("sidx", relax.ShapeStructInfo((6,))) + sidx1 = relax.Var("sidx", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) + x3 = relax.Var("x", R.Tensor((4, 10), "float32")) + idx0 = relax.Var("idx", relax.TensorStructInfo(sidx0, "int64")) + idx1 = relax.Var("idx", relax.TensorStructInfo(sidx1, "int64")) + idx2 = relax.Var("idx", R.Tensor((6,), "int64")) + + _check_inference( + bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.take(x3, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x3, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + + +def test_take_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((4, 10), "float16")) + x1 = relax.Var("x", R.Tensor((4, 10), "int16")) + x2 = relax.Var("x", R.Tensor((4, 10), "int32")) + idx0 = relax.Var("idx", R.Tensor((6,), "int32")) + idx1 = relax.Var("idx", R.Tensor((6,), "int8")) + idx2 = relax.Var("idx", R.Tensor((6,), "uint32")) + + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float16")) + _check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo((4, 6), "int16")) + _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo((4, 6), "int32")) + _check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo((4, 6), "float16")) + _check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo((4, 6), "int16")) + _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo((4, 6), "int32")) + _check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo((4, 6), "float16")) + _check_inference(bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo((4, 6), "int16")) + _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo((4, 6), "int32")) + + +def test_take_infer_struct_info_indices_not_one_dimensional(): + bb = relax.BlockBuilder() + sidx0 = relax.Var("sidx", relax.ShapeStructInfo((6, 6))) + sidx1 = relax.Var("sidx", relax.ShapeStructInfo(())) + sidx2 = relax.Var("sidx", relax.ShapeStructInfo(ndim=2)) + sidx3 = relax.Var("sidx", relax.ShapeStructInfo(ndim=0)) + sidx4 = relax.Var("sidx", relax.ShapeStructInfo()) + x = relax.Var("x", R.Tensor((4, 10), "float32")) + idx0 = relax.Var("idx", R.Tensor((6, 6), "int64")) + idx1 = relax.Var("idx", R.Tensor((), "int64")) + idx2 = relax.Var("idx", R.Tensor("int64", ndim=2)) + idx3 = relax.Var("idx", R.Tensor("int64", ndim=0)) + idx4 = relax.Var("idx", R.Tensor("int64")) + idx5 = relax.Var("idx", relax.TensorStructInfo(sidx0, "int64")) + idx6 = relax.Var("idx", relax.TensorStructInfo(sidx1, "int64")) + idx7 = relax.Var("idx", relax.TensorStructInfo(sidx2, "int64")) + idx8 = relax.Var("idx", relax.TensorStructInfo(sidx3, "int64")) + idx9 = relax.Var("idx", relax.TensorStructInfo(sidx4, "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx1, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx2, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx3, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx4, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx5, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx6, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx7, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx8, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx9, axis=1)) + + +def test_take_infer_struct_info_indices_not_integer_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((4, 10), "float32")) + idx0 = relax.Var("idx", R.Tensor((6, 6), "float32")) + idx1 = relax.Var("idx", R.Tensor((6, 6), "float64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx1, axis=1)) + + +def test_take_infer_struct_info_multi_dimensional_without_axis(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((4, 10), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + idx0 = relax.Var("idx", R.Tensor((6,), "int64")) + idx1 = relax.Var("idx", R.Tensor("int64", ndim=1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x0, idx0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x1, idx0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x2, idx0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x0, idx1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x1, idx1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x2, idx1)) + + +def test_take_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((4, 10), "float32")) + idx = relax.Var("idx", R.Tensor((6,), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx, axis=-3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx, axis=2)) + + +def test_take_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((4, 10))) + x1 = relax.Var("x", R.Tensor((4, 10), "float32")) + idx0 = relax.Var("idx", relax.ShapeStructInfo((6,))) + idx1 = relax.Var("idx", R.Tensor((6,), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x0, idx1, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x1, idx0, axis=1)) + + +def test_strided_slice_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((8, 9, 10, 10))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.strided_slice( + x0, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo((4, 9, 10, 3), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice( + x1, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.strided_slice( + x2, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice( + x3, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo((4, 9, 10, 3), dtype=""), + ) + _check_inference( + bb, + relax.op.strided_slice( + x4, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo(dtype="", ndim=4), + ) + _check_inference( + bb, + relax.op.strided_slice( + x5, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo(dtype=""), + ) + _check_inference( + bb, + relax.op.strided_slice( + x0, axes=[-1, -3, -4], begin=[8, 0, 1], end=[0, 9, 8], strides=[-3, 1, 2] + ), + relax.TensorStructInfo((4, 9, 10, 3), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[1, 2], begin=[1, 0], end=[8, 9]), + relax.TensorStructInfo((8, 7, 9, 10), "float32"), + ) + + +def test_strided_slice_infer_struct_info_shape_out_of_range(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((20, 10, 5), "float32")) + _check_inference( + bb, + relax.op.strided_slice( + x0, axes=[0, 1, 2], begin=[20, 10, 4], end=[0, 0, 1], strides=[-1, -3, -2] + ), + relax.TensorStructInfo((19, 3, 2), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice( + x0, axes=[0, 1, 2], begin=[200, 10, 4], end=[0, 0, 1], strides=[-1, -3, -2] + ), + relax.TensorStructInfo((19, 3, 2), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice( + x0, axes=[0, 1, 2], begin=[200, 10, 100], end=[0, 0, 1], strides=[-1, -3, -5] + ), + relax.TensorStructInfo((19, 3, 1), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice( + x0, axes=[0, 1, 2], begin=[-21, -11, -6], end=[1, 1, 1], strides=[1000, 1000, 1000] + ), + relax.TensorStructInfo((1, 1, 1), "float32"), + ) + + +def test_strided_slice_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[0], begin=[1], end=[3]), + relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m) + 1 - 1, n), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[0], begin=[1], end=[8], strides=[3]), + relax.TensorStructInfo(((tir.min(8, m) - tir.min(1, m) + 3 - 1) // 3, n), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[0], begin=[1], end=[3]), + relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m) + 1 - 1, n), dtype=""), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[0], begin=[1], end=[8], strides=[3]), + relax.TensorStructInfo(((tir.min(8, m) - tir.min(1, m) + 3 - 1) // 3, n), dtype=""), + ) + + +def test_strided_slice_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((8, 10))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, dtype="")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, dtype="")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, dtype="")) + + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x2, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x3, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x4, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x5, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype=""), + ) + + +def test_strided_slice_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((8, 9), "float16")) + x1 = relax.Var("x", R.Tensor((8, 9), "int32")) + x2 = relax.Var("x", R.Tensor((8, 9), "int64")) + + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo((8, 9), "float16"), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo((8, 9), "int32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x2, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo((8, 9), "int64"), + ) + + +def test_strided_slice_infer_struct_info_symbolic_begin_end_strides(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[a], end=[8]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[0], end=[a]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[a]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + + +def test_strided_slice_infer_struct_info_no_axis(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((m, n))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor(dtype="float32", ndim=2)) + x2 = relax.Var("x", R.Tensor(dtype="float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[], begin=[], end=[]), + relax.TensorStructInfo((m, n), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[], begin=[], end=[]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x2, axes=[], begin=[], end=[]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x3, axes=[], begin=[], end=[]), + relax.TensorStructInfo(s0, "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x4, axes=[], begin=[], end=[]), + relax.TensorStructInfo(s1, "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x5, axes=[], begin=[], end=[]), + relax.TensorStructInfo(s2, "float32"), + ) + + +def test_strided_slice_begin_end_strides_int64(): + x = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) + strided_slice = relax.op.strided_slice( + x, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ) + + assert strided_slice.attrs.begin[0].dtype == "int64" + assert strided_slice.attrs.begin[1].dtype == "int64" + assert strided_slice.attrs.begin[2].dtype == "int64" + assert strided_slice.attrs.end[0].dtype == "int64" + assert strided_slice.attrs.end[1].dtype == "int64" + assert strided_slice.attrs.end[2].dtype == "int64" + assert strided_slice.attrs.strides[0].dtype == "int64" + assert strided_slice.attrs.strides[1].dtype == "int64" + assert strided_slice.attrs.strides[2].dtype == "int64" + + +def test_strided_slice_inconsistent_axes_begin_end_strides_length(): + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + with pytest.raises(TVMError): + relax.op.strided_slice(x, axes=[1], begin=[], end=[9]) + with pytest.raises(TVMError): + relax.op.strided_slice(x, axes=[1], begin=[0], end=[]) + with pytest.raises(TVMError): + relax.op.strided_slice(x, axes=[1], begin=[0], end=[9], strides=[]) + + +def test_strided_slice_infer_struct_info_repetitive_axes(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x, axes=[0, 0], begin=[0, 0], end=[8, 8])) + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x, axes=[0, -2], begin=[0, 0], end=[8, 8])) + + +def test_strided_slice_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x, axes=[2], begin=[0], end=[8])) + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x, axes=[-3], begin=[0], end=[8])) + + +def test_strided_slice_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((8, 9))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((8, 9), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8])) + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8])) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_linear_algebra.py b/tests/python/relax/test_op_linear_algebra.py new file mode 100644 index 000000000000..5eb19cf2b420 --- /dev/null +++ b/tests/python/relax/test_op_linear_algebra.py @@ -0,0 +1,244 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((3, 4), "float32")) + assert relax.op.matmul(x, y).op == Op.get("relax.matmul") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_matmul_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((4,), "float32")) + x2 = relax.Var("x", R.Tensor((2, 3, 5, 4), "float32")) + x3 = relax.Var("x", R.Tensor((2, 1, 4, 5), "float32")) + x4 = relax.Var("x", R.Tensor((2, 1, 4, 5))) + x5 = relax.Var("x", R.Tensor("float32")) + x6 = relax.Var("x", R.Tensor((2, 1, 4, 5), "float16")) + y0 = relax.Var("y", R.Tensor((4, 5), "float32")) + y1 = relax.Var("y", R.Tensor((4,), "float32")) + y2 = relax.Var("y", R.Tensor((2, 3, 4, 5), "float32")) + y3 = relax.Var("y", R.Tensor((6, 1, 3, 5, 7), "float32")) + y4 = relax.Var("y", R.Tensor("float32", ndim=5)) + y5 = relax.Var("y", R.Tensor()) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((3, 5), "float32")) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((), "float32")) + _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo((2, 3, 5), "float32")) + _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo((2, 3, 5), "float32")) + _check_inference( + bb, relax.op.matmul(x3, y3), relax.TensorStructInfo((6, 2, 3, 4, 7), "float32") + ) + _check_inference(bb, relax.op.matmul(x4, y3), relax.TensorStructInfo((6, 2, 3, 4, 7), "")) + _check_inference(bb, relax.op.matmul(x3, y4), relax.TensorStructInfo(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.matmul(x5, y3), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.matmul(x3, y5), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, + relax.op.matmul(x3, y3, out_dtype="float16"), + relax.TensorStructInfo((6, 2, 3, 4, 7), "float16"), + ) + _check_inference( + bb, + relax.op.matmul(x6, y3, out_dtype="float16"), + relax.TensorStructInfo((6, 2, 3, 4, 7), "float16"), + ) + + +def test_matmul_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + k0 = tir.Var("k0", "int64") + k1 = tir.Var("k1", "int64") + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + b1 = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((m, k0), "float32")) + x1 = relax.Var("x", R.Tensor((k0,), "float32")) + x2 = relax.Var("x", R.Tensor((a, b, m, k0), "float32")) + x3 = relax.Var("x", R.Tensor((b, 1, m, k0), "float32")) + x4 = relax.Var("x", R.Tensor((b, 1, m, k1), "float32")) + y0 = relax.Var("y", R.Tensor((k0, n), "float32")) + y1 = relax.Var("y", R.Tensor((k0,), "float32")) + y2 = relax.Var("y", R.Tensor((a, b, k0, n), "float32")) + y3 = relax.Var("y", R.Tensor((a, 1, c, k0, n), "float32")) + y4 = relax.Var("y", R.Tensor((a, b1, c, k0, n), "float32")) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((), "float32")) + _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo((a, b, n), "float32")) + _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo((a, b, m), "float32")) + _check_inference( + bb, relax.op.matmul(x3, y3), relax.TensorStructInfo((a, b, c, m, n), "float32") + ) + _check_inference( + bb, relax.op.matmul(x4, y3), relax.TensorStructInfo((a, b, c, m, n), "float32") + ) + _check_inference(bb, relax.op.matmul(x3, y4), relax.TensorStructInfo(dtype="float32", ndim=5)) + + +def test_matmul_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s4", relax.ShapeStructInfo(ndim=1)) + s5 = relax.Var("s5", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s5, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + y1 = relax.Var("y", relax.TensorStructInfo(s2, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.matmul(x1, y0), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.matmul(x2, y0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.matmul(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo(dtype="float32", ndim=0)) + _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo(dtype="float32", ndim=0)) + + +def test_matmul_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float16")) + y0 = relax.Var("y", R.Tensor((4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((3, 4), "int8")) + y1 = relax.Var("y", R.Tensor((4, 5), "int8")) + x2 = relax.Var("x", R.Tensor((3, 4), "int64")) + y2 = relax.Var("y", R.Tensor((4, 5), "int64")) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((3, 5), "float16")) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((3, 5), "int8")) + _check_inference(bb, relax.op.matmul(x2, y2), relax.TensorStructInfo((3, 5), "int64")) + + +def test_matmul_infer_struct_info_mixed_precision(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float16")) + y0 = relax.Var("y", R.Tensor((4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((3, 4), "int8")) + y1 = relax.Var("y", R.Tensor((4, 5), "int8")) + x2 = relax.Var("x", R.Tensor((3, 4))) + y2 = relax.Var("y", R.Tensor((4, 5))) + + _check_inference( + bb, + relax.op.matmul(x0, y0, out_dtype="float32"), + relax.TensorStructInfo((3, 5), "float32"), + ) + _check_inference( + bb, relax.op.matmul(x1, y1, out_dtype="int32"), relax.TensorStructInfo((3, 5), "int32") + ) + _check_inference( + bb, + relax.op.matmul(x2, y2, out_dtype="float32"), + relax.TensorStructInfo((3, 5), "float32"), + ) + + +def test_matmul_infer_struct_info_zero_rank_input(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((), "float32")) + y0 = relax.Var("y", R.Tensor((4, 5), "float32")) + y1 = relax.Var("y", R.Tensor((), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x0, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x1, y0)) + + +def test_matmul_infer_struct_info_not_broadcastable(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((2, 8, 3, 5, 6), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x, y)) + + +def test_matmul_infer_struct_info_unequal_reduction_length(): + bb = relax.BlockBuilder() + k = tir.Var("k", "int64") + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((3, k), "float32")) + y0 = relax.Var("y", R.Tensor((6, 5), "float32")) + y1 = relax.Var("y", R.Tensor((k + 1, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x0, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x1, y1)) + + +def test_linear(): + # Since linear is only a sugar for transpose + matmul + add, + # we only have brief tests here. + bb = relax.BlockBuilder() + x1 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x2 = relax.Var("x", R.Tensor("float32")) + w1 = relax.Var("w", R.Tensor((5, 4), "float32")) + w2 = relax.Var("w", R.Tensor((4,), "float32")) + w3 = relax.Var("w", R.Tensor("float32")) + b1 = relax.Var("b", R.Tensor((5,), "float32")) + b2 = relax.Var("b", R.Tensor((), "float32")) + + # Need a scope to normalize non-leaf nodes + with bb.function("func", [x1]): + _check_inference( + bb, relax.op.linear(x1, w1, b1), relax.TensorStructInfo((2, 3, 5), "float32") + ) + _check_inference( + bb, relax.op.linear(x1, w1, b2), relax.TensorStructInfo((2, 3, 5), "float32") + ) + with pytest.raises(TVMError): + bb.normalize(relax.op.linear(x1, w2, b1)) # error on Add with shape (2, 3, 5) and (4,) + _check_inference(bb, relax.op.linear(x1, w2, b2), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.linear(x1, w3, b1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x1, w3, b2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w1, b1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w1, b2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w2, b1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w2, b2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w3, b1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w3, b2), relax.TensorStructInfo(dtype="float32")) + + # Fake output + gv = bb.emit_func_output(relax.Tuple([])) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py new file mode 100644 index 000000000000..3edf63764a58 --- /dev/null +++ b/tests/python/relax/test_op_manipulate.py @@ -0,0 +1,3026 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R +from tvm.tir.expr import FloatImm, IntImm + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + assert relax.op.broadcast_to(x, (3, 3, 4, 5)).op == Op.get("relax.broadcast_to") + assert relax.op.concat([x]).op == Op.get("relax.concat") + assert relax.op.expand_dims(x, axis=[]).op == Op.get("relax.expand_dims") + assert relax.op.flatten(x).op == Op.get("relax.flatten") + assert relax.op.permute_dims(x).op == Op.get("relax.permute_dims") + assert relax.op.reshape(x, (4, 5, 3)).op == Op.get("relax.reshape") + assert relax.op.split(x, indices_or_sections=1).op == Op.get("relax.split") + assert relax.op.tile(x, (2, 2, 2)).op == Op.get("relax.tile") + assert relax.op.repeat(x, 2, 0).op == Op.get("relax.repeat") + assert relax.op.squeeze(x).op == Op.get("relax.squeeze") + assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c, a)).op == Op.get( + "relax.layout_transform" + ) + assert relax.op.collapse_sum_to(x, (4, 5)).op == Op.get("relax.collapse_sum_to") + y = relax.Var("x", R.Tensor((4, 5), "float32")) + assert relax.op.collapse_sum_like(x, y).op == Op.get("relax.collapse_sum_like") + assert relax.op.cumsum(x, axis=1, dtype="int32").op == Op.get("relax.cumsum") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_reshape_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + s0 = relax.Var("s", R.Shape((3, 8, 5))) + s1 = relax.Var("s", R.Shape(ndim=3)) + s2 = relax.Var("s", R.Shape()) + s3 = relax.ShapeExpr((3, 8, 5)) + + _check_inference( + bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference(bb, relax.op.reshape(x0, (-1,)), relax.TensorStructInfo((120,), "float32")) + _check_inference( + bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x3, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") + ) + _check_inference( + bb, relax.op.reshape(x3, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") + ) + _check_inference( + bb, relax.op.reshape(x4, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") + ) + _check_inference( + bb, relax.op.reshape(x5, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") + ) + # Remove Var from StructInfo when we can + _check_inference(bb, relax.op.reshape(x0, s0), relax.TensorStructInfo((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x1, s0), relax.TensorStructInfo((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x2, s0), relax.TensorStructInfo((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x3, s0), relax.TensorStructInfo((3, 8, 5), dtype="")) + _check_inference(bb, relax.op.reshape(x4, s0), relax.TensorStructInfo((3, 8, 5), dtype="")) + _check_inference(bb, relax.op.reshape(x5, s0), relax.TensorStructInfo((3, 8, 5), dtype="")) + _check_inference(bb, relax.op.reshape(x0, s1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.reshape(x1, s1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.reshape(x2, s1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.reshape(x3, s1), relax.TensorStructInfo(s1, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s1), relax.TensorStructInfo(s1, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s1), relax.TensorStructInfo(s1, dtype="")) + _check_inference(bb, relax.op.reshape(x0, s2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.reshape(x1, s2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.reshape(x2, s2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.reshape(x3, s2), relax.TensorStructInfo(s2, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s2), relax.TensorStructInfo(s2, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s2), relax.TensorStructInfo(s2, dtype="")) + _check_inference(bb, relax.op.reshape(x0, s3), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.reshape(x1, s3), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.reshape(x2, s3), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.reshape(x3, s3), relax.TensorStructInfo(s3, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s3), relax.TensorStructInfo(s3, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s3), relax.TensorStructInfo(s3, dtype="")) + + +def test_reshape_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) + s0 = relax.Var("s", R.Shape((c, a, d, b))) + s1 = relax.Var("s", R.Shape()) + s2 = relax.ShapeExpr((c, a, d, b)) + + _check_inference( + bb, relax.op.reshape(x, (c, a, d, b)), relax.TensorStructInfo((c, a, d, b), "float32") + ) + _check_inference( + bb, + relax.op.reshape(x, (d, c, b, -1)), + relax.TensorStructInfo((d, c, b, tir.floordiv(a * b * c * d, d * c * b)), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (1, -1, 1)), + relax.TensorStructInfo((1, a * b * c * d, 1), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (2, -1, a)), + relax.TensorStructInfo((2, tir.floordiv(a * b * c * d, a * 2), a), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (c, -1, d, b)), + relax.TensorStructInfo((c, tir.floordiv(a * b * c * d, c * d * b), d, b), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (c, a * d, b)), + relax.TensorStructInfo((c, a * d, b), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (c, a * b * d, -1)), + relax.TensorStructInfo( + (c, a * b * d, tir.floordiv(a * b * c * d, c * (a * b * d))), "float32" + ), + ) + # Remove Var from StructInfo when we can + _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo((c, a, d, b), "float32")) + _check_inference(bb, relax.op.reshape(x, s1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.reshape(x, s2), relax.TensorStructInfo(s2, "float32")) + + +def test_reshape_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4, 5))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + ns0 = relax.Var("ns", relax.ShapeStructInfo((3, 8, 5))) + ns1 = relax.Var("ns", relax.ShapeStructInfo()) + + _check_inference( + bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x0, (2, 3, 0, 5)), relax.TensorStructInfo((2, 3, 4, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x0, (1, 3, 0, -1)), relax.TensorStructInfo((1, 3, 4, 10), "float32") + ) + _check_inference( + bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + # Remove Var from StructInfo when we can + _check_inference(bb, relax.op.reshape(x0, ns0), relax.TensorStructInfo((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x0, ns1), relax.TensorStructInfo(ns1, "float32")) + _check_inference( + bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + # Remove Var from StructInfo when we can + _check_inference(bb, relax.op.reshape(x1, ns0), relax.TensorStructInfo((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x1, ns1), relax.TensorStructInfo(ns1, "float32")) + _check_inference( + bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + # Remove Var from StructInfo when we can + _check_inference(bb, relax.op.reshape(x2, ns0), relax.TensorStructInfo((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x2, ns1), relax.TensorStructInfo(ns1, "float32")) + + +def test_reshape_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + + _check_inference(bb, relax.op.reshape(x0, (120,)), relax.TensorStructInfo((120,), "float16")) + _check_inference(bb, relax.op.reshape(x1, (120,)), relax.TensorStructInfo((120,), "int8")) + + +def test_reshape_infer_struct_info_unequal_shape_prod(): + bb = relax.BlockBuilder() + s = relax.Var("s", relax.ShapeStructInfo((2, 3, 4, 5))) + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + ns = relax.Var("ns", relax.ShapeStructInfo((4, 4, 1, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, (4, 4, 1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, (4, 4, 1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, (4, 4, -1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, (4, 4, -1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, ns)) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, ns)) + + +def test_reshape_infer_struct_info_inference_not_deducible(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor("float32", ndim=4)) + x1 = relax.Var("x", R.Tensor("float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, (2, 3, -1))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, (2, 3, -1))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x2, (2, 3, -1))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x3, (2, 3, -1))) + + +def test_reshape_new_shape_not_tuple(): + m = tir.Var("m", "int64") + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + + with pytest.raises(TVMError): + relax.op.reshape(x, 120) + with pytest.raises(TVMError): + relax.op.reshape(x, m) + + +def test_reshape_infer_struct_info_new_shape_not_integer(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2.0, 3, 4, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2, 3, -1.0))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2, 3, 4.0, -1))) + + +def test_reshape_infer_struct_info_multiple_dim_inference(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2, -1, -1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (-1, -1, -1, -1))) + + +def test_reshape_infer_struct_info_non_positive_new_shape(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (-2, -3, -4, -5))) + + +def test_reshape_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + ns = relax.Var("ns", relax.TensorStructInfo((120,), "float32")) + pv = relax.Var("pv", relax.PrimStructInfo("int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, (2, 3, 4, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, (2, 3, 4, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x2, ns)) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x2, [pv])) + + +def test_permute_dims_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((1, 2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((1,), "float32")) + x7 = relax.Var("x", R.Tensor((), "float32")) + + _check_inference( + bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x0, axes=None), relax.TensorStructInfo((4, 3, 2, 1), "float32") + ) + _check_inference( + bb, + relax.op.permute_dims(x0, [-2, -3, 3, -4]), + relax.TensorStructInfo((3, 2, 4, 1), "float32"), + ) + _check_inference( + bb, relax.op.permute_dims(x1, [2, 3, 1, 0]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x1, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x2, axes=None), relax.TensorStructInfo(dtype="float32") + ) + _check_inference( + bb, relax.op.permute_dims(x3, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), dtype="") + ) + _check_inference( + bb, relax.op.permute_dims(x3, axes=None), relax.TensorStructInfo((4, 3, 2, 1), dtype="") + ) + _check_inference( + bb, + relax.op.permute_dims(x3, [-2, -3, 3, -4]), + relax.TensorStructInfo((3, 2, 4, 1), dtype=""), + ) + _check_inference( + bb, relax.op.permute_dims(x4, [2, 3, 1, 0]), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x4, axes=None), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference(bb, relax.op.permute_dims(x5, axes=None), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.permute_dims(x6, axes=None), relax.TensorStructInfo((1,), "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x7, axes=None), relax.TensorStructInfo((), "float32") + ) + + +def test_permute_dims_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) + + _check_inference( + bb, relax.op.permute_dims(x, [2, 3, 1, 0]), relax.TensorStructInfo((c, d, b, a), "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x, axes=None), relax.TensorStructInfo((d, c, b, a), "float32") + ) + _check_inference( + bb, + relax.op.permute_dims(x, [-2, -3, 3, -4]), + relax.TensorStructInfo((c, b, d, a), "float32"), + ) + + +def test_permute_dims_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((1, 2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.permute_dims(x0, [0, 1, 2, 3]), relax.TensorStructInfo(s0, "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x0, [-4, -3, -2, -1]), relax.TensorStructInfo(s0, "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x0, [2, 3, 0, 1]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x0, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x1, [0, 1, 2, 3]), relax.TensorStructInfo(s1, "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x1, [2, 3, 0, 1]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x1, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x2, axes=None), relax.TensorStructInfo(dtype="float32") + ) + + +def test_permute_dims_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((1, 2, 3, 4), "int8")) + x2 = relax.Var("x", R.Tensor((1, 2, 3, 4), "int32")) + + _check_inference( + bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "float16") + ) + _check_inference( + bb, relax.op.permute_dims(x1, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "int8") + ) + _check_inference( + bb, relax.op.permute_dims(x2, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "int32") + ) + + +def test_permute_dims_infer_struct_info_unknown_ndim_with_axes(): + bb = relax.BlockBuilder() + s = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor("float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [2, 3, 1, 0])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [2, 3, 1, 0])) + + +def test_permute_dims_infer_struct_info_wrong_number_axes(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((1, 2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [1, 2, 4, 0, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [1, 2, 4, 0, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x2, [0, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x2, [1, 2, 4, 0, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x3, [0, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x3, [1, 2, 4, 0, 3])) + + +def test_permute_dims_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, 3, 4, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, -5, 1, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, 3, 4, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, -5, 1, 3])) + + +def test_permute_dims_infer_struct_info_repetitive_axes(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, 2, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, 2, -2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, 2, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, 2, -2, 1])) + + +def test_permute_dims_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((1, 2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((1, 2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1)) + + +def test_expand_dims_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "float32") + ) + _check_inference( + bb, + relax.op.expand_dims(x0, [-1, 1, -6, 3, 5]), + relax.TensorStructInfo((2, 1, 1, 1, 3, 1, 4, 1), "float32"), + ) + _check_inference(bb, relax.op.expand_dims(x0, []), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference( + bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.expand_dims(x1, []), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.expand_dims(x2, []), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.expand_dims(x3, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), dtype="") + ) + _check_inference( + bb, + relax.op.expand_dims(x3, [-1, 1, -6, 3, 5]), + relax.TensorStructInfo((2, 1, 1, 1, 3, 1, 4, 1), dtype=""), + ) + _check_inference(bb, relax.op.expand_dims(x3, []), relax.TensorStructInfo((2, 3, 4), dtype="")) + _check_inference(bb, relax.op.expand_dims(x4, [1, 3]), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.expand_dims(x4, []), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.expand_dims(x5, [1, 3]), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.expand_dims(x5, []), relax.TensorStructInfo(dtype="")) + + +def test_expand_dims_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x = relax.Var("x", R.Tensor((a, 4, b), "float32")) + + _check_inference( + bb, relax.op.expand_dims(x, [1, 3]), relax.TensorStructInfo((a, 1, 4, 1, b), "float32") + ) + _check_inference( + bb, + relax.op.expand_dims(x, [-1, 1, -6, 3, 5]), + relax.TensorStructInfo((a, 1, 1, 1, 4, 1, b, 1), "float32"), + ) + _check_inference(bb, relax.op.expand_dims(x, []), relax.TensorStructInfo((a, 4, b), "float32")) + + +def test_expand_dims_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.expand_dims(x0, []), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.expand_dims(x1, []), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.expand_dims(x2, []), relax.TensorStructInfo(s2, "float32")) + + +def test_expand_dims_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) + + _check_inference( + bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "float16") + ) + _check_inference( + bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "int8") + ) + _check_inference( + bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "int32") + ) + + +def test_expand_dims_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", relax.TensorStructInfo(s0)) + x3 = relax.Var("x", relax.TensorStructInfo(s1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, [1, 5])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, [-6, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, [1, 5])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, [-6, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x2, [1, 5])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x2, [-6, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x3, [1, 5])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x3, [-6, 1])) + + +def test_expand_dims_infer_struct_info_repetitive_axes(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", relax.TensorStructInfo(s0)) + x3 = relax.Var("x", relax.TensorStructInfo(s1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, [1, -4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, [1, -4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x2, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x2, [1, -4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x3, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x3, [1, -4])) + + +def test_expand_dims_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, axis=[])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, axis=[])) + + +def test_layout_transform_infer_struct_info(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) + + transpose_transform = lambda a, b, c: (a, c, b) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=transpose_transform), + relax.TensorStructInfo((10, 30, 20), "float32"), + ) + + tiling_transform = lambda a, b, c: (a, b // 2, c, b % 2) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=tiling_transform), + relax.TensorStructInfo((10, 10, 30, 2), "float32"), + ) + + implicit_padding_transform = lambda a, b, c: (a, c, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=implicit_padding_transform, pad_value=2), + relax.TensorStructInfo((10, 30, 7, 3), "float32"), + ) + + flatten_transform = lambda a, b, c: (a * 600 + b * 30 + c) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=flatten_transform), + relax.TensorStructInfo((6000,), "float32"), + ) + + +def test_layout_transform_infer_struct_info_mismatch_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "int32")) + + transpose_transform = lambda a, b, c: (a, c, b) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x, index_map=transpose_transform, pad_value=2.2)) + + +def test_layout_transform_infer_struct_info_unknown_shape(): + bb = relax.BlockBuilder() + tiling_transform = lambda a, b: (a, b // 2, b % 2) + + x_unknown_shape = relax.Var("x", R.Tensor("float32", ndim=2)) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + x_unknown_rank_dtype = relax.Var("x", R.Tensor()) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform), + relax.TensorStructInfo(dtype="", ndim=3), + ) + + +def test_layout_transform_infer_struct_info_symbolic_shape(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((a, b), "float32")) + + tiling_transform = lambda a, b: (a, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x0, index_map=tiling_transform), + relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), + ) + + +def test_layout_transform_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + + s = relax.Var("s", relax.ShapeStructInfo((30, 20))) + x = relax.Var("x", relax.TensorStructInfo(s, "float32")) + tiling_padding_transform = lambda a, b: (a, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=tiling_padding_transform), + relax.TensorStructInfo((30, 7, 3), "float32"), + ) + + s_unknown_shape = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x_unknown_shape = relax.Var("x", relax.TensorStructInfo(s_unknown_shape, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + s_unknown_rank = relax.Var("s", relax.ShapeStructInfo()) + x_unknown_rank = relax.Var("x", relax.TensorStructInfo(s_unknown_rank, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s_symbolic_shape = relax.Var("s", relax.ShapeStructInfo((a, b))) + x_symbolic_shape = relax.Var("x", relax.TensorStructInfo(s_symbolic_shape, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_symbolic_shape, index_map=tiling_padding_transform), + relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), + ) + + +def test_layout_transform_infer_struct_info_invalid_index_map(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x, index_map=lambda a, b: (b, a))) + + +def test_squeeze_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=6)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4))) + x4 = relax.Var("x", R.Tensor(ndim=6)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.squeeze(x0, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), "float32") + ) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference( + bb, relax.op.squeeze(x1, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x2, [1, 4]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.squeeze(x3, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), dtype="") + ) + _check_inference(bb, relax.op.squeeze(x3), relax.TensorStructInfo((2, 3, 4), dtype="")) + _check_inference(bb, relax.op.squeeze(x4, [1, 4]), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.squeeze(x4), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.squeeze(x5, [1, 4]), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.squeeze(x5), relax.TensorStructInfo(dtype="")) + + +def test_squeeze_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((a, 1, b), "float32")) + x1 = relax.Var("x", R.Tensor((a, 1, b))) + + _check_inference(bb, relax.op.squeeze(x0, [1]), relax.TensorStructInfo((a, b), "float32")) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x1, [1]), relax.TensorStructInfo((a, b), dtype="")) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(dtype="")) + + +def test_squeeze_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s2 = relax.Var("s", relax.ShapeStructInfo((a, 1, b))) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + s4 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) + + _check_inference( + bb, relax.op.squeeze(x0, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.squeeze(x0, []), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x1, []), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(s1, dtype="float32")) + _check_inference(bb, relax.op.squeeze(x2, [1]), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.squeeze(x2, []), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.squeeze(x3, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.squeeze(x3, []), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.squeeze(x3), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x4, [1, 4]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x4, []), relax.TensorStructInfo(s4, "float32")) + _check_inference(bb, relax.op.squeeze(x4), relax.TensorStructInfo(dtype="float32")) + + +def test_squeeze_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "int8")) + x2 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "int32")) + + _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo((2, 3, 4), "float16")) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo((2, 3, 4), "int8")) + _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo((2, 3, 4), "int32")) + + +def test_squeeze_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=6)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [6])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [-7])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1, [6])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1, [-7])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [6])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [-7])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x3, [6])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x3, [-7])) + + +def test_squeeze_infer_struct_info_repetitive_axes(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=6)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [3, -3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1, [3, -3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [3, -3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x3, [3, -3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x3, [1, 1])) + + +def test_squeeze_infer_struct_info_axis_length_not_one(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo((a, 3, 4))) + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((a, 3, 4), "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [0])) + _check_inference(bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo((3, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [0])) + _check_inference(bb, relax.op.squeeze(x3, [0]), relax.TensorStructInfo(dtype="float32", ndim=2)) + + +def test_squeeze_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1)) + + +def test_flatten_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor((3,), "float32")) + x2 = relax.Var("x", R.Tensor((), "float32")) + x3 = relax.Var("x", R.Tensor("float32", ndim=3)) + x4 = relax.Var("x", R.Tensor("float32", ndim=1)) + x5 = relax.Var("x", R.Tensor("float32", ndim=0)) + x6 = relax.Var("x", R.Tensor("float32")) + x7 = relax.Var("x", R.Tensor((3, 4, 5))) + x8 = relax.Var("x", R.Tensor((3,))) + x9 = relax.Var("x", R.Tensor(())) + x10 = relax.Var("x", R.Tensor(ndim=3)) + x11 = relax.Var("x", R.Tensor(ndim=1)) + x12 = relax.Var("x", R.Tensor(ndim=0)) + x13 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((60,), "float32")) + _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((3,), "float32")) + _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((1,), "float32")) + _check_inference(bb, relax.op.flatten(x3), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x4), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x5), relax.TensorStructInfo((1,), "float32")) + _check_inference(bb, relax.op.flatten(x6), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x7), relax.TensorStructInfo((60,), dtype="")) + _check_inference(bb, relax.op.flatten(x8), relax.TensorStructInfo((3,), dtype="")) + _check_inference(bb, relax.op.flatten(x9), relax.TensorStructInfo((1,), dtype="")) + _check_inference(bb, relax.op.flatten(x10), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.flatten(x11), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.flatten(x12), relax.TensorStructInfo((1,), dtype="")) + _check_inference(bb, relax.op.flatten(x13), relax.TensorStructInfo(dtype="", ndim=1)) + + +def test_flatten_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((a, b), "float32")) + x1 = relax.Var("x", R.Tensor((a, b))) + + _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((a * b,), "float32")) + _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((a * b,), dtype="")) + + +def test_flatten_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((3, 4, 5))) + s1 = relax.Var("s", relax.ShapeStructInfo((3,))) + s2 = relax.Var("s", relax.ShapeStructInfo(())) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s4 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s5 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + s6 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s5, "float32")) + x6 = relax.Var("x", relax.TensorStructInfo(s6, "float32")) + + _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((1,), "float32")) + _check_inference(bb, relax.op.flatten(x3), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x4), relax.TensorStructInfo(s4, "float32")) + _check_inference(bb, relax.op.flatten(x5), relax.TensorStructInfo((1,), "float32")) + _check_inference(bb, relax.op.flatten(x6), relax.TensorStructInfo(dtype="float32", ndim=1)) + + +def test_flatten_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((3, 4, 5), "int8")) + x2 = relax.Var("x", R.Tensor((3, 4, 5), "int32")) + + _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((60,), "float16")) + _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((60,), "int8")) + _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((60,), "int32")) + + +def test_flatten_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((3, 4, 5), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.flatten(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.flatten(x1)) + + +def test_flatten_wrong_input_number(): + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((2, 3, 4), "float32")) + + with pytest.raises(TypeError): + relax.op.flatten(x, y) + + +def test_concat_infer_struct_info_with_axis(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((2, 4, 4), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=3)) + y2 = relax.Var("y", R.Tensor("float32")) + y3 = relax.Var("y", R.Tensor((2, 4, 4))) + y4 = relax.Var("y", R.Tensor(ndim=3)) + y5 = relax.Var("y", R.Tensor()) + z0 = relax.Var("z", R.Tensor((2, 5, 4), "float32")) + z1 = relax.Var("z", R.Tensor("float32", ndim=3)) + z2 = relax.Var("z", R.Tensor("float32")) + z3 = relax.Var("z", R.Tensor((2, 5, 4))) + z4 = relax.Var("z", R.Tensor(ndim=3)) + z5 = relax.Var("z", R.Tensor()) + + _check_inference( + bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorStructInfo((2, 12, 4), "float32") + ) + _check_inference( + bb, relax.op.concat([x0, y0, z0], axis=-2), relax.TensorStructInfo((2, 12, 4), "float32") + ) + _check_inference( + bb, relax.op.concat([x1, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y0, z0], axis=1), relax.TensorStructInfo((2, 12, 4), dtype="") + ) + _check_inference( + bb, relax.op.concat([x3, y0, z0], axis=-2), relax.TensorStructInfo((2, 12, 4), dtype="") + ) + _check_inference( + bb, relax.op.concat([x4, y0, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y0, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x1, y1, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y1, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y1, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y1, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y2, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y2, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y5, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x1, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y2, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y1, z1], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y2, z2], axis=1), relax.TensorStructInfo(dtype="float32", ndim=-1) + ) + _check_inference( + bb, relax.op.concat([x3, y2, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x4, y4, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y5, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=-1) + ) + _check_inference( + bb, relax.op.concat([x3, y3, z3], axis=1), relax.TensorStructInfo((2, 12, 4), dtype="") + ) + _check_inference( + bb, relax.op.concat([x3, y3, z3], axis=-2), relax.TensorStructInfo((2, 12, 4), dtype="") + ) + _check_inference( + bb, relax.op.concat([x4, y3, z3], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y5, z3], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x4, y4, z4], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y5, z4], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference(bb, relax.op.concat([x5, y5, z5], axis=1), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0, z0]), axis=1), + relax.TensorStructInfo((2, 12, 4), "float32"), + ) + + +def test_concat_infer_struct_info_with_axis_shape_symbolic(): + bb = relax.BlockBuilder() + a0 = tir.Var("a0", "int64") + a1 = tir.Var("a1", "int64") + b0 = tir.Var("b0", "int64") + b1 = tir.Var("b1", "int64") + b2 = tir.Var("b2", "int64") + c = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((a0, b0, c), "float32")) + x1 = relax.Var("x", R.Tensor((a1, b0, c), "float32")) + y = relax.Var("y", R.Tensor((a0, b1, c), "float32")) + z = relax.Var("z", R.Tensor((a0, b2, c), "float32")) + + _check_inference( + bb, + relax.op.concat([x0, y, z], axis=1), + relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), + ) + _check_inference( + bb, + relax.op.concat([x0, y, z], axis=-2), + relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), + ) + _check_inference( + bb, relax.op.concat([x1, y, z], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y, z]), axis=1), + relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), + ) + + +def test_concat_infer_struct_info_with_axis_shape_var(): + bb = relax.BlockBuilder() + a0 = tir.Var("a0", "int64") + a1 = tir.Var("a1", "int64") + b0 = tir.Var("b0", "int64") + b1 = tir.Var("b1", "int64") + b2 = tir.Var("b2", "int64") + c = tir.Var("c", "int64") + sx0 = relax.Var("sx", relax.ShapeStructInfo((2, 3, 4))) + sx1 = relax.Var("sx", relax.ShapeStructInfo((a0, b0, c))) + sx2 = relax.Var("sx", relax.ShapeStructInfo((a1, b0, c))) + sx3 = relax.Var("sx", relax.ShapeStructInfo(ndim=3)) + sx4 = relax.Var("sx", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(sx3, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(sx4, "float32")) + y0 = relax.Var("y", R.Tensor((2, 4, 4), "float32")) + y1 = relax.Var("y", R.Tensor((a0, b1, c), "float32")) + z0 = relax.Var("z", R.Tensor((2, 5, 4), "float32")) + z1 = relax.Var("z", R.Tensor((a0, b2, c), "float32")) + + _check_inference( + bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x1, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x4, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0, z0]), axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + +def test_concat_infer_struct_info_without_axis(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3,), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=1)) + x2 = relax.Var("x", R.Tensor((3,))) + x3 = relax.Var("x", R.Tensor(ndim=1)) + y0 = relax.Var("y", R.Tensor((4,), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=1)) + z0 = relax.Var("z", R.Tensor((5,), "float32")) + z1 = relax.Var("z", R.Tensor("float32", ndim=1)) + + _check_inference( + bb, relax.op.concat([x0, y0, z0], axis=None), relax.TensorStructInfo((12,), "float32") + ) + _check_inference( + bb, + relax.op.concat([x1, y0, z0], axis=None), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, relax.op.concat([x2, y0, z0], axis=None), relax.TensorStructInfo((12,), dtype="") + ) + _check_inference( + bb, relax.op.concat([x3, y0, z0], axis=None), relax.TensorStructInfo(dtype="", ndim=1) + ) + _check_inference( + bb, + relax.op.concat([x1, y1, z0], axis=None), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, relax.op.concat([x2, y1, z0], axis=None), relax.TensorStructInfo(dtype="", ndim=1) + ) + _check_inference( + bb, + relax.op.concat([x1, y1, z1], axis=None), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0, z0]), axis=None), + relax.TensorStructInfo((12,), "float32"), + ) + + +def test_concat_infer_struct_info_without_axis_shape_symbolic(): + bb = relax.BlockBuilder() + a0 = tir.Var("a0", "int64") + a1 = tir.Var("a1", "int64") + x0 = relax.Var("x", R.Tensor((a0,), "float32")) + x1 = relax.Var("x", R.Tensor((a0,), "")) + y0 = relax.Var("y", R.Tensor((a1,), "float32")) + y1 = relax.Var("y", R.Tensor((a1,), "")) + + _check_inference( + bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo((a0 + a1,), "float32") + ) + _check_inference( + bb, relax.op.concat([x0, y1], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") + ) + _check_inference( + bb, relax.op.concat([x1, y0], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") + ) + _check_inference( + bb, relax.op.concat([x1, y1], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0]), axis=None), + relax.TensorStructInfo((a0 + a1,), "float32"), + ) + + +def test_concat_infer_struct_info_without_axis_shape_var(): + bb = relax.BlockBuilder() + sx0 = relax.Var("sx", relax.ShapeStructInfo((3,))) + sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=1)) + sy0 = relax.Var("sy", relax.ShapeStructInfo((4,))) + x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(sy0, "float32")) + + _check_inference( + bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) + ) + _check_inference( + bb, relax.op.concat([x1, y0], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0]), axis=None), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + + +def test_concat_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3,), "float16")) + y0 = relax.Var("y", R.Tensor((4,), "float16")) + x1 = relax.Var("x", R.Tensor((3,), "int8")) + y1 = relax.Var("y", R.Tensor((4,), "int8")) + x2 = relax.Var("x", R.Tensor((3,), "int32")) + y2 = relax.Var("y", R.Tensor((4,), "int32")) + + _check_inference( + bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo((7,), "float16") + ) + _check_inference(bb, relax.op.concat([x1, y1], axis=None), relax.TensorStructInfo((7,), "int8")) + _check_inference( + bb, relax.op.concat([x2, y2], axis=None), relax.TensorStructInfo((7,), "int32") + ) + + +def test_concat_infer_struct_info_tuple_var(): + bb = relax.BlockBuilder() + a = tir.Var("a0", "int64") + b0 = tir.Var("b0", "int64") + b1 = tir.Var("b1", "int64") + t0 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo((a, b0), "float32"), relax.TensorStructInfo((a, b1), "float32")] + ), + ) + t1 = relax.Var( + "t", + relax.TupleStructInfo( + [ + relax.TensorStructInfo((a, b0), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=2), + ] + ), + ) + t2 = relax.Var( + "t", + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32", ndim=2), + ] + ), + ) + t3 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] + ), + ) + t4 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo((a, b0), "float32"), relax.TensorStructInfo((a, b1))] + ), + ) + t5 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo((a, b0), dtype=""), relax.TensorStructInfo((a, b1), dtype="")] + ), + ) + t6 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="", ndim=2), relax.TensorStructInfo(dtype="")] + ), + ) + t7 = relax.Var( + "t", + relax.TupleStructInfo([relax.TensorStructInfo(dtype=""), relax.TensorStructInfo(dtype="")]), + ) + + _check_inference( + bb, relax.op.concat(t0, axis=1), relax.TensorStructInfo((a, b0 + b1), "float32") + ) + _check_inference( + bb, relax.op.concat(t1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.concat(t2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.concat(t3, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.concat(t4, axis=1), relax.TensorStructInfo((a, b0 + b1), "float32") + ) + _check_inference( + bb, relax.op.concat(t5, axis=1), relax.TensorStructInfo((a, b0 + b1), dtype="") + ) + _check_inference(bb, relax.op.concat(t6, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.concat(t7, axis=1), relax.TensorStructInfo(dtype="")) + + +def test_concat_infer_struct_info_single_input_tensor(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((3, a))) + s1 = relax.Var("s", relax.ShapeStructInfo((a,))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s4 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((3, a), "float32")) + x1 = relax.Var("x", R.Tensor((a,), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + x3 = relax.Var("x", R.Tensor("float32", ndim=1)) + x4 = relax.Var("x", R.Tensor("float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x6 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x7 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x8 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + x9 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) + + _check_inference(bb, relax.op.concat([x0], axis=1), relax.TensorStructInfo((3, a), "float32")) + _check_inference(bb, relax.op.concat([x1], axis=0), relax.TensorStructInfo((a,), "float32")) + _check_inference(bb, relax.op.concat([x1], axis=None), relax.TensorStructInfo((a,), "float32")) + _check_inference( + bb, relax.op.concat([x2], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3], axis=0), relax.TensorStructInfo(dtype="float32", ndim=1) + ) + _check_inference( + bb, relax.op.concat([x3], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) + ) + _check_inference(bb, relax.op.concat([x4], axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.concat([x5], axis=1), relax.TensorStructInfo(s0, dtype="float32")) + _check_inference(bb, relax.op.concat([x6], axis=0), relax.TensorStructInfo(s1, dtype="float32")) + _check_inference( + bb, relax.op.concat([x6], axis=None), relax.TensorStructInfo(s1, dtype="float32") + ) + _check_inference(bb, relax.op.concat([x7], axis=1), relax.TensorStructInfo(s2, dtype="float32")) + _check_inference(bb, relax.op.concat([x8], axis=0), relax.TensorStructInfo(s3, dtype="float32")) + _check_inference( + bb, relax.op.concat([x8], axis=None), relax.TensorStructInfo(s3, dtype="float32") + ) + _check_inference(bb, relax.op.concat([x9], axis=1), relax.TensorStructInfo(s4, dtype="float32")) + + +def test_concat_infer_struct_info_zero_rank_input_tensor(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(())) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + x0 = relax.Var("x", R.Tensor((), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=0)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x0], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x1], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x2], axis=None)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x3], axis=None)) + + +def test_concat_infer_struct_info_no_input_tensor(): + bb = relax.BlockBuilder() + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([], axis=None)) + + +def test_concat_infer_struct_info_without_axis_but_tensor_not_one_dimensional(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x0], axis=None)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x1], axis=None)) + _check_inference(bb, relax.op.concat([x2], axis=None), relax.TensorStructInfo(dtype="float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x3], axis=None)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x4], axis=None)) + _check_inference(bb, relax.op.concat([x5], axis=None), relax.TensorStructInfo(s2, "float32")) + + +def test_concat_infer_struct_info_inconsistent_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3,))) + y = relax.Var("y", R.Tensor((4,), "float32")) + z = relax.Var("z", R.Tensor((5,), "int8")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y, z], axis=0)) + + +def test_concat_infer_struct_info_inconsistent_ndim(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((4, 5))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x = relax.Var("x", R.Tensor((3,), "float32")) + y0 = relax.Var("y", R.Tensor((4, 5), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=2)) + y2 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) + y3 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + z = relax.Var("z", R.Tensor((5,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y0, z], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y1, z], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y2, z], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y3, z], axis=0)) + + +def test_concat_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((3,))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", R.Tensor((3,), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=1)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x0], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x1], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x2], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x3], axis=1)) + + +def test_concat_infer_struct_info_unequal_shape(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo((3, a + 2))) + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((3, a + 2), "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + y0 = relax.Var("y", R.Tensor((3, 3), "float32")) + y1 = relax.Var("y", R.Tensor((3, a), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x0, y0])) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x2, y0])) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x1, y1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x3, y1])) + + +def test_concat_infer_struct_info_input_not_tuple(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3,), "float32")) + s = relax.Var("s", relax.ShapeStructInfo((3,))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat(x)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat(s)) + + +def test_concat_infer_struct_info_input_tuple_field_not_tensor(): + bb = relax.BlockBuilder() + s = relax.Var("s", relax.ShapeStructInfo((3,))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([s])) + + +def test_split_infer_struct_info_by_indices(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 10, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.split(x0, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 3, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, [3, 7], axis=-2), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 3, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x2, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x3, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), dtype=""), + relax.TensorStructInfo((2, 4, 4), dtype=""), + relax.TensorStructInfo((2, 3, 4), dtype=""), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x4, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x5, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo(dtype=""), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, [-2, 2, 6, 4, 8, 12, 9], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 0, 4), "float32"), + relax.TensorStructInfo((2, 2, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 0, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 2, 4), "float32"), + relax.TensorStructInfo((2, 0, 4), "float32"), + relax.TensorStructInfo((2, 1, 4), "float32"), + ] + ), + ) + + +def test_split_infer_struct_info_by_indices_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x = relax.Var("x", R.Tensor((a, b), "float32")) + + _check_inference( + bb, + relax.op.split(x, [10, 20], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorStructInfo(dtype="float32", ndim=2), + ] + ), + ) + + +def test_split_infer_struct_info_by_indices_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 10, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, + relax.op.split(x0, [3], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, [3], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x2, [3], axis=1), + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] + ), + ) + + +def test_split_infer_struct_info_by_n_section(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 10, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.split(x0, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 2, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, 2, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 5, 4), "float32"), + relax.TensorStructInfo((2, 5, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, 3, axis=-2), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 2, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x2, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x3, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), dtype=""), + relax.TensorStructInfo((2, 4, 4), dtype=""), + relax.TensorStructInfo((2, 2, 4), dtype=""), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x4, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x5, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo(dtype=""), + ] + ), + ) + + +def test_split_infer_struct_info_by_n_section_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x = relax.Var("x", R.Tensor((a, b), "float32")) + + _check_inference( + bb, + relax.op.split(x, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"), + relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"), + relax.TensorStructInfo((a, b - tir.ceildiv(b, 3) * 2), "float32"), + ] + ), + ) + + +def test_split_infer_struct_info_by_n_section_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 10, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, + relax.op.split(x0, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x2, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + ] + ), + ) + + +def test_split_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 10, 4), "int8")) + + _check_inference( + bb, + relax.op.split(x0, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), "float16"), + relax.TensorStructInfo((2, 4, 4), "float16"), + relax.TensorStructInfo((2, 3, 4), "float16"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), "int8"), + relax.TensorStructInfo((2, 4, 4), "int8"), + relax.TensorStructInfo((2, 3, 4), "int8"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), "float16"), + relax.TensorStructInfo((2, 4, 4), "float16"), + relax.TensorStructInfo((2, 2, 4), "float16"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), "int8"), + relax.TensorStructInfo((2, 4, 4), "int8"), + relax.TensorStructInfo((2, 2, 4), "int8"), + ] + ), + ) + + +def test_split_infer_struct_info_single_output(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((a, b))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((a, b), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, + relax.op.split(x0, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo((a, b), "float32")]), + ) + _check_inference( + bb, + relax.op.split(x1, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32", ndim=2)]), + ) + _check_inference( + bb, + relax.op.split(x2, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32")]), + ) + _check_inference( + bb, + relax.op.split(x3, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s0, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x4, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s1, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x5, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s2, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x0, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo((a, b), "float32")]), + ) + _check_inference( + bb, + relax.op.split(x1, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32", ndim=2)]), + ) + _check_inference( + bb, + relax.op.split(x2, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32")]), + ) + _check_inference( + bb, + relax.op.split(x3, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s0, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x4, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s1, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x5, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s2, "float32")]), + ) + + +def test_split_indices_or_sections_int64(): + x = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + split0 = relax.op.split(x, [3, 6], axis=1) + split1 = relax.op.split(x, 4, axis=1) + + assert split0.attrs.indices_or_sections[0].dtype == "int64" + assert split0.attrs.indices_or_sections[1].dtype == "int64" + assert split1.attrs.indices_or_sections.dtype == "int64" + + +def test_split_infer_struct_info_non_integer_indices(): + bb = relax.BlockBuilder() + a = tir.Var("c", "int64") + b = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((3, 4), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x, [a, b], axis=1)) + + +def test_split_invalid_n_section(): + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor((3, 4), "float32")) + + with pytest.raises(TVMError): + relax.op.split(x, 0, axis=1) + with pytest.raises(TVMError): + relax.op.split(x, -1, axis=1) + with pytest.raises(TVMError): + relax.op.split(x, n, axis=1) + + +def test_split_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, [], axis=2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, [], axis=-3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x1, 1, axis=2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x1, 1, axis=-3)) + + +def test_split_infer_invalid_struct_info_indices(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + v = relax.Var("v", relax.PrimStructInfo("int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, [v], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, v, axis=1)) + + +def test_split_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, 1, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x1, 1, axis=1)) + + +def test_broadcast_to_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 1, 3))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x3, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") + ) + _check_inference( + bb, relax.op.broadcast_to(x4, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") + ) + _check_inference( + bb, relax.op.broadcast_to(x5, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") + ) + + +def test_broadcast_to_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x0 = relax.Var("x", R.Tensor((b, 1, 1, d), "float32")) + x1 = relax.Var("x", R.Tensor((b, 1, 1, d))) + + _check_inference( + bb, + relax.op.broadcast_to(x0, (a, b, 1, c, d)), + relax.TensorStructInfo((a, b, 1, c, d), "float32"), + ) + _check_inference( + bb, + relax.op.broadcast_to(x1, (a, b, 1, c, d)), + relax.TensorStructInfo((a, b, 1, c, d), dtype=""), + ) + + +def test_broadcast_to_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + + +def test_broadcast_to_infer_struct_info_tgt_shape_var(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((b, 1, 1, d))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((b, 1, 1, d), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeStructInfo((a, b, 1, c, d))) + stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=5)) + stgt2 = relax.Var("stgt", relax.ShapeStructInfo()) + + _check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x2, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x3, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x4, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x5, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x2, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x3, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x4, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x5, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x2, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x3, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x4, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x5, stgt2), relax.TensorStructInfo(stgt2, "float32")) + + +def test_broadcast_to_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 1, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 1, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 1, 3), "int32")) + + _check_inference( + bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float16") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "int8") + ) + _check_inference( + bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "int32") + ) + + +def test_broadcast_to_infer_struct_info_tgt_ndim_less_than_old_ndim(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x0 = relax.Var("x", R.Tensor((2, 1), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeStructInfo((2,))) + stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, (2,))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, stgt0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, stgt1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, (2,))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, stgt0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, stgt1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x2, (2,))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x2, stgt0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x2, stgt1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x3, (2,))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x3, stgt0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x3, stgt1)) + + +def test_broadcast_to_infer_struct_info_not_broadcastable_static(): + bb = relax.BlockBuilder() + s = relax.Var("s", relax.ShapeStructInfo((2, 1, 3))) + x0 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + stgt = relax.Var("stgt", relax.ShapeStructInfo((2, 1, 6))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, (2, 1, 6))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, stgt)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, (2, 1, 6))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, stgt)) + + +def test_broadcast_to_infer_struct_info_not_broadcastable_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s = relax.Var("s", relax.ShapeStructInfo((2, a))) + x0 = relax.Var("x", R.Tensor((2, a), "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeStructInfo((2, b))) + stgt1 = relax.Var("stgt", relax.ShapeStructInfo((2, 1))) + stgt2 = relax.Var("stgt", relax.ShapeStructInfo((b, a))) + + _check_inference( + bb, relax.op.broadcast_to(x0, (2, b)), relax.TensorStructInfo((2, b), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x0, (2, 1)), relax.TensorStructInfo((2, 1), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x0, (b, a)), relax.TensorStructInfo((b, a), "float32") + ) + _check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference( + bb, relax.op.broadcast_to(x1, (2, b)), relax.TensorStructInfo((2, b), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (2, 1)), relax.TensorStructInfo((2, 1), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (b, a)), relax.TensorStructInfo((b, a), "float32") + ) + _check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32")) + + +def test_broadcast_to_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 1, 3))) + x1 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) + stgt = relax.Var("stgt", relax.TensorStructInfo((4, 2, 5, 3), dtype="")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, (4, 2, 5, 3))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, stgt)) + + +def test_collapse_sum_like_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((3, 4), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=2)) + y2 = relax.Var("y", R.Tensor("float32")) + y3 = relax.Var("y", R.Tensor((3, 4))) + y4 = relax.Var("y", R.Tensor(ndim=2)) + y5 = relax.Var("y", R.Tensor((1, 4))) + + _check_inference( + bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x0, y2), relax.TensorStructInfo(dtype="float32", ndim=-1) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x0, y3), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x2, y0), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x2, y4), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x4, y1), relax.TensorStructInfo(dtype="", ndim=2) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x5, y3), relax.TensorStructInfo((3, 4), dtype="") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x0, y5), relax.TensorStructInfo((1, 4), "float32") + ) + + +def test_collapse_sum_like_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((3, 4, a), "float32")) + y0 = relax.Var("y", R.Tensor((4, a), "float32")) + x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32")) + y1 = relax.Var("x", R.Tensor((1, a + b), "float32")) + + _check_inference( + bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((4, a), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((1, a + b), "float32") + ) + + +def test_collapse_sum_like_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) + s3 = relax.Var("s3", relax.ShapeStructInfo((3, 4))) + s4 = relax.Var("s4", relax.ShapeStructInfo(ndim=2)) + s5 = relax.Var("s5", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + y1 = relax.Var("y", relax.TensorStructInfo(s4, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(s5, "float32")) + + _check_inference(bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo(s4, "float32")) + _check_inference(bb, relax.op.collapse_sum_like(x2, y2), relax.TensorStructInfo(s5, "float32")) + + +def test_collapse_sum_like_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + y0 = relax.Var("y", R.Tensor((3, 4), "float16")) + y1 = relax.Var("y", R.Tensor((3, 4), "int8")) + + _check_inference( + bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), "float16") + ) + _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((3, 4), "int8")) + + +def test_collapse_sum_like_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) + x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x0, x1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x2, x0)) + + +def test_collapse_sum_like_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + y0 = relax.Var("y", R.Tensor((3, 6, 5), "float32")) + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x1 = relax.Var("z", R.Tensor((3, a, 5), "float32")) + y1 = relax.Var("w", R.Tensor((3, b, 5), "float32")) + + s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) + s1 = relax.Var("s1", relax.ShapeStructInfo((3, 6, 5))) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + + s2 = relax.Var("s2", relax.ShapeStructInfo((3, a, 5))) + s3 = relax.Var("s3", relax.ShapeStructInfo((3, b, 5))) + x3 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x0, y0)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x1, y1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x2, y2)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x3, y3)) + + +def test_collapse_sum_to_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x2, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference(bb, relax.op.collapse_sum_to(x3, (3, 4)), relax.TensorStructInfo((3, 4), "")) + _check_inference(bb, relax.op.collapse_sum_to(x4, (3, 4)), relax.TensorStructInfo((3, 4), "")) + _check_inference(bb, relax.op.collapse_sum_to(x5, (3, 4)), relax.TensorStructInfo((3, 4), "")) + + +def test_collapse_sum_to_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((3, 4, a), "float32")) + x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32")) + + _check_inference( + bb, relax.op.collapse_sum_to(x0, (4, a)), relax.TensorStructInfo((4, a), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (1, a + b)), relax.TensorStructInfo((1, a + b), "float32") + ) + + +def test_collapse_sum_to_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + _check_inference( + bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + + +def test_collapse_sum_to_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + + _check_inference( + bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float16") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "int8") + ) + + +def test_collapse_sum_to_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) + x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x0, x0)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x0, x2)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x1, x1)) + + +def test_collapse_sum_to_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x1 = relax.Var("x", R.Tensor((3, a, 5), "float32")) + + s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + + s1 = relax.Var("s1", relax.ShapeStructInfo((3, a, 5))) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x0, (4, 4, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x1, (3, b, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x2, (4, 4, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x3, (3, b, 5))) + + +def test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + s0 = relax.Var("s0", relax.ShapeStructInfo((3, a, b))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((3, a, b), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + stgt0 = relax.Var("stgt0", relax.ShapeStructInfo((a, b))) + stgt1 = relax.Var("stgt1", relax.ShapeStructInfo(ndim=2)) + stgt2 = relax.Var("stgt2", relax.ShapeStructInfo()) + + _check_inference( + bb, relax.op.collapse_sum_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference(bb, relax.op.collapse_sum_to(x2, stgt0), relax.TensorStructInfo(stgt0, "")) + _check_inference( + bb, relax.op.collapse_sum_to(x3, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x4, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x5, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference(bb, relax.op.collapse_sum_to(x2, stgt1), relax.TensorStructInfo(stgt1, "")) + _check_inference( + bb, relax.op.collapse_sum_to(x3, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x4, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x5, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + _check_inference(bb, relax.op.collapse_sum_to(x2, stgt2), relax.TensorStructInfo(stgt2, "")) + _check_inference( + bb, relax.op.collapse_sum_to(x3, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x4, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x5, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + + +def test_repeat_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 10, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.repeat(x0, 2, axis=0), + relax.TensorStructInfo((4, 10, 4), "float32"), + ) + _check_inference( + bb, + relax.op.repeat(x0, 2, axis=-2), + relax.TensorStructInfo((2, 20, 4), "float32"), + ) + _check_inference( + bb, + relax.op.repeat(x0, 2), + relax.TensorStructInfo((160,), "float32"), + ) + _check_inference( + bb, + relax.op.repeat(x1, 2, axis=0), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.repeat(x1, 2), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference(bb, relax.op.repeat(x2, 2, axis=0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.repeat(x2, 2), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference( + bb, + relax.op.repeat(x3, 2, axis=0), + relax.TensorStructInfo((4, 10, 4), dtype=""), + ) + _check_inference(bb, relax.op.repeat(x4, 2, axis=0), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.repeat(x5, 2, axis=0), relax.TensorStructInfo(dtype="")) + + +def test_repeat_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x = relax.Var("x", R.Tensor((a, b, c), "float32")) + + _check_inference(bb, relax.op.repeat(x, 2, 0), relax.TensorStructInfo((a * 2, b, c), "float32")) + _check_inference( + bb, + relax.op.repeat(x, 2, -1), + relax.TensorStructInfo((a, b, c * 2), "float32"), + ) + _check_inference( + bb, + relax.op.repeat(x, 2), + relax.TensorStructInfo((a * b * c * 2,), "float32"), + ) + + +def test_repeat_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + + _check_inference(bb, relax.op.repeat(x0, 2, 0), relax.TensorStructInfo((4, 3, 4), "float16")) + _check_inference(bb, relax.op.repeat(x1, 2, 0), relax.TensorStructInfo((4, 3, 4), "int8")) + + +def test_repeat_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.repeat(x0, 2, 3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.repeat(x0, 2, -4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.repeat(x1, 2, 3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.repeat(x1, 2, -4)) + # okay + bb.normalize(relax.op.repeat(x2, 2, 3)) + bb.normalize(relax.op.repeat(x2, 2, -4)) + + +def test_repeat_return_data_sinfo(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + + _check_inference(bb, relax.op.repeat(x0, 1, 0), x0.struct_info) + _check_inference(bb, relax.op.repeat(x0, 1, -1), x0.struct_info) + _check_inference(bb, relax.op.repeat(x1, 1, 0), x1.struct_info) + _check_inference(bb, relax.op.repeat(x2, 1, 0), x2.struct_info) + + +def test_repeat_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + r1 = tir.Var("r", "float32") + r2 = tir.StringImm("abc") + + with pytest.raises(TVMError): + bb.normalize(relax.op.repeat(x0, 2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.repeat(x1, 2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.repeat(x2, 1.5)) + with pytest.raises(TVMError): + bb.normalize(relax.op.repeat(x2, r1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.repeat(x2, r2)) + + +def test_tile_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 10, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.tile(x0, 2), + relax.TensorStructInfo((2, 10, 8), "float32"), + ) + _check_inference( + bb, + relax.op.tile(x0, (3, 2)), + relax.TensorStructInfo((2, 30, 8), "float32"), + ) + _check_inference( + bb, + relax.op.tile(x0, (4, 3, 2)), + relax.TensorStructInfo((8, 30, 8), "float32"), + ) + _check_inference( + bb, + relax.op.tile(x0, (5, 4, 3, 2)), + relax.TensorStructInfo((5, 8, 30, 8), "float32"), + ) + _check_inference( + bb, + relax.op.tile(x1, 2), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.tile(x1, (5, 4, 3, 2)), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference(bb, relax.op.tile(x2, (5, 4, 3, 2)), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, + relax.op.tile(x3, 2), + relax.TensorStructInfo((2, 10, 8), dtype=""), + ) + _check_inference( + bb, + relax.op.tile(x3, (5, 4, 3, 2)), + relax.TensorStructInfo((5, 8, 30, 8), dtype=""), + ) + _check_inference(bb, relax.op.tile(x4, 2), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.tile(x4, (5, 4, 3, 2)), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.tile(x5, (5, 4, 3, 2)), relax.TensorStructInfo(dtype="")) + + +def test_tile_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x = relax.Var("x", R.Tensor((a, b, c), "float32")) + + _check_inference(bb, relax.op.tile(x, 2), relax.TensorStructInfo((a, b, c * 2), "float32")) + _check_inference( + bb, relax.op.tile(x, (3, 2)), relax.TensorStructInfo((a, b * 3, c * 2), "float32") + ) + _check_inference( + bb, relax.op.tile(x, (4, 3, 2)), relax.TensorStructInfo((a * 4, b * 3, c * 2), "float32") + ) + _check_inference( + bb, + relax.op.tile(x, (5, 4, 3, 2)), + relax.TensorStructInfo((5, a * 4, b * 3, c * 2), "float32"), + ) + + +def test_tile_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + + _check_inference(bb, relax.op.tile(x0, (3, 2)), relax.TensorStructInfo((2, 9, 8), "float16")) + _check_inference(bb, relax.op.tile(x1, (3, 2)), relax.TensorStructInfo((2, 9, 8), "int8")) + + +def test_tile_return_data_sinfo(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + + _check_inference(bb, relax.op.tile(x0, 1), x0.struct_info) + _check_inference(bb, relax.op.tile(x0, (1, 1)), x0.struct_info) + _check_inference(bb, relax.op.tile(x0, (1, 1, 1)), x0.struct_info) + _check_inference(bb, relax.op.tile(x1, 1), x1.struct_info) + _check_inference(bb, relax.op.tile(x2, 1), x2.struct_info) + + +def test_tile_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + r1 = tir.Var("a", "float32") + r2 = tir.StringImm("abc") + + with pytest.raises(TVMError): + bb.normalize(relax.op.tile(x0, 2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.tile(x1, 2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.tile(x2, (2, 1.5, 2))) + with pytest.raises(TVMError): + bb.normalize(relax.op.tile(x2, (2, r1))) + with pytest.raises(TVMError): + bb.normalize(relax.op.tile(x2, r2)) + + +def test_cumsum_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 10, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.cumsum(x0, axis=1), relax.TensorStructInfo((2, 10, 4), "float32")) + _check_inference( + bb, relax.op.cumsum(x1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.cumsum(x2, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.cumsum(x3, axis=1), relax.TensorStructInfo((2, 10, 4), dtype="")) + _check_inference(bb, relax.op.cumsum(x4, axis=1), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.cumsum(x5, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.cumsum(x0), relax.TensorStructInfo((80,), "float32")) + _check_inference( + bb, relax.op.cumsum(x0, axis=1, dtype="int32"), relax.TensorStructInfo((2, 10, 4), "int32") + ) + + +def test_cumsum_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x = relax.Var("x", R.Tensor((a, b, c), "float32")) + + _check_inference(bb, relax.op.cumsum(x, axis=1), relax.TensorStructInfo((a, b, c), "float32")) + _check_inference(bb, relax.op.cumsum(x), relax.TensorStructInfo((a * b * c,), "float32")) + + +def test_cumsum_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + + _check_inference(bb, relax.op.cumsum(x0, axis=1), relax.TensorStructInfo((2, 3, 4), "float16")) + _check_inference(bb, relax.op.cumsum(x1, axis=1), relax.TensorStructInfo((2, 3, 4), "int8")) + + +def test_cumsum_wrong_input_number(): + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((2, 3, 4), "float32")) + + with pytest.raises(TVMError): + relax.op.cumsum(x, y) + + +def test_cumsum_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.cumsum(x0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.cumsum(x1, axis=1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py new file mode 100644 index 000000000000..d596c60196f3 --- /dev/null +++ b/tests/python/relax/test_op_misc.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax as rx +from tvm.script import relax as R +from tvm.script import tir as T + + +@tvm.register_func("test.op.identity", override=True) +def identity_packed(a): + return tvm.nd.array(a.asnumpy()) + + +@T.prim_func +def identity_tir(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [54, 96]) + B = T.match_buffer(b, [54, 96]) + + for i, j in T.grid(54, 96): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + + +def test_call_tir() -> None: + v0 = rx.Var("v0", R.Tensor([54, 96], "float32")) + v1 = rx.call_dps_packed(rx.extern("test.op.identity"), [v0], R.Tensor((54, 96), "float32")) + v1 = rx.call_tir(identity_tir, [v0], R.Tensor((54, 96), "float32")) + + +def test_implicit_op(): + m, n = tvm.tir.Var("m", "int64"), tvm.tir.Var("n", "int64") + x = rx.Var("x", R.Tensor([m, n], "float32")) + y = rx.Var("y", R.Tensor([m, n], "float32")) + + def _check_call(expr, op_name: str): + assert isinstance(expr, rx.Call) + if not op_name.startswith("relax."): + op_name = "relax." + op_name + op = tvm.ir.Op.get(op_name) + assert expr.op == op + + # Comparison operators + _check_call(x > y, "greater") + _check_call(x >= y, "greater_equal") + _check_call(x < y, "less") + _check_call(x <= y, "less_equal") + + # Arithmetic operators + _check_call(-x, "negative") + _check_call(x + y, "add") + _check_call(x - y, "subtract") + _check_call(x * y, "multiply") + _check_call(x / y, "divide") + _check_call(x // y, "floor_divide") + _check_call(x**y, "power") + # _check_call(x % y, "mod") <= relax.mod is not implemented yet + + # Cast + _check_call(x.astype("float32"), "astype") + + # Call + call_expr = x(y)(y) + assert isinstance(call_expr.op, rx.Call) + assert call_expr.op.op == x + + # GetTupleItem + ## Eager get item for tuple + tuple_expr = rx.Tuple((x, y)) + assert tuple_expr[0] == x + assert tuple_expr[1] == y + + ## Eager get item for ShapeExpr + shape_expr = rx.ShapeExpr((1, 2)) + assert shape_expr[0] == 1 + assert shape_expr[1] == 2 + + ## Create TupleGetItem for other expr + assert isinstance(x[0], rx.TupleGetItem) + assert isinstance(x[1][0], rx.TupleGetItem) + + +def test_vm_alloc_tensor(): + bb = rx.BlockBuilder() + storage = rx.Var("storage", rx.TensorStructInfo(dtype="float32")) + alloc = rx.op.vm.alloc_tensor(storage, offset=0, shape=rx.ShapeExpr([4, 5]), dtype="float32") + alloc = bb.normalize(alloc) + tvm.ir.assert_structural_equal(alloc.struct_info, R.Tensor([4, 5], "float32")) + + +def test_builtin_stop_lift_params(): + bb = rx.BlockBuilder() + x = rx.Var("x", rx.TensorStructInfo(shape=[4, 5], dtype="float32")) + x1 = rx.op.builtin.stop_lift_params(x) + x1 = bb.normalize(x1) + tvm.ir.assert_structural_equal(x1.struct_info, R.Tensor([4, 5], "float32")) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py new file mode 100644 index 000000000000..51144784638a --- /dev/null +++ b/tests/python/relax/test_op_nn.py @@ -0,0 +1,1324 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + assert relax.op.nn.relu(x).op == Op.get("relax.nn.relu") + assert relax.op.nn.gelu(x).op == Op.get("relax.nn.gelu") + assert relax.op.nn.silu(x).op == Op.get("relax.nn.silu") + assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax") + assert relax.op.nn.log_softmax(x).op == Op.get("relax.nn.log_softmax") + assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout") + + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + gamma = relax.Var("gamma", R.Tensor((3,), "float32")) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var = relax.Var("moving_var", R.Tensor((3,), "float32")) + assert relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1).op == Op.get( + "relax.nn.batch_norm" + ) + assert relax.op.nn.layer_norm(x, gamma, beta, axes=1).op == Op.get("relax.nn.layer_norm") + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + assert relax.op.nn.cross_entropy_with_logits(x, y).op == Op.get( + "relax.nn.cross_entropy_with_logits" + ) + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_linear_unit_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.nn.silu(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.gelu(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.nn.relu(x3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.nn.gelu(x4), relax.TensorStructInfo(dtype="")) + + +def test_linear_unit_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n), "float32")) + + +def test_linear_unit_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1, "float32")) + + +def test_linear_unit_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + + _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.nn.relu(x2), relax.TensorStructInfo((2, 3), "int64")) + + +def test_linear_unit_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.gelu(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.silu(x1)) + + +def test_linear_unit_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.gelu(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.silu(x1)) + + +def test_softmax_log_softmax_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.nn.softmax(x2, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.nn.softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.nn.softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) + + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.nn.log_softmax(x1, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.nn.log_softmax(x2, axis=1), relax.TensorStructInfo(dtype="float32") + ) + _check_inference( + bb, relax.op.nn.log_softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="") + ) + _check_inference(bb, relax.op.nn.log_softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) + + +def test_softmax_log_softmax_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo((4, n), "float32")) + + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference( + bb, relax.op.nn.log_softmax(x1, axis=0), relax.TensorStructInfo((4, n), "float32") + ) + + +def test_softmax_log_softmax_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x1), relax.TensorStructInfo(s1, "float32")) + + +def test_softmax_log_softmax_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "float64")) + + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo((2, 3), "float64")) + + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.nn.log_softmax(x1), relax.TensorStructInfo((2, 3), "float64")) + + +def test_softmax_log_softmax_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x1)) + + +def test_softmax_log_softmax_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x, axis=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x, axis=-4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x, axis=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x, axis=-4)) + + +def test_softmax_log_softmax_wrong_with_multiple_axes(): + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + + with pytest.raises(TVMError): + relax.op.nn.softmax(x, axis=[1, 2]) + with pytest.raises(TVMError): + relax.op.nn.softmax(x, axis=[-1, -2, -3]) + with pytest.raises(TVMError): + relax.op.nn.log_softmax(x, axis=[1, 2]) + with pytest.raises(TVMError): + relax.op.nn.log_softmax(x, axis=[-1, -2, -3]) + + +def test_softmax_log_softmax_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x1)) + + +def test_batch_norm_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor(ndim=4)) + x4 = relax.Var("x", R.Tensor()) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor("float32", ndim=1)) + gamma2 = relax.Var("gamma", R.Tensor(ndim=1)) + beta0 = relax.Var("beta", R.Tensor((3,), "float32")) + beta1 = relax.Var("beta", R.Tensor((3,))) + moving_mean0 = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_mean1 = relax.Var("moving_mean", R.Tensor((3,))) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor("float32", ndim=1)) + moving_var2 = relax.Var("moving_var", R.Tensor(ndim=1)) + + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo((3,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, axis=-3), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo((3,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x1, gamma0, beta0, moving_mean0, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo((3,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma1, beta0, moving_mean0, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo((3,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var1, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x1, gamma1, beta0, moving_mean0, moving_var1, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x2, gamma1, beta0, moving_mean0, moving_var1, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x3, gamma2, beta1, moving_mean1, moving_var2, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(ndim=4, dtype=""), + relax.TensorStructInfo((3,), dtype=""), + relax.TensorStructInfo(dtype="", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x4, gamma2, beta1, moving_mean1, moving_var2, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo((3,), dtype=""), + relax.TensorStructInfo(dtype="", ndim=1), + ] + ), + ) + + +def test_batch_norm_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c0 = tir.Var("c", "int64") + c1 = tir.Var("c", "int64") + h = tir.Var("h", "int64") + w = tir.Var("w", "int64") + x0 = relax.Var("x", R.Tensor((n, c0, h, w), "float32")) + x1 = relax.Var("x", R.Tensor((n, c1, h, w), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + gamma0 = relax.Var("gamma", R.Tensor((c0,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((c1,), "float32")) + gamma2 = relax.Var("gamma", R.Tensor("float32", ndim=1)) + beta = relax.Var("beta", R.Tensor((c0,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((c0,), "float32")) + moving_var0 = relax.Var("moving_var", R.Tensor((c0,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor((c1,), "float32")) + moving_var2 = relax.Var("moving_var", R.Tensor("float32", ndim=1)) + + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((n, c0, h, w), "float32"), + relax.TensorStructInfo((c0,), "float32"), + relax.TensorStructInfo((c0,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x2, gamma0, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo((c0,), "float32"), + relax.TensorStructInfo((c0,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var1, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma2, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((n, c0, h, w), "float32"), + relax.TensorStructInfo((c0,), "float32"), + relax.TensorStructInfo((c0,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var2, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((n, c0, h, w), "float32"), + relax.TensorStructInfo((c0,), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + + +def test_batch_norm_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo()) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) + beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) + moving_mean = relax.Var("moving_mean", relax.TensorStructInfo(s2, "float32")) + moving_var = relax.Var("moving_var", relax.TensorStructInfo(s3, "float32")) + + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma, beta, moving_mean, moving_var, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(s0, "float32"), + relax.TensorStructInfo(s2, "float32"), + relax.TensorStructInfo(s3, "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x1, gamma, beta, moving_mean, moving_var, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(s1, "float32"), + relax.TensorStructInfo(s2, "float32"), + relax.TensorStructInfo(s3, "float32"), + ] + ), + ) + + +def test_batch_norm_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) + gamma = relax.Var("gamma", R.Tensor((3,), "float16")) + beta = relax.Var("beta", R.Tensor((3,), "float16")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float16")) + moving_var = relax.Var("moving_var", R.Tensor((3,), "float16")) + + _check_inference( + bb, + relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float16"), + relax.TensorStructInfo((3,), "float16"), + relax.TensorStructInfo((3,), "float16"), + ] + ), + ) + + +def test_batch_norm_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "int8")) + beta0 = relax.Var("beta", R.Tensor((3,), "int8")) + moving_mean0 = relax.Var("moving_mean", R.Tensor((3,), "int8")) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int32")) + gamma1 = relax.Var("gamma", R.Tensor((3,), "int32")) + beta1 = relax.Var("beta", R.Tensor((3,), "int32")) + moving_mean1 = relax.Var("moving_mean", R.Tensor((3,), "int32")) + moving_var1 = relax.Var("moving_var", R.Tensor((3,), "int32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x1, gamma1, beta1, moving_mean1, moving_var1, axis=1)) + + +def test_batch_norm_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + gamma = relax.Var("gamma", R.Tensor((3,), "float32")) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var = relax.Var("moving_var", R.Tensor((3,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=-5)) + + +def test_batch_norm_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((3,))) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor((3,), "float16")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var1, axis=1)) + + +def test_batch_norm_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((3, 1), "float32")) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor((1, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x, gamma1, beta, moving_mean, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x, gamma0, beta, moving_mean, moving_var1, axis=1)) + + +def test_batch_norm_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + c = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, c, 28, 28), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "float32")) + gamma2 = relax.Var("gamma", R.Tensor((c + 2,), "float32")) + beta0 = relax.Var("beta", R.Tensor((3,), "float32")) + beta1 = relax.Var("beta", R.Tensor((c,), "float32")) + moving_mean0 = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_mean1 = relax.Var("moving_mean", R.Tensor((c,), "float32")) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor((4,), "float32")) + moving_var2 = relax.Var("moving_var", R.Tensor((c,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta0, moving_mean0, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var1, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x1, gamma2, beta1, moving_mean1, moving_var2, axis=1)) + + +def test_batch_norm_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((3,), "float32"))) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var = relax.Var("moving_var", R.Tensor((3,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, moving_var, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, moving_var, axis=1)) + + +def test_layer_norm_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor("float32", ndim=2)) + gamma2 = relax.Var("gamma", R.Tensor((4, 5))) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, 5))) + + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, 3]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x1, gamma0, beta0, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x2, gamma0, beta0, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma1, beta0, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x3, gamma2, beta1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype=""), + ) + + +def test_layer_norm_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c0 = tir.Var("c", "int64") + c1 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((n, a, b, c0), "float32")) + x1 = relax.Var("x", R.Tensor((n, a, b, c1), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + gamma0 = relax.Var("gamma", R.Tensor((b, c0), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((b, c1), "float32")) + beta = relax.Var("beta", R.Tensor((b, c0), "float32")) + + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma0, beta, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c0), "float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x1, gamma0, beta, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x2, gamma0, beta, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x2, gamma1, beta, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_layer_norm_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo()) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=2)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=2)) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) + beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) + + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma, beta, axes=[2, 3]), + relax.TensorStructInfo(s0, "float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x1, gamma, beta, axes=[2, 3]), + relax.TensorStructInfo(s1, "float32"), + ) + + +def test_layer_norm_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float16")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float64")) + gamma1 = relax.Var("gamma", R.Tensor((4, 5), "float64")) + beta1 = relax.Var("beta", R.Tensor((4, 5), "float64")) + + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float16"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float64"), + ) + + +def test_layer_norm_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "int8")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int32")) + gamma1 = relax.Var("gamma", R.Tensor((4, 5), "int32")) + beta1 = relax.Var("beta", R.Tensor((4, 5), "int32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1])) + + +def test_layer_norm_infer_struct_info_axis_out_of_range_and_repetitive(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4, 5), "float32")) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma, beta, axes=[3, 4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma, beta, axes=[3, -1])) + + +def test_layer_norm_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4, 5), "int8")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma1, beta0, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma0, beta1, axes=[-2, -1])) + + +def test_layer_norm_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma1, beta0, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma0, beta1, axes=[-2, -1])) + + +def test_layer_norm_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + c0 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, c0), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 6), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4, c0), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, c0 - 2), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1])) + + +def test_layer_norm_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((4, 5), "float32"))) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x1, gamma0, beta, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1])) + + +def test_group_norm_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor("float32", ndim=1)) + gamma2 = relax.Var("gamma", R.Tensor((4,))) + beta0 = relax.Var("beta", R.Tensor((4,), "float32")) + beta1 = relax.Var("beta", R.Tensor((4,))) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x2, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x3, gamma2, beta1, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype=""), + ) + + +def test_group_norm_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c0 = tir.Var("c", "int64") + c1 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((n, a, b, c0), "float32")) + x1 = relax.Var("x", R.Tensor((n, a, b, c1), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + gamma0 = relax.Var("gamma", R.Tensor((a,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((a,), "float32")) + beta = relax.Var("beta", R.Tensor((a,), "float32")) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c0), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c1), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma1, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c0), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x2, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x2, gamma1, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_group_norm_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo()) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) + beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma, beta, num_groups=2, channel_axis=-2, axes=[1, 3]), + relax.TensorStructInfo(s0, "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma, beta, num_groups=2, channel_axis=-2, axes=[1, 3]), + relax.TensorStructInfo(s1, "float32"), + ) + + +def test_group_norm_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float16")) + beta0 = relax.Var("beta", R.Tensor((3,), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float64")) + gamma1 = relax.Var("gamma", R.Tensor((3,), "float64")) + beta1 = relax.Var("beta", R.Tensor((3,), "float64")) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=3, channel_axis=1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float16"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=3, channel_axis=1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float64"), + ) + + +def test_group_norm_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((4,), "int8")) + beta0 = relax.Var("beta", R.Tensor((4,), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "int32")) + beta1 = relax.Var("beta", R.Tensor((4,), "int32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_axis_out_of_range_and_repetitive(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4,), "float32")) + beta = relax.Var("beta", R.Tensor((4,), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=-2, axes=[3, 4]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=-2, axes=[3, -1]) + ) + + +def test_group_norm_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "int8")) + beta0 = relax.Var("beta", R.Tensor((4,), "float32")) + beta1 = relax.Var("beta", R.Tensor((4,))) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma0, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma0, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + c0 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, c0), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 6), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4, c0), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, c0 - 2), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((4, 5), "float32"))) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x1, gamma0, beta, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x0, gamma1, beta, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_dropout_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.nn.dropout(x0), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x2), + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x3), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), dtype=""), relax.TensorStructInfo((2, 3), dtype="")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x4), + relax.TupleStructInfo([relax.TensorStructInfo(dtype=""), relax.TensorStructInfo(dtype="")]), + ) + + +def test_dropout_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor((m, n), "float32")) + + _check_inference( + bb, + relax.op.nn.dropout(x), + relax.TupleStructInfo( + [relax.TensorStructInfo((m, n), "float32"), relax.TensorStructInfo((m, n), "float32")] + ), + ) + + +def test_dropout_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, + relax.op.nn.dropout(x0), + relax.TupleStructInfo( + [relax.TensorStructInfo(s0, "float32"), relax.TensorStructInfo(s0, "float32")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x1), + relax.TupleStructInfo( + [relax.TensorStructInfo(s1, "float32"), relax.TensorStructInfo(s1, "float32")] + ), + ) + + +def test_dropout_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + + _check_inference( + bb, + relax.op.nn.dropout(x0), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), "float64"), relax.TensorStructInfo((2, 3), "float64")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x1), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), "int8"), relax.TensorStructInfo((2, 3), "int8")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x2), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), "int64"), relax.TensorStructInfo((2, 3), "int64")] + ), + ) + + +def test_dropout_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.dropout(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.dropout(x1)) + + +def test_cross_entropy_infer_struct_info(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=2)) + y2 = relax.Var("y", R.Tensor((2, 3))) + y3 = relax.Var("y", R.Tensor(ndim=2)) + + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y0), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, + relax.op.nn.cross_entropy_with_logits(x, y1), + relax.TensorStructInfo((), dtype="float32"), + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y2), relax.TensorStructInfo((), dtype="") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y3), relax.TensorStructInfo((), dtype="") + ) + + +def test_cross_entropy_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m0 = tir.Var("m", "int64") + m1 = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m0, n), "float32")) + x1 = relax.Var("x", R.Tensor((m1, n), "float32")) + y = relax.Var("y", R.Tensor((m0, n), "float32")) + + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x0, y), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x1, y), relax.TensorStructInfo((), "float32") + ) + + +def test_cross_entropy_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y0), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y1), relax.TensorStructInfo((), "float32") + ) + + +def test_cross_entropy_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + y0 = relax.Var("y", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int32")) + y2 = relax.Var("y", R.Tensor((2, 3), "int32")) + + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x0, y0), relax.TensorStructInfo((), "float16") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x1, y1), relax.TensorStructInfo((), "int8") + ) + + +def test_cross_entropy_infer_struct_info_wrong_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x1, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y1)) + + +def test_cross_entropy_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 4), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y0)) + + +def test_cross_entropy_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x1, y)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py new file mode 100644 index 000000000000..d1d604429e93 --- /dev/null +++ b/tests/python/relax/test_op_nn_convolution.py @@ -0,0 +1,1190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_conv1d_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + assert relax.op.nn.conv1d(x, w).op == Op.get("relax.nn.conv1d") + + +def test_conv2d_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + assert relax.op.nn.conv2d(x, w).op == Op.get("relax.nn.conv2d") + assert relax.op.nn.conv2d_transpose(x, w).op == Op.get("relax.nn.conv2d_transpose") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_conv1d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 28, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor()) + x5 = relax.Var("x", R.Tensor((2, 4, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((3, 4, 3), "float32")) + w2 = relax.Var("w", R.Tensor("float32", ndim=3)) + w3 = relax.Var("w", R.Tensor("float32")) + w4 = relax.Var("w", R.Tensor((48, 4, 3, 16), "float32")) + + _check_inference(bb, relax.op.nn.conv1d(x0, w0), relax.TensorStructInfo((2, 4, 26), "float32")) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, out_dtype="float16"), + relax.TensorStructInfo((2, 4, 26), "float16"), + ) + _check_inference( + bb, relax.op.nn.conv1d(x0, w0, padding=1), relax.TensorStructInfo((2, 4, 28), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, padding=[1, 3]), + relax.TensorStructInfo((2, 4, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, strides=2), + relax.TensorStructInfo((2, 4, 13), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, strides=(2,)), + relax.TensorStructInfo((2, 4, 13), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, dilation=2), + relax.TensorStructInfo((2, 4, 24), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, dilation=(2,)), + relax.TensorStructInfo((2, 4, 24), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x1, w0, data_layout="NWC"), + relax.TensorStructInfo((2, 26, 4), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, out_layout="NWC"), + relax.TensorStructInfo((2, 26, 4), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w1, kernel_layout="IOW"), + relax.TensorStructInfo((2, 4, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d( + x5, w4, data_layout="NCW16c", kernel_layout="OIW16i", out_layout="NWC16c" + ), + relax.TensorStructInfo((2, 26, 3, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.conv1d(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.nn.conv1d(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.nn.conv1d(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.nn.conv1d(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.nn.conv1d(x4, w0), relax.TensorStructInfo(dtype="", ndim=3)) + + +def test_conv1d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + iw = tir.Var("iw", "int64") + ki = tir.Var("ki", "int64") + ko = tir.Var("ko", "int64") + kw = tir.Var("kw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, iw, c16), "float32")) + w0 = relax.Var("w", R.Tensor((ko, ki, kw), "float32")) + w1 = relax.Var("w", R.Tensor((ko, c, kw), "float32")) + w2 = relax.Var("w", R.Tensor((ko, c, kw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0), + relax.TensorStructInfo((n, ko, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w1), + relax.TensorStructInfo((n, ko, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x1, w2, data_layout="NCW16c", kernel_layout="OIW16i", out_layout="NCW"), + relax.TensorStructInfo((n, ko, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, strides=2, padding=1, dilation=2), + relax.TensorStructInfo( + (n, ko, tvm.tir.floordiv(iw + 3, 2) + 1 - kw), + "float32", + ), + ) + + +def test_conv1d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + w = relax.Var("w", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.nn.conv1d(x0, w), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference( + bb, + relax.op.nn.conv1d(x1, w, data_layout="NCW16c"), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w, out_layout="NCW16c"), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x2, w), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + +def test_conv1d_infer_struct_info_groups(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 128, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 8, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((48, 16, 3), "float32")) + w1 = relax.Var("w", R.Tensor((48, 2, 3, 8), "float32")) + + _check_inference( + bb, relax.op.nn.conv1d(x0, w0, groups=8), relax.TensorStructInfo((2, 48, 26), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w1, kernel_layout="OIW8i", groups=8), + relax.TensorStructInfo((2, 48, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x1, w0, data_layout="NCW16c", groups=8), + relax.TensorStructInfo((2, 3, 26, 16), "float32"), + ) + + +def test_conv1d_infer_struct_info_symbolic_groups(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x = relax.Var("x", R.Tensor((n, ic * 4, 28), "float32")) + w0 = relax.Var("w", R.Tensor((oc * 4, ic, 3), "float32")) + w1 = relax.Var("w", R.Tensor((oc, ic, 3), "float32")) + + _check_inference( + bb, + relax.op.nn.conv1d(x, w0, groups=4), + relax.TensorStructInfo((n, oc * 4, 26), "float32"), + ) + _check_inference( + bb, relax.op.nn.conv1d(x, w1, groups=4), relax.TensorStructInfo((n, oc, 26), "float32") + ) + + +def test_conv1d_infer_struct_info_input_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 128, 28), "float32")) + w0 = relax.Var("w", R.Tensor((48, 20, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic * 6, 28), "float32")) + w1 = relax.Var("w", R.Tensor((oc, ic - 1, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x1, w1, groups=6)) + + +def test_conv1d_infer_struct_info_output_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 120, 28), "float32")) + w0 = relax.Var("w", R.Tensor((128, 20, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic * 6, 28), "float32")) + w1 = relax.Var("w", R.Tensor((oc * 6 + 4, ic * 6, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x1, w1, groups=6)) + + +def test_conv1d_non_positive_group(): + x = relax.Var("x", R.Tensor((2, 128, 28), "float32")) + w = relax.Var("w", R.Tensor((48, 16, 3), "float32")) + + with pytest.raises(TVMError): + relax.op.nn.conv1d(x, w, groups=0) + with pytest.raises(TVMError): + relax.op.nn.conv1d(x, w, groups=-2) + + +def test_conv1d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16")) + w0 = relax.Var("w", R.Tensor((4, 3, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28), "float64")) + w1 = relax.Var("w", R.Tensor((4, 3, 3), "float64")) + x2 = relax.Var("x", R.Tensor((2, 3, 28), "int8")) + w2 = relax.Var("w", R.Tensor((4, 3, 3), "int8")) + x3 = relax.Var("x", R.Tensor((2, 3, 28), "int32")) + w3 = relax.Var("w", R.Tensor((4, 3, 3), "int32")) + + _check_inference(bb, relax.op.nn.conv1d(x0, w0), relax.TensorStructInfo((2, 4, 26), "float16")) + _check_inference(bb, relax.op.nn.conv1d(x1, w1), relax.TensorStructInfo((2, 4, 26), "float64")) + _check_inference(bb, relax.op.nn.conv1d(x2, w2), relax.TensorStructInfo((2, 4, 26), "int8")) + _check_inference(bb, relax.op.nn.conv1d(x3, w3), relax.TensorStructInfo((2, 4, 26), "int32")) + + +def test_conv1d_infer_struct_info_mixed_precision(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16")) + w0 = relax.Var("w", R.Tensor((4, 3, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28), "int8")) + w1 = relax.Var("w", R.Tensor((4, 3, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 28))) + w2 = relax.Var("w", R.Tensor((4, 3, 3))) + + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, out_dtype="float32"), + relax.TensorStructInfo((2, 4, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x1, w1, out_dtype="int32"), + relax.TensorStructInfo((2, 4, 26), "int32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x2, w2, out_dtype="float32"), + relax.TensorStructInfo((2, 4, 26), "float32"), + ) + + +def test_conv1d_unequal_input_channel(): + bb = relax.BlockBuilder() + ic = tir.Var("ic", "int64") + x0 = relax.Var("x", R.Tensor([2, 3, 28], "float32")) + w0 = relax.Var("w", R.Tensor([3, 4, 3], "float32")) + x1 = relax.Var("x", R.Tensor([2, ic, 28], "float32")) + w1 = relax.Var("w", R.Tensor([4, ic + 2, 3], "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x1, w1)) + + +def test_conv1d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + conv1d = relax.op.nn.conv1d(x, w, strides=(1,), padding=(1, 1), dilation=(1,)) + + assert conv1d.attrs.strides[0].dtype == "int64" + assert conv1d.attrs.padding[0].dtype == "int64" + assert conv1d.attrs.padding[1].dtype == "int64" + assert conv1d.attrs.dilation[0].dtype == "int64" + + +def test_conv1d_wrong_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + with pytest.raises(TVMError): + relax.op.nn.conv1d(x, w, strides=(1, 2)) + with pytest.raises(TVMError): + relax.op.nn.conv1d(x, w, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.conv1d(x, w, dilation=(1, 2)) + + +def test_conv1d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x, w, data_layout="OIW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x, w, kernel_layout="NWC")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x, w, out_layout="OWI")) + + +def test_conv1d_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3), "int8")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x, w)) + + +def test_conv1d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=2)) + w0 = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((4, 3, 6, 3), "float32")) + w2 = relax.Var("w", R.Tensor("float32", ndim=5)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w1, data_layout="NCW16c")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x1, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x2, w0)) + + +def test_conv1d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) + w0 = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((4, 3, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x1, w0)) + + +def test_conv2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 28, 28, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor()) + x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32")) + w2 = relax.Var("w", R.Tensor("float32", ndim=4)) + w3 = relax.Var("w", R.Tensor("float32")) + w4 = relax.Var("w", R.Tensor((48, 4, 3, 3, 16), "float32")) + + _check_inference( + bb, relax.op.nn.conv2d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, out_dtype="float16"), + relax.TensorStructInfo((2, 4, 26, 26), "float16"), + ) + _check_inference( + bb, relax.op.nn.conv2d(x0, w0, padding=1), relax.TensorStructInfo((2, 4, 28, 28), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, padding=[1, 2]), + relax.TensorStructInfo((2, 4, 28, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, padding=[1, 2, 3, 4]), + relax.TensorStructInfo((2, 4, 30, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, strides=2), + relax.TensorStructInfo((2, 4, 13, 13), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, strides=(2, 3)), + relax.TensorStructInfo((2, 4, 13, 9), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, dilation=2), + relax.TensorStructInfo((2, 4, 24, 24), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, dilation=(2, 1)), + relax.TensorStructInfo((2, 4, 24, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w0, data_layout="NHWC"), + relax.TensorStructInfo((2, 26, 26, 4), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, out_layout="NHWC"), + relax.TensorStructInfo((2, 26, 26, 4), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w1, kernel_layout="IOHW"), + relax.TensorStructInfo((2, 4, 26, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d( + x5, w4, data_layout="NCHW16c", kernel_layout="OIHW16i", out_layout="NHWC16c" + ), + relax.TensorStructInfo((2, 26, 26, 3, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.conv2d(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.nn.conv2d(x4, w0), relax.TensorStructInfo(dtype="", ndim=4)) + + +def test_conv2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + ki = tir.Var("ki", "int64") + ko = tir.Var("ko", "int64") + kh = tir.Var("kh", "int64") + kw = tir.Var("kw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + w0 = relax.Var("w", R.Tensor((ko, ki, kh, kw), "float32")) + w1 = relax.Var("w", R.Tensor((ko, c, kh, kw), "float32")) + w2 = relax.Var("w", R.Tensor((ko, c, kh, kw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0), + relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w1), + relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d( + x1, w2, data_layout="NCHW16c", kernel_layout="OIHW16i", out_layout="NCHW" + ), + relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, strides=(2, 2), padding=(1, 1), dilation=(2, 2)), + relax.TensorStructInfo( + (n, ko, tvm.tir.floordiv(ih + 3, 2) + 1 - kh, tvm.tir.floordiv(iw + 3, 2) + 1 - kw), + "float32", + ), + ) + + +def test_conv2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + w = relax.Var("w", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.nn.conv2d(x0, w), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w, data_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w, out_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x2, w), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_conv2d_infer_struct_info_groups(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 8, 28, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((48, 16, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((48, 2, 3, 3, 8), "float32")) + + _check_inference( + bb, relax.op.nn.conv2d(x0, w0, groups=8), relax.TensorStructInfo((2, 48, 26, 26), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w1, kernel_layout="OIHW8i", groups=8), + relax.TensorStructInfo((2, 48, 26, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w0, data_layout="NCHW16c", groups=8), + relax.TensorStructInfo((2, 3, 26, 26, 16), "float32"), + ) + + +def test_conv2d_infer_struct_info_symbolic_groups(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x = relax.Var("x", R.Tensor((n, ic * 4, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((oc * 4, ic, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((oc, ic, 3, 3), "float32")) + + _check_inference( + bb, + relax.op.nn.conv2d(x, w0, groups=4), + relax.TensorStructInfo((n, oc * 4, 26, 26), "float32"), + ) + _check_inference( + bb, relax.op.nn.conv2d(x, w1, groups=4), relax.TensorStructInfo((n, oc, 26, 26), "float32") + ) + + +def test_conv2d_infer_struct_info_input_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((48, 20, 3, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32")) + w1 = relax.Var("w", R.Tensor((oc, ic - 1, 3, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6)) + + +def test_conv2d_infer_struct_info_output_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 120, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((128, 20, 3, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32")) + w1 = relax.Var("w", R.Tensor((oc * 6 + 4, ic * 6, 3, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6)) + + +def test_conv2d_non_positive_group(): + x = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((48, 16, 3, 3), "float32")) + + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, groups=0) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, groups=-2) + + +def test_conv2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float64")) + w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float64")) + x2 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + w2 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8")) + x3 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int32")) + w3 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int32")) + + _check_inference( + bb, relax.op.nn.conv2d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26), "float16") + ) + _check_inference( + bb, relax.op.nn.conv2d(x1, w1), relax.TensorStructInfo((2, 4, 26, 26), "float64") + ) + _check_inference(bb, relax.op.nn.conv2d(x2, w2), relax.TensorStructInfo((2, 4, 26, 26), "int8")) + _check_inference( + bb, relax.op.nn.conv2d(x3, w3), relax.TensorStructInfo((2, 4, 26, 26), "int32") + ) + + +def test_conv2d_infer_struct_info_mixed_precision(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 28, 28))) + w2 = relax.Var("w", R.Tensor((4, 3, 3, 3))) + + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, out_dtype="float32"), + relax.TensorStructInfo((2, 4, 26, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w1, out_dtype="int32"), + relax.TensorStructInfo((2, 4, 26, 26), "int32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x2, w2, out_dtype="float32"), + relax.TensorStructInfo((2, 4, 26, 26), "float32"), + ) + + +def test_conv2d_unequal_input_channel(): + bb = relax.BlockBuilder() + ic = tir.Var("ic", "int64") + x0 = relax.Var("x", R.Tensor([2, 3, 28, 28], "float32")) + w0 = relax.Var("w", R.Tensor([3, 4, 3, 3], "float32")) + x1 = relax.Var("x", R.Tensor([2, ic, 28, 28], "float32")) + w1 = relax.Var("w", R.Tensor([4, ic + 2, 3, 3], "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w1)) + + +def test_conv2d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + conv2d = relax.op.nn.conv2d(x, w, strides=(1, 1), padding=(1, 1), dilation=(1, 1)) + + assert conv2d.attrs.strides[0].dtype == "int64" + assert conv2d.attrs.strides[1].dtype == "int64" + assert conv2d.attrs.padding[0].dtype == "int64" + assert conv2d.attrs.padding[1].dtype == "int64" + assert conv2d.attrs.padding[2].dtype == "int64" + assert conv2d.attrs.padding[3].dtype == "int64" + assert conv2d.attrs.dilation[0].dtype == "int64" + assert conv2d.attrs.dilation[1].dtype == "int64" + + +def test_conv2d_wrong_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, strides=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, dilation=(1, 2, 3)) + + +def test_conv2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x, w, data_layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x, w, kernel_layout="NHWC")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x, w, out_layout="OHWI")) + + +def test_conv2d_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x, w)) + + +def test_conv2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((4, 3, 6, 3, 3), "float32")) + w2 = relax.Var("w", R.Tensor("float32", ndim=6)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w1, data_layout="NCHW16c")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x2, w0)) + + +def test_conv2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((4, 3, 3, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w0)) + + +def test_conv2d_transpose_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 28, 28, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor()) + x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + w2 = relax.Var("w", R.Tensor("float32", ndim=4)) + w3 = relax.Var("w", R.Tensor("float32")) + w4 = relax.Var("w", R.Tensor((4, 48, 3, 3, 16), "float32")) + + _check_inference( + bb, relax.op.nn.conv2d_transpose(x0, w0), relax.TensorStructInfo((2, 4, 30, 30), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float16"), + relax.TensorStructInfo((2, 4, 30, 30), "float16"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, padding=1), + relax.TensorStructInfo((2, 4, 28, 28), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, padding=[1, 2]), + relax.TensorStructInfo((2, 4, 28, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, padding=[1, 2, 3, 4]), + relax.TensorStructInfo((2, 4, 26, 24), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, strides=3, output_padding=1), + relax.TensorStructInfo((2, 4, 85, 85), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, strides=3, output_padding=[2, 1]), + relax.TensorStructInfo((2, 4, 86, 85), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, strides=2), + relax.TensorStructInfo((2, 4, 57, 57), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, strides=(2, 3)), + relax.TensorStructInfo((2, 4, 57, 84), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, dilation=2), + relax.TensorStructInfo((2, 4, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, dilation=(2, 1)), + relax.TensorStructInfo((2, 4, 32, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x1, w0, data_layout="NHWC"), + relax.TensorStructInfo((2, 30, 30, 4), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, out_layout="NHWC"), + relax.TensorStructInfo((2, 30, 30, 4), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w1, kernel_layout="OIHW"), + relax.TensorStructInfo((2, 4, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose( + x5, w4, data_layout="NCHW16c", kernel_layout="IOHW16i", out_layout="NHWC16c" + ), + relax.TensorStructInfo((2, 30, 30, 3, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.conv2d_transpose(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d_transpose(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d_transpose(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d_transpose(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d_transpose(x4, w0), relax.TensorStructInfo(dtype="", ndim=4) + ) + + +def test_conv2d_transpose_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + ki = tir.Var("ki", "int64") + ko = tir.Var("ko", "int64") + kh = tir.Var("kh", "int64") + kw = tir.Var("kw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + w0 = relax.Var("w", R.Tensor((ki, ko, kh, kw), "float32")) + w1 = relax.Var("w", R.Tensor((c, ko, kh, kw), "float32")) + w2 = relax.Var("w", R.Tensor((c, ko, kh, kw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0), + relax.TensorStructInfo((n, ko, ih + kh - 1, iw + kw - 1), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w1), + relax.TensorStructInfo((n, ko, ih + kh - 1, iw + kw - 1), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose( + x1, w2, data_layout="NCHW16c", kernel_layout="IOHW16i", out_layout="NCHW" + ), + relax.TensorStructInfo((n, ko, ih + kh - 1, iw + kw - 1), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose( + x0, w0, strides=(2, 2), padding=(1, 1), output_padding=(1, 0), dilation=(2, 2) + ), + relax.TensorStructInfo( + (n, ko, ih * 2 + kh * 2 - 4, iw * 2 + kw * 2 - 5), + "float32", + ), + ) + + +def test_conv2d_transpose_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + w = relax.Var("w", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.conv2d_transpose(x0, w), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x1, w, data_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w, out_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x2, w), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_conv2d_transpose_infer_struct_info_groups(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 8, 28, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((128, 6, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((16, 6, 3, 3, 8), "float32")) + + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, groups=8), + relax.TensorStructInfo((2, 48, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w1, kernel_layout="IOHW8i", groups=8), + relax.TensorStructInfo((2, 48, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x1, w0, data_layout="NCHW16c", groups=8), + relax.TensorStructInfo((2, 3, 30, 30, 16), "float32"), + ) + + +def test_conv2d_transpose_infer_struct_info_symbolic_groups(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x = relax.Var("x", R.Tensor((n, ic * 4, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((ic, oc, 3, 3), "float32")) + + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x, w0, groups=4), + relax.TensorStructInfo((n, oc * 4, 30, 30), "float32"), + ) + + +def test_conv2d_transpose_infer_struct_info_input_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((128, 20, 3, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic, 28, 28), "float32")) + w1 = relax.Var("w", R.Tensor((ic - 1, oc, 3, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x1, w1, groups=6)) + + +def test_conv2d_transpose_non_positive_group(): + x = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((128, 16, 3, 3), "float32")) + + with pytest.raises(TVMError): + relax.op.nn.conv2d_transpose(x, w, groups=0) + with pytest.raises(TVMError): + relax.op.nn.conv2d_transpose(x, w, groups=-2) + + +def test_conv2d_transpose_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) + w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float64")) + w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float64")) + x2 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + w2 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8")) + x3 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int32")) + w3 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int32")) + + _check_inference( + bb, relax.op.nn.conv2d_transpose(x0, w0), relax.TensorStructInfo((2, 4, 30, 30), "float16") + ) + _check_inference( + bb, relax.op.nn.conv2d_transpose(x1, w1), relax.TensorStructInfo((2, 4, 30, 30), "float64") + ) + _check_inference( + bb, relax.op.nn.conv2d_transpose(x2, w2), relax.TensorStructInfo((2, 4, 30, 30), "int8") + ) + _check_inference( + bb, relax.op.nn.conv2d_transpose(x3, w3), relax.TensorStructInfo((2, 4, 30, 30), "int32") + ) + + +def test_conv2d_transpose_unequal_input_channel(): + bb = relax.BlockBuilder() + ic = tir.Var("ic", "int64") + x0 = relax.Var("x", R.Tensor([2, 3, 28, 28], "float32")) + w0 = relax.Var("w", R.Tensor([4, 3, 3, 3], "float32")) + x1 = relax.Var("x", R.Tensor([2, ic, 28, 28], "float32")) + w1 = relax.Var("w", R.Tensor([ic + 2, 4, 3, 3], "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x1, w1)) + + +def test_conv2d_transpose_wrong_output_padding(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor([2, 3, 28, 28], "float32")) + w0 = relax.Var("w", R.Tensor([3, 4, 3, 3], "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x0, w0, strides=2, output_padding=2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x0, w0, strides=(2, 2), output_padding=(2, 2))) + + +def test_conv2d_transpose_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32")) + conv2d_transpose = relax.op.nn.conv2d_transpose( + x, w, strides=(1, 1), padding=(1, 1), output_padding=(1, 2), dilation=(1, 1) + ) + + assert conv2d_transpose.attrs.strides[0].dtype == "int64" + assert conv2d_transpose.attrs.strides[1].dtype == "int64" + assert conv2d_transpose.attrs.padding[0].dtype == "int64" + assert conv2d_transpose.attrs.padding[1].dtype == "int64" + assert conv2d_transpose.attrs.padding[2].dtype == "int64" + assert conv2d_transpose.attrs.padding[3].dtype == "int64" + assert conv2d_transpose.attrs.output_padding[0].dtype == "int64" + assert conv2d_transpose.attrs.output_padding[1].dtype == "int64" + assert conv2d_transpose.attrs.dilation[0].dtype == "int64" + assert conv2d_transpose.attrs.dilation[1].dtype == "int64" + + +def test_conv2d_transpose_wrong_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32")) + with pytest.raises(TVMError): + relax.op.nn.conv2d_transpose(x, w, strides=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.conv2d_transpose(x, w, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.conv2d_transpose(x, w, output_padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.conv2d_transpose(x, w, dilation=(1, 2, 3)) + + +def test_conv2d_transpose_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x, w, data_layout="IOHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x, w, kernel_layout="NHWC")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x, w, out_layout="OHWI")) + + +def test_conv2d_transpose_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x, w)) + + +def test_conv2d_transpose_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((3, 4, 6, 3, 3), "float32")) + w2 = relax.Var("w", R.Tensor("float32", ndim=6)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x0, w1, data_layout="NCHW16c")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x0, w2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x1, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x2, w0)) + + +def test_conv2d_transpose_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32")) + w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((3, 4, 3, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d_transpose(x1, w0)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py new file mode 100644 index 000000000000..2bd7747f3132 --- /dev/null +++ b/tests/python/relax/test_op_nn_pooling.py @@ -0,0 +1,655 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + assert relax.op.nn.max_pool2d(x).op == Op.get("relax.nn.max_pool2d") + assert relax.op.nn.avg_pool2d(x).op == Op.get("relax.nn.avg_pool2d") + assert relax.op.nn.adaptive_avg_pool2d(x).op == Op.get("relax.nn.adaptive_avg_pool2d") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_max_pool2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) + + _check_inference( + bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, pool_size=3), + relax.TensorStructInfo((2, 3, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, pool_size=(5, 3)), + relax.TensorStructInfo((2, 3, 28, 30), "float32"), + ) + _check_inference( + bb, relax.op.nn.max_pool2d(x0, padding=1), relax.TensorStructInfo((2, 3, 34, 34), "float32") + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, padding=[1, 2]), + relax.TensorStructInfo((2, 3, 34, 36), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, strides=2), + relax.TensorStructInfo((2, 3, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, dilation=2), + relax.TensorStructInfo((2, 3, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x1, layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, out_layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), + relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.max_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.max_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.nn.max_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.nn.max_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4)) + + +def test_max_pool2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool2d( + x0, pool_size=(3, 3), strides=(3, 3), padding=(2, 2), dilation=(2, 2) + ), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(ih - 1, 3) + 1, + tvm.tir.floordiv(iw - 1, 3) + 1, + ), + "float32", + ), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x1, layout="NCHW16c", out_layout="NHWC"), + relax.TensorStructInfo((n, ih, iw, c * 16), "float32"), + ) + + +def test_max_pool2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x1, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x2), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_max_pool2d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool2d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 15, 16), "float32"), + ) + + +def test_max_pool2d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool2d( + x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2, 2), ceil_mode=True + ), + relax.TensorStructInfo((n, c, tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), "float32"), + ) + + +def test_max_pool2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) + _check_inference( + bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float16") + ) + _check_inference(bb, relax.op.nn.max_pool2d(x1), relax.TensorStructInfo((2, 3, 32, 32), "int8")) + _check_inference( + bb, relax.op.nn.max_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), "int64") + ) + + +def test_max_pool2d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + max_pool2d = relax.op.nn.max_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 1), dilation=(1, 1)) + + assert max_pool2d.attrs.strides[0].dtype == "int64" + assert max_pool2d.attrs.strides[1].dtype == "int64" + assert max_pool2d.attrs.padding[0].dtype == "int64" + assert max_pool2d.attrs.padding[1].dtype == "int64" + assert max_pool2d.attrs.padding[2].dtype == "int64" + assert max_pool2d.attrs.padding[3].dtype == "int64" + assert max_pool2d.attrs.dilation[0].dtype == "int64" + assert max_pool2d.attrs.dilation[1].dtype == "int64" + + +def test_max_pool2d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + relax.op.nn.max_pool2d(x, pool_size=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.max_pool2d(x, strides=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.max_pool2d(x, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.max_pool2d(x, dilation=(1, 2, 3)) + + +def test_max_pool2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x, layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x, out_layout="OHWI")) + + +def test_max_pool2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x1)) + + +def test_max_pool2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x1)) + + +def test_avg_pool2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) + + _check_inference( + bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x0, pool_size=3), + relax.TensorStructInfo((2, 3, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x0, pool_size=(5, 3)), + relax.TensorStructInfo((2, 3, 28, 30), "float32"), + ) + _check_inference( + bb, relax.op.nn.avg_pool2d(x0, padding=1), relax.TensorStructInfo((2, 3, 34, 34), "float32") + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x0, padding=[1, 2]), + relax.TensorStructInfo((2, 3, 34, 36), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x0, strides=2), + relax.TensorStructInfo((2, 3, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x0, dilation=2), + relax.TensorStructInfo((2, 3, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x1, layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x0, out_layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), + relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.avg_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.avg_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.nn.avg_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.nn.avg_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4)) + + +def test_avg_pool2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool2d( + x0, pool_size=(3, 3), strides=(3, 3), padding=(2, 2), dilation=(2, 2) + ), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(ih - 1, 3) + 1, + tvm.tir.floordiv(iw - 1, 3) + 1, + ), + "float32", + ), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x1, layout="NCHW16c", out_layout="NHWC"), + relax.TensorStructInfo((n, ih, iw, c * 16), "float32"), + ) + + +def test_avg_pool2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x1, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x2), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_avg_pool2d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool2d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 15, 16), "float32"), + ) + + +def test_avg_pool2d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool2d( + x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2, 2), ceil_mode=True + ), + relax.TensorStructInfo((n, c, tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), "float32"), + ) + + +def test_avg_pool2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) + _check_inference( + bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float16") + ) + _check_inference(bb, relax.op.nn.avg_pool2d(x1), relax.TensorStructInfo((2, 3, 32, 32), "int8")) + _check_inference( + bb, relax.op.nn.avg_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), "int64") + ) + + +def test_avg_pool2d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + avg_pool2d = relax.op.nn.avg_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 1), dilation=(1, 1)) + + assert avg_pool2d.attrs.strides[0].dtype == "int64" + assert avg_pool2d.attrs.strides[1].dtype == "int64" + assert avg_pool2d.attrs.padding[0].dtype == "int64" + assert avg_pool2d.attrs.padding[1].dtype == "int64" + assert avg_pool2d.attrs.padding[2].dtype == "int64" + assert avg_pool2d.attrs.padding[3].dtype == "int64" + assert avg_pool2d.attrs.dilation[0].dtype == "int64" + assert avg_pool2d.attrs.dilation[1].dtype == "int64" + + +def test_avg_pool2d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + relax.op.nn.avg_pool2d(x, pool_size=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool2d(x, strides=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool2d(x, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool2d(x, dilation=(1, 2, 3)) + + +def test_avg_pool2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool2d(x, layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool2d(x, out_layout="OHWI")) + + +def test_avg_pool2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool2d(x1)) + + +def test_avg_pool2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool2d(x1)) + + +def test_adaptive_avg_pool2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) + + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=30), + relax.TensorStructInfo((2, 3, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=(28, 30)), + relax.TensorStructInfo((2, 3, 28, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x1, layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, out_layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), + relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4) + ) + + +def test_adaptive_avg_pool2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((n, c, ih, iw), "float32") + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=256), + relax.TensorStructInfo((n, c, 256, 256), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=(256, 128)), + relax.TensorStructInfo((n, c, 256, 128), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x1, layout="NCHW16c", out_layout="NHWC"), + relax.TensorStructInfo((n, ih, iw, c * 16), "float32"), + ) + + +def test_adaptive_avg_pool2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=32), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x1, layout="NCHW16c"), + relax.TensorStructInfo(s1, "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, out_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x2, out_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + + +def test_adaptive_avg_pool2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float16") + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x1), relax.TensorStructInfo((2, 3, 32, 32), "int8") + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), "int64") + ) + + +def test_adaptive_avg_pool2d_wrong_output_size_ndim(): + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + with pytest.raises(TVMError): + relax.op.nn.adaptive_avg_pool2d(x, (32, 32, 32)) + + +def test_adaptive_avg_pool2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x, layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x, out_layout="OHWI")) + + +def test_adaptive_avg_pool2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x1)) + + +def test_adaptive_avg_pool2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_search.py b/tests/python/relax/test_op_search.py new file mode 100644 index 000000000000..ba78d11022b4 --- /dev/null +++ b/tests/python/relax/test_op_search.py @@ -0,0 +1,436 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Callable + +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + cond = relax.Var("cond", R.Tensor((2, 3), "bool")) + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("x", R.Tensor((2, 3), "float32")) + assert relax.op.where(cond, x, y).op == Op.get("relax.where") + assert relax.op.argmax(x).op == Op.get("relax.argmax") + assert relax.op.argmin(x).op == Op.get("relax.argmin") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_where_infer_struct_info(): + bb = relax.BlockBuilder() + cond0 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool")) + cond1 = relax.Var("cond", R.Tensor("bool", ndim=5)) + cond2 = relax.Var("cond", R.Tensor("bool")) + x0 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((5, 1, 3, 2))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((4, 3, 1), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=3)) + y2 = relax.Var("y", R.Tensor("float32")) + y3 = relax.Var("y", R.Tensor((4, 3, 1))) + y4 = relax.Var("y", R.Tensor(ndim=3)) + y5 = relax.Var("y", R.Tensor()) + + _check_inference( + bb, relax.op.where(cond0, x0, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), "float32") + ) + _check_inference( + bb, relax.op.where(cond0, x1, y0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond0, x2, y0), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.where(cond0, x3, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), dtype="") + ) + _check_inference(bb, relax.op.where(cond0, x4, y0), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y0), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.where(cond0, x1, y1), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond0, x2, y1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x3, y1), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x4, y1), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x2, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x3, y2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x4, y2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x5, y2), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.where(cond0, x3, y3), relax.TensorStructInfo((6, 5, 4, 3, 2), dtype="") + ) + _check_inference(bb, relax.op.where(cond0, x4, y3), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y3), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x4, y4), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y4), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x5, y5), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.where(cond1, x0, y0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond1, x2, y0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond2, x0, y0), relax.TensorStructInfo(dtype="float32")) + + +def test_where_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d0 = tir.Var("d", "int64") + d1 = tir.Var("d", "int64") + e = tir.Var("e", "int64") + cond = relax.Var("cond", R.Tensor((a, b, 1, d0, 1), "bool")) + x0 = relax.Var("x", R.Tensor((b, 1, d0, e), "float32")) + x1 = relax.Var("x", R.Tensor((b, 1, d1, e), "float32")) + x2 = relax.Var("x", R.Tensor((b, 1, d0, e))) + y0 = relax.Var("y", R.Tensor((c, d0, 1), "float32")) + y1 = relax.Var("y", R.Tensor((c, d0, 1))) + + _check_inference( + bb, relax.op.where(cond, x0, y0), relax.TensorStructInfo((a, b, c, d0, e), "float32") + ) + _check_inference( + bb, relax.op.where(cond, x1, y0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.where(cond, x2, y0), relax.TensorStructInfo((a, b, c, d0, e), dtype="") + ) + _check_inference( + bb, relax.op.where(cond, x0, y1), relax.TensorStructInfo((a, b, c, d0, e), dtype="") + ) + _check_inference(bb, relax.op.where(cond, x1, y1), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference( + bb, relax.op.where(cond, x2, y1), relax.TensorStructInfo((a, b, c, d0, e), dtype="") + ) + + +def test_where_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + scond0 = relax.Var("scond", relax.ShapeStructInfo((6, 5, 1, 3, 1))) + scond1 = relax.Var("scond", relax.ShapeStructInfo(ndim=5)) + scond2 = relax.Var("scond", relax.ShapeStructInfo()) + sx0 = relax.Var("sx", relax.ShapeStructInfo((5, 1, 3, 2))) + sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=4)) + sx2 = relax.Var("sx", relax.ShapeStructInfo()) + sy0 = relax.Var("sy", relax.ShapeStructInfo((4, 3, 1))) + sy1 = relax.Var("sy", relax.ShapeStructInfo(ndim=3)) + sy2 = relax.Var("sy", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeStructInfo((6, 5, 4, 3, 2))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + cond0 = relax.Var("cond", relax.TensorStructInfo(scond0, "bool")) + cond1 = relax.Var("cond", relax.TensorStructInfo(scond1, "bool")) + cond2 = relax.Var("cond", relax.TensorStructInfo(scond2, "bool")) + cond3 = relax.Var("cond", relax.TensorStructInfo(s0, "bool")) + cond4 = relax.Var("cond", relax.TensorStructInfo(s1, "bool")) + cond5 = relax.Var("cond", relax.TensorStructInfo(s2, "bool")) + x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(sy0, "float32")) + y1 = relax.Var("y", relax.TensorStructInfo(sy1, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(sy2, "float32")) + y3 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) + y4 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + y5 = relax.Var("y", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.where(cond0, x0, y0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.where(cond0, x0, y1), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond0, x0, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.where(cond0, x1, y1), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond0, x1, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x2, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.where(cond1, x1, y1), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond1, x1, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond1, x2, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond2, x2, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond3, x3, y3), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, relax.op.where(cond3, x3, y4), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.where(cond3, x4, y3), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.where(cond4, x3, y3), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond4, x4, y4), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.where(cond4, x4, y5), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond4, x5, y4), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond5, x4, y4), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond5, x5, y5), relax.TensorStructInfo(s2, "float32")) + + +def test_where_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + cond = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool")) + x0 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float16")) + y0 = relax.Var("y", R.Tensor((4, 3, 1), "float16")) + x1 = relax.Var("x", R.Tensor((5, 1, 3, 2), "int8")) + y1 = relax.Var("y", R.Tensor((4, 3, 1), "int8")) + x2 = relax.Var("x", R.Tensor((5, 1, 3, 2), "int32")) + y2 = relax.Var("y", R.Tensor((4, 3, 1), "int32")) + + _check_inference( + bb, relax.op.where(cond, x0, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), "float16") + ) + _check_inference( + bb, relax.op.where(cond, x1, y1), relax.TensorStructInfo((6, 5, 4, 3, 2), "int8") + ) + _check_inference( + bb, relax.op.where(cond, x2, y2), relax.TensorStructInfo((6, 5, 4, 3, 2), "int32") + ) + + +def test_where_infer_struct_info_cond_not_boolean(): + bb = relax.BlockBuilder() + cond0 = relax.Var("cond", R.Tensor((2, 3), "float32")) + cond1 = relax.Var("cond", R.Tensor((2, 3))) + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond0, x, y)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x, y)) + + +def test_where_infer_struct_info_shape_unequal_const_int(): + bb = relax.BlockBuilder() + cond0 = relax.Var("cond", R.Tensor((6, 5, 1, 4, 1), "bool")) + cond1 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool")) + x0 = relax.Var("x", R.Tensor((5, 1, 4, 2), "float32")) + x1 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float32")) + y0 = relax.Var("y", R.Tensor((4, 4, 1), "float32")) + y1 = relax.Var("y", R.Tensor((4, 3, 1), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond0, x1, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x0, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x1, y0)) + + +def test_where_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + cond = relax.Var("cond", R.Tensor((2, 3), "bool")) + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond, x0, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond, x1, y1)) + + +def test_where_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + cond0 = relax.Var("cond", relax.ShapeStructInfo((2, 3))) + cond1 = relax.Var("cond", R.Tensor((2, 3), "bool")) + x0 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x1 = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", relax.TupleStructInfo([R.Tensor((2, 3), "float32")])) + y1 = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond0, x1, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x0, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x1, y0)) + + +(argmax_argmin_op,) = tvm.testing.parameters((relax.op.argmax,), (relax.op.argmin,)) + + +def test_argmax_argmin_infer_struct_info(argmax_argmin_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + + _check_inference(bb, argmax_argmin_op(x0, axis=1), relax.TensorStructInfo((2, 4, 5), "int64")) + _check_inference( + bb, + argmax_argmin_op(x0, axis=1, keepdims=True), + relax.TensorStructInfo((2, 1, 4, 5), "int64"), + ) + _check_inference(bb, argmax_argmin_op(x0, axis=None), relax.TensorStructInfo((), "int64")) + _check_inference( + bb, + argmax_argmin_op(x0, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "int64"), + ) + _check_inference( + bb, argmax_argmin_op(x1, axis=1), relax.TensorStructInfo(dtype="int64", ndim=3) + ) + _check_inference( + bb, + argmax_argmin_op(x1, axis=1, keepdims=True), + relax.TensorStructInfo(dtype="int64", ndim=4), + ) + _check_inference(bb, argmax_argmin_op(x1, axis=None), relax.TensorStructInfo((), "int64")) + _check_inference( + bb, + argmax_argmin_op(x1, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "int64"), + ) + _check_inference(bb, argmax_argmin_op(x2, axis=1), relax.TensorStructInfo(dtype="int64")) + _check_inference( + bb, + argmax_argmin_op(x2, axis=1, keepdims=True), + relax.TensorStructInfo(dtype="int64"), + ) + _check_inference(bb, argmax_argmin_op(x2, axis=None), relax.TensorStructInfo((), "int64")) + _check_inference( + bb, + argmax_argmin_op(x2, axis=None, keepdims=True), + relax.TensorStructInfo(dtype="int64"), + ) + _check_inference( + bb, argmax_argmin_op(x3, axis=1), relax.TensorStructInfo((2, 4, 5), dtype="int64") + ) + _check_inference( + bb, + argmax_argmin_op(x3, axis=1, keepdims=True), + relax.TensorStructInfo((2, 1, 4, 5), dtype="int64"), + ) + _check_inference(bb, argmax_argmin_op(x3, axis=None), relax.TensorStructInfo((), dtype="int64")) + _check_inference( + bb, + argmax_argmin_op(x3, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), dtype="int64"), + ) + _check_inference( + bb, + argmax_argmin_op(x0, axis=1, keepdims=True), + relax.TensorStructInfo((2, 1, 4, 5), "int64"), + ) + _check_inference(bb, argmax_argmin_op(x0, axis=-1), relax.TensorStructInfo((2, 3, 4), "int64")) + + +def test_argmax_argmin_infer_struct_info_shape_symbolic(argmax_argmin_op: Callable): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((a, b, c, d), "int64")) + + _check_inference(bb, argmax_argmin_op(x, axis=1), relax.TensorStructInfo((a, c, d), "int64")) + _check_inference( + bb, + argmax_argmin_op(x, axis=1, keepdims=True), + relax.TensorStructInfo((a, 1, c, d), "int64"), + ) + _check_inference(bb, argmax_argmin_op(x, axis=None), relax.TensorStructInfo((), "int64")) + _check_inference( + bb, + argmax_argmin_op(x, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "int64"), + ) + + +def test_argmax_argmin_infer_struct_info_shape_var(argmax_argmin_op: Callable): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "int64")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "int64")) + + _check_inference(bb, argmax_argmin_op(x0), relax.TensorStructInfo((), dtype="int64")) + _check_inference( + bb, argmax_argmin_op(x0, keepdims=True), relax.TensorStructInfo((1, 1, 1, 1), dtype="int64") + ) + _check_inference( + bb, argmax_argmin_op(x0, axis=2), relax.TensorStructInfo(dtype="int64", ndim=3) + ) + _check_inference( + bb, + argmax_argmin_op(x0, axis=2, keepdims=True), + relax.TensorStructInfo(dtype="int64", ndim=4), + ) + _check_inference(bb, argmax_argmin_op(x1), relax.TensorStructInfo((), dtype="int64")) + _check_inference(bb, argmax_argmin_op(x1, keepdims=True), relax.TensorStructInfo(dtype="int64")) + _check_inference(bb, argmax_argmin_op(x1, axis=2), relax.TensorStructInfo(dtype="int64")) + _check_inference( + bb, argmax_argmin_op(x1, axis=2, keepdims=True), relax.TensorStructInfo(dtype="int64") + ) + + +def test_argmax_argmin_infer_struct_info_more_input_dtype(argmax_argmin_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + + _check_inference(bb, argmax_argmin_op(x0), relax.TensorStructInfo((), "int64")) + _check_inference(bb, argmax_argmin_op(x1), relax.TensorStructInfo((), "int64")) + + +def test_argmax_argmin_infer_struct_info_axis_out_of_range(argmax_argmin_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int64")) + x1 = relax.Var("x", R.Tensor("int64", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(argmax_argmin_op(x0, axis=4)) + with pytest.raises(TVMError): + bb.normalize(argmax_argmin_op(x0, axis=-5)) + with pytest.raises(TVMError): + bb.normalize(argmax_argmin_op(x1, axis=4)) + with pytest.raises(TVMError): + bb.normalize(argmax_argmin_op(x1, axis=-5)) + + +def test_argmax_argmin_infer_struct_info_wrong_input_type(argmax_argmin_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "int64"))) + + with pytest.raises(TVMError): + bb.normalize(argmax_argmin_op(x0)) + with pytest.raises(TVMError): + bb.normalize(argmax_argmin_op(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_set.py b/tests/python/relax/test_op_set.py new file mode 100644 index 000000000000..755d5e8f870c --- /dev/null +++ b/tests/python/relax/test_op_set.py @@ -0,0 +1,862 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + assert relax.op.unique(x).op == Op.get("relax.unique") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_unique_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=False, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=True, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x0, return_index=True, return_inverse=False, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=False, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=False, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=False, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=-2), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x0, sorted=True, return_index=True, return_inverse=True, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x0, sorted=True, return_index=True, return_inverse=True, return_counts=True, axis=1 + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x1, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.unique( + x1, return_index=False, return_inverse=True, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=True, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=False, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x2, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.unique( + x2, return_index=True, return_inverse=False, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=False, return_counts=False, axis=1), + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="int64", ndim=1)] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=False, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x3, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="", ndim=3), + ) + _check_inference( + bb, + relax.op.unique( + x3, return_index=False, return_inverse=False, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + + +def test_unique_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x = relax.Var("x", R.Tensor((a, b, c), "float32")) + + _check_inference( + bb, + relax.op.unique( + x, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=False, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + + +def test_unique_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=False, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x1, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.unique( + x1, return_index=False, return_inverse=False, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="int64", ndim=1)] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + + +def test_unique_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) + + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float16", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="int8", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="int32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + + +def test_unique_infer_struct_info_input_zero_rank(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(())) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + x0 = relax.Var("x", R.Tensor((), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=0)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((1,), "float32"), + relax.TensorStructInfo((1,), "int64"), + relax.TensorStructInfo((1,), "int64"), + relax.TensorStructInfo((1,), "int64"), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=False, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((1,), "float32"), + relax.TensorStructInfo((1,), "int64"), + relax.TensorStructInfo((1,), "int64"), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x2, return_index=True, return_inverse=False, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [relax.TensorStructInfo((1,), "float32"), relax.TensorStructInfo((1,), "int64")] + ), + ) + _check_inference( + bb, + relax.op.unique( + x3, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo((1,), "float32"), + ) + + +def test_unique_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x0, axis=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x0, axis=-4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x1, axis=0)) + + +def test_unique_infer_struct_info_wrong_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_statistical.py b/tests/python/relax/test_op_statistical.py new file mode 100644 index 000000000000..b1bdd8e44d85 --- /dev/null +++ b/tests/python/relax/test_op_statistical.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + assert relax.op.max(x).op == Op.get("relax.max") + assert relax.op.mean(x).op == Op.get("relax.mean") + assert relax.op.min(x).op == Op.get("relax.min") + assert relax.op.prod(x).op == Op.get("relax.prod") + assert relax.op.std(x).op == Op.get("relax.std") + assert relax.op.sum(x).op == Op.get("relax.sum") + assert relax.op.variance(x).op == Op.get("relax.variance") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_statistical_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + + _check_inference(bb, relax.op.sum(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32")) + _check_inference( + bb, + relax.op.sum(x0, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((2, 1, 1, 5), "float32"), + ) + _check_inference(bb, relax.op.sum(x0, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.sum(x0, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "float32"), + ) + _check_inference( + bb, relax.op.mean(x1, axis=[1, 2]), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, + relax.op.mean(x1, axis=[1, 2], keepdims=True), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference(bb, relax.op.mean(x1, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.mean(x1, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "float32"), + ) + _check_inference( + bb, relax.op.variance(x2, axis=[1, 2]), relax.TensorStructInfo(dtype="float32") + ) + _check_inference( + bb, + relax.op.variance(x2, axis=[1, 2], keepdims=True), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference(bb, relax.op.variance(x2, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.variance(x2, axis=None, keepdims=True), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference(bb, relax.op.max(x3, axis=[1, 2]), relax.TensorStructInfo((2, 5), dtype="")) + _check_inference( + bb, + relax.op.max(x3, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((2, 1, 1, 5), dtype=""), + ) + _check_inference(bb, relax.op.max(x3, axis=None), relax.TensorStructInfo((), dtype="")) + _check_inference( + bb, + relax.op.max(x3, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), dtype=""), + ) + _check_inference(bb, relax.op.prod(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32")) + _check_inference( + bb, + relax.op.prod(x0, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((2, 1, 1, 5), "float32"), + ) + _check_inference(bb, relax.op.std(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32")) + _check_inference( + bb, + relax.op.std(x0, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((2, 1, 1, 5), "float32"), + ) + _check_inference(bb, relax.op.sum(x0, axis=[-1, -4]), relax.TensorStructInfo((3, 4), "float32")) + _check_inference(bb, relax.op.sum(x0, axis=[]), relax.TensorStructInfo((2, 3, 4, 5), "float32")) + + +def test_statistical_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) + + _check_inference(bb, relax.op.min(x, axis=[1, 2]), relax.TensorStructInfo((a, d), "float32")) + _check_inference( + bb, + relax.op.min(x, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((a, 1, 1, d), "float32"), + ) + _check_inference(bb, relax.op.min(x, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.min(x, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "float32"), + ) + + +def test_statistical_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, relax.op.max(x0), relax.TensorStructInfo((), dtype="float32")) + _check_inference( + bb, relax.op.max(x0, keepdims=True), relax.TensorStructInfo((1, 1, 1, 1), dtype="float32") + ) + _check_inference( + bb, relax.op.max(x0, axis=[2, 3]), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, + relax.op.max(x0, axis=[2, 3], keepdims=True), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference(bb, relax.op.max(x1), relax.TensorStructInfo((), dtype="float32")) + _check_inference(bb, relax.op.max(x1, keepdims=True), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.max(x1, axis=[2, 3]), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.max(x1, axis=[2, 3], keepdims=True), relax.TensorStructInfo(dtype="float32") + ) + + +def test_statistical_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + + _check_inference(bb, relax.op.sum(x0), relax.TensorStructInfo((), "float16")) + _check_inference(bb, relax.op.sum(x1), relax.TensorStructInfo((), "int8")) + + +def test_statistical_infer_struct_info_axis_out_of_range_repetitive(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x0, axis=[4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x1, axis=[3, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x0, axis=[-1, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x1, axis=[-4, -4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x0, axis=[-5])) + + +def test_statistical_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.variance(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.variance(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_ternary.py b/tests/python/relax/test_op_ternary.py new file mode 100644 index 000000000000..5ea7a01da701 --- /dev/null +++ b/tests/python/relax/test_op_ternary.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + z = relax.Var("z", R.Tensor((2, 3), "float32")) + assert relax.op.ewise_fma(x, y, z).op == Op.get("relax.ewise_fma") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_ewise_fma_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3))) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor(dtype="float32", ndim=2)) + z0 = relax.Var("z", R.Tensor((2, 3), "float32")) + z1 = relax.Var("z", R.Tensor("float32")) + + _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.ewise_fma(x0, y1, z0), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.ewise_fma(x0, y1, z1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.ewise_fma(x1, y0, z0), relax.TensorStructInfo((2, 3), dtype="")) + + +def test_ewise_fma_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + y0 = relax.Var("y", R.Tensor((m, n), "float32")) + y1 = relax.Var("y", R.Tensor(dtype="float32", ndim=2)) + z0 = relax.Var("z", R.Tensor((m, n), "float32")) + + _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((m, n), "float32")) + _check_inference( + bb, relax.op.ewise_fma(x0, y1, z0), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + + +def test_ewise_fma_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y = relax.Var("y", relax.TensorStructInfo(s0, "float32")) + z = relax.Var("z", relax.TensorStructInfo(s0, "float32")) + + _check_inference(bb, relax.op.ewise_fma(x0, y, z), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, relax.op.ewise_fma(x1, y, z), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.ewise_fma(x2, y, z), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + + +def test_ewise_fma_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + y0 = relax.Var("y", R.Tensor((2, 3), "float64")) + z0 = relax.Var("z", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "int8")) + z1 = relax.Var("z", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + y2 = relax.Var("y", R.Tensor((2, 3), "int64")) + z2 = relax.Var("z", R.Tensor((2, 3), "int64")) + + _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.ewise_fma(x1, y1, z1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.ewise_fma(x2, y2, z2), relax.TensorStructInfo((2, 3), "int64")) + + +def test_ewise_fma_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "int32")) + y1 = relax.Var("y", R.Tensor((2, 3), "float32")) + z0 = relax.Var("z", R.Tensor((2, 3), "float32")) + z1 = relax.Var("z", R.Tensor((2, 3), "int8")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y0, z0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y1, z1)) + + +def test_ewise_fma_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor((2, 3, 4), "float32")) + z0 = relax.Var("z", R.Tensor((2, 3), "float32")) + z1 = relax.Var("z", R.Tensor(dtype="float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y1, z0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y0, z1)) + + +def test_ewise_fma_wrong_input_number(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + + with pytest.raises(TypeError): + relax.op.ewise_fma(x) + with pytest.raises(TypeError): + relax.op.ewise_fma(x, x) + with pytest.raises(TypeError): + relax.op.ewise_fma(x, x, x, x) + + +def test_ewise_fma_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", relax.ShapeStructInfo((2, 3))) + y1 = relax.Var("y", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + z = relax.Var("z", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y0, z)) + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y1, z)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_unary.py b/tests/python/relax/test_op_unary.py new file mode 100644 index 000000000000..45336661a1ae --- /dev/null +++ b/tests/python/relax/test_op_unary.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Callable +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + assert relax.op.abs(x).op == Op.get("relax.abs") + assert relax.op.acos(x).op == Op.get("relax.acos") + assert relax.op.acosh(x).op == Op.get("relax.acosh") + assert relax.op.asin(x).op == Op.get("relax.asin") + assert relax.op.asinh(x).op == Op.get("relax.asinh") + assert relax.op.atan(x).op == Op.get("relax.atan") + assert relax.op.atanh(x).op == Op.get("relax.atanh") + assert relax.op.ceil(x).op == Op.get("relax.ceil") + assert relax.op.cos(x).op == Op.get("relax.cos") + assert relax.op.cosh(x).op == Op.get("relax.cosh") + assert relax.op.exp(x).op == Op.get("relax.exp") + assert relax.op.floor(x).op == Op.get("relax.floor") + assert relax.op.isfinite(x).op == Op.get("relax.isfinite") + assert relax.op.isinf(x).op == Op.get("relax.isinf") + assert relax.op.isnan(x).op == Op.get("relax.isnan") + assert relax.op.log(x).op == Op.get("relax.log") + assert relax.op.negative(x).op == Op.get("relax.negative") + assert relax.op.round(x).op == Op.get("relax.round") + assert relax.op.sigmoid(x).op == Op.get("relax.sigmoid") + assert relax.op.sin(x).op == Op.get("relax.sin") + assert relax.op.sinh(x).op == Op.get("relax.sinh") + assert relax.op.square(x).op == Op.get("relax.square") + assert relax.op.sqrt(x).op == Op.get("relax.sqrt") + assert relax.op.tan(x).op == Op.get("relax.tan") + assert relax.op.tanh(x).op == Op.get("relax.tanh") + assert relax.op.clip(x, 0, 6).op == Op.get("relax.clip") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +unary_arith_op, require_float_dtype = tvm.testing.parameters( + (relax.op.abs, False), + (relax.op.acos, True), + (relax.op.acosh, True), + (relax.op.asin, True), + (relax.op.asinh, True), + (relax.op.atan, True), + (relax.op.atanh, True), + (relax.op.ceil, False), + (relax.op.cos, True), + (relax.op.cosh, True), + (relax.op.exp, True), + (relax.op.floor, False), + (relax.op.log, True), + (relax.op.negative, False), + (relax.op.round, False), + (relax.op.sigmoid, True), + (relax.op.sign, False), + (relax.op.sin, True), + (relax.op.sinh, True), + (relax.op.square, False), + (relax.op.sqrt, True), + (relax.op.tan, True), + (relax.op.tanh, True), +) + + +def test_unary_arith_infer_struct_info(unary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, unary_arith_op(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, unary_arith_op(x3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, unary_arith_op(x4), relax.TensorStructInfo(dtype="")) + + +def test_unary_arith_infer_struct_info_shape_symbolic(unary_arith_op: Callable): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo((4, n), "float32")) + + +def test_unary_arith_infer_struct_info_shape_var(unary_arith_op: Callable): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo(s1, "float32")) + + +def test_unary_arith_infer_struct_info_more_input_dtype( + unary_arith_op: Callable, require_float_dtype: bool +): + if require_float_dtype: + return + + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + + _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, unary_arith_op(x2), relax.TensorStructInfo((2, 3), "int64")) + + +def test_unary_arith_infer_struct_info_invalid_input_dtype( + unary_arith_op: Callable, require_float_dtype: bool +): + if not require_float_dtype: + return + + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3), "int64")) + + with pytest.raises(TVMError): + bb.normalize(unary_arith_op(x0)) + with pytest.raises(TVMError): + bb.normalize(unary_arith_op(x1)) + + +def test_unary_arith_wrong_input_number(unary_arith_op: Callable): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + + with pytest.raises(TypeError): + unary_arith_op(x, x) + with pytest.raises(TypeError): + unary_arith_op(x, x, x) + + +def test_unary_arith_infer_struct_info_wrong_input_type(unary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(unary_arith_op(x0)) + with pytest.raises(TVMError): + bb.normalize(unary_arith_op(x1)) + + +def test_clip_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.clip(x0, 0, 6), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.clip(x1, 0, 6), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.clip(x2, 0, 6), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.clip(x3, 0, 6), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.clip(x4, 0, 6), relax.TensorStructInfo(dtype="")) + + # Symbolic + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x5 = relax.Var("x", R.Tensor((m, n), "float32")) + x6 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, relax.op.clip(x5, 0, 6), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.clip(x6, 0, 6), relax.TensorStructInfo((4, n), "float32")) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_pipeline.py b/tests/python/relax/test_pipeline.py new file mode 100644 index 000000000000..c66066f8f830 --- /dev/null +++ b/tests/python/relax/test_pipeline.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import relax +from tvm.script import relax as R + + +def test_pipeline_compile(): + pipeline = relax.get_pipeline() + + @tvm.script.ir_module + class Mod: + @R.function + def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): + lv0 = R.add(x, y) + return lv0 + + mod = Mod + mod = pipeline(mod) + target = tvm.target.Target("llvm", host="llvm") + + ex = relax.build(mod, target) + x_np = np.random.rand(3, 4).astype(np.float32) + y_np = np.random.rand(3, 4).astype(np.float32) + x = tvm.nd.array(x_np) + y = tvm.nd.array(y_np) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + z = vm["main"](x, y) + tvm.testing.assert_allclose(z.numpy(), x_np + y_np, rtol=1e-7, atol=1e-7) diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py new file mode 100644 index 000000000000..776abbce764d --- /dev/null +++ b/tests/python/relax/test_relax_operators.py @@ -0,0 +1,235 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +import tempfile + +import numpy as np +import tvm +import tvm.testing +from tvm import relax +from tvm._ffi.base import TVMError +from tvm.script import relax as R, tir as T + + +@tvm.script.ir_module +class InputModule: + @R.function + def foo(x: R.Tensor(("m", "n"), "int64")): + y = R.unique(x, sorted=False) + y_sorted = R.unique(x) + return y, y_sorted + + +def run_cpu(mod, func_name, *input): + target = tvm.target.Target("llvm") + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + return vm[func_name](*input) + + +def test_unique(): + + # TODO(prakalp): also add test for compiling and running on cuda device. + data_numpy = np.random.randint(0, 16, (16, 16)) + data = tvm.nd.array(data_numpy) + result, result_sorted = run_cpu(InputModule, "foo", data) + + expected_output_sorted, indices = np.unique(data_numpy, return_index=True) + expected_output = [data_numpy.flatten()[index] for index in sorted(indices, reverse=True)] + + np.testing.assert_array_equal(expected_output_sorted, result_sorted.numpy()) + np.testing.assert_array_equal(expected_output, result.numpy()) + + +@tvm.script.ir_module +class PrintTest: + @R.function + def foo(x: R.Tensor((), "int32")): + # results have to be bound, but we don't use them + # TODO: We should allow calls whose results are not bound for side effects; + # it would be easy syntactic sugar to add. + p1 = R.print(x) + p2 = R.print(x, format="Number: {}") + t = (x, x) + p3 = R.print(t, format="Tuple: {}") + p4 = R.print(x, t) + p5 = R.print(x, x, format="Custom print: {} {}") + p6 = R.print(x, t, format="Another print: {} {}") + return x + + +def test_print(): + try: + stdout = sys.stdout + with tempfile.TemporaryFile(mode="w+") as test_out: + sys.stdout = test_out + run_cpu(PrintTest, "foo", tvm.nd.array(np.array(1).astype("int32"))) + test_out.seek(0) + printed_text = str(test_out.read()) + expected = "1\nNumber: 1\nTuple: (1, 1)\n1 (1, 1)\nCustom print: 1 1\nAnother print: 1 (1, 1)\n" + assert printed_text in expected, ("printed_text is ", printed_text) + finally: + sys.stdout = stdout + + +@tvm.script.ir_module +class AssertOpTest: + @R.function + def passes(x: R.Tensor((), "int32")): + p1 = R.assert_op(relax.const(True)) + return x + + @R.function + def pass_with_args(x: R.Tensor((), "int32")): + p1 = R.assert_op(relax.const(True), x, format="You won't see me") + return x + + @R.function + def simple_fail(x: R.Tensor((), "int32")): + p1 = R.assert_op(relax.const(False)) + return x + + @R.function + def fail_with_message(x: R.Tensor((), "int32")): + p1 = R.assert_op(relax.const(False), format="I failed...") + return x + + @R.function + def fail_with_args(x: R.Tensor((), "int32")): + # no format + p1 = R.assert_op(relax.const(False), [x, x]) + return x + + @R.function + def fail_with_formatted_message(x: R.Tensor((), "int32")): + p1 = R.assert_op(relax.const(False), x, format="Number: {}") + return x + + +def test_assert_op(): + def check_assertion_error(func_name, func_arg, expected_message): + passed = False + try: + run_cpu(AssertOpTest, func_name, func_arg) + passed = True + except TVMError as e: + # TVM will print out a TVMError that will contain the + # generated error at the bottom of a stack trace + assert "AssertionError" in e.args[0] + assert expected_message in e.args[0] + assert not passed + + run_cpu(AssertOpTest, "passes", tvm.nd.array(np.array(1).astype("int32"))) + run_cpu(AssertOpTest, "pass_with_args", tvm.nd.array(np.array(2).astype("int32"))) + check_assertion_error( + "simple_fail", tvm.nd.array(np.array(3).astype("int32")), "Assertion Failed" + ) + check_assertion_error( + "fail_with_message", tvm.nd.array(np.array(4).astype("int32")), "I failed..." + ) + check_assertion_error("fail_with_args", tvm.nd.array(np.array(5).astype("int32")), "5, 5") + check_assertion_error( + "fail_with_formatted_message", tvm.nd.array(np.array(6).astype("int32")), "Number: 6" + ) + + +@tvm.script.ir_module +class ShapeOfTest: + @R.function + def get_shape(t: R.Tensor(ndim=-1, dtype="int32")) -> R.Shape(ndim=-1): + return R.shape_of(t) + + @R.function + def get_constrained_shape(t: R.Tensor(ndim=1, dtype="int32")) -> R.Shape(ndim=1): + # require the input tensor to have rank 1 + return R.shape_of(t) + + @R.function + def get_scalar_shape() -> R.Shape(()): + x: R.Tensor((), "int32") = R.const(1, dtype="int32") + return R.shape_of(x) + + @R.function + def get_constant_shape() -> R.Shape((2, 2)): + x: R.Tensor((2, 2), "int32") = R.const( + np.array([[1, 2], [3, 4]], dtype="int32"), dtype="int32" + ) + return R.shape_of(x) + + +def test_op_shape_of(): + unit_shape = run_cpu(ShapeOfTest, "get_scalar_shape") + assert unit_shape == tvm.runtime.ShapeTuple([]) + + const_shape = run_cpu(ShapeOfTest, "get_constant_shape") + assert const_shape == tvm.runtime.ShapeTuple([2, 2]) + + scalar_shape = run_cpu(ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, dtype="int32"))) + assert scalar_shape == tvm.runtime.ShapeTuple([]) + + tensor_shape = run_cpu( + ShapeOfTest, "get_shape", tvm.nd.array(np.zeros((1, 2, 3)).astype("int32")) + ) + assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3]) + + constrained_shape = run_cpu( + ShapeOfTest, "get_constrained_shape", tvm.nd.array(np.zeros((1,)).astype("int32")) + ) + assert constrained_shape == tvm.runtime.ShapeTuple([1]) + + +@tvm.script.ir_module +class ShapeToTensorTest: + @R.function + def const_shape(shape: R.Shape(ndim=-1)) -> R.Tensor(ndim=-1): + return R.shape_to_tensor(shape) + + @R.function + def symbolic_shape(shape: R.Shape(("m", "n"))) -> R.Tensor(ndim=-1): + m = T.int64() + n = T.int64() + return R.shape_to_tensor(shape) + + +def test_op_shape_to_tensor(): + # Check struct info + isinstance(ShapeToTensorTest["const_shape"].body.struct_info, tvm.relax.TensorStructInfo) + assert ShapeToTensorTest["const_shape"].body.struct_info.ndim == 1 + isinstance(ShapeToTensorTest["symbolic_shape"].body.struct_info, tvm.relax.TensorStructInfo) + assert ShapeToTensorTest["symbolic_shape"].body.struct_info.ndim == 1 + + # Check its functionality + out2d = run_cpu(ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 2])) + assert isinstance(out2d, tvm.runtime.ndarray.NDArray) + assert np.array_equal(out2d.numpy(), np.array([3, 2])) + + out3d = run_cpu(ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2])) + assert isinstance(out3d, tvm.runtime.ndarray.NDArray) + assert np.array_equal(out3d.numpy(), np.array([3, 3, 2])) + + out4d = run_cpu(ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2, 2])) + assert isinstance(out4d, tvm.runtime.ndarray.NDArray) + assert np.array_equal(out4d.numpy(), np.array([3, 3, 2, 2])) + + outs = run_cpu(ShapeToTensorTest, "symbolic_shape", tvm.runtime.ShapeTuple([3, 2])) + assert isinstance(outs, tvm.runtime.ndarray.NDArray) + assert np.array_equal(outs.numpy(), np.array([3, 2])) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_relay_translator.py b/tests/python/relax/test_relay_translator.py new file mode 100644 index 000000000000..d3cd47b9e699 --- /dev/null +++ b/tests/python/relax/test_relay_translator.py @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tempfile + +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import meta_schedule as ms +from tvm import relax, relay, tir, topi +from tvm.ir.base import assert_structural_equal +from tvm.relax.testing import relay_translator +from tvm.relay import testing +from tvm.runtime import vm +from tvm.script import tir as T +from tvm.target import Target + + +def get_resnet(batch_size, dtype, layout, image_shape): + relay_mod, params = testing.resnet.get_workload( + num_layers=18, + batch_size=batch_size, + dtype=dtype, + layout=layout, + image_shape=image_shape, + ) + + return relay_mod, params + + +def relay_build_and_run(mod, target, dev, params, data): + with tempfile.TemporaryDirectory() as work_dir: + db = ms.relay_integration.tune_relay( + mod=mod, + params=params, + target=target, + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=1024, + task_scheduler="round-robin", + work_dir=work_dir, + ) + ex = ms.relay_integration.compile_relay( + db, + mod=mod, + target=target, + params=params, + ) + rt_mod = tvm.contrib.graph_executor.GraphModule(ex["default"](dev)) + rt_mod.set_input("data", data) + rt_mod.run() + out = rt_mod.get_output(0).numpy() + return ex, rt_mod, out + + +def relax_build_and_run(mod, target, dev, params, data): + mod = relax.transform.BindParams("main", params)(mod) + with tempfile.TemporaryDirectory() as work_dir: + db = ms.relax_integration.tune_relax( + mod=mod, + target=target, + task_scheduler="round-robin", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=1024, + work_dir=work_dir, + ) + ex = ms.relax_integration.compile_relax( + db, + mod=mod, + target=target, + params=params, + ) + vm = relax.VirtualMachine(ex, dev) + res = vm["main"](data) + out = res.numpy() + return ex, vm, out + + +def verify_e2e_translation(target_str, layout, batch_size, image_shape): + target = Target(target_str) + dev = tvm.device(str(target), dev_id=0) + relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape) + input_shape = (1, *image_shape) + data = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32), dev) + relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) + assert relax_mod["main"].attrs["global_symbol"] == "main" + + _, _, relay_out = relay_build_and_run(relay_mod, target, dev, params, data) + _, _, relax_out = relax_build_and_run(relax_mod, target, dev, params, data) + tvm.testing.assert_allclose(relay_out, relax_out, atol=1e-5, rtol=1e-5) + + +@pytest.mark.skip(reason="take too much time") +@pytest.mark.parametrize( + "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 1, (224, 224, 3))] +) +def test_verify_e2e_translation_cpu(layout, batch_size, image_shape): + verify_e2e_translation("llvm --num-cores=16", layout, batch_size, image_shape) + + +@pytest.mark.skip(reason="take too much time") +@tvm.testing.requires_gpu +@pytest.mark.parametrize( + "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 1, (224, 224, 3))] +) +def test_verify_e2e_translation_gpu(layout, batch_size, image_shape): + verify_e2e_translation("cuda", layout, batch_size, image_shape) + + +def verify_extracted_tasks(target_str, layout, batch_size, image_shape): + target = Target(target_str) + relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape) + relax_mod = relay_translator.from_relay( + relay_mod["main"], + target, + params, + pass_config={ + "relay.backend.use_meta_schedule": True, + "relay.FuseOps.max_depth": 1, # Disable relay fusion + }, + ) + relay_tasks = ms.relay_integration.extract_tasks( + relay_mod, + target=target, + params=params, + pass_config={ + "relay.backend.use_meta_schedule": True, + "relay.FuseOps.max_depth": 1, # Disable relay fusion + }, + ) + relax_tasks = ms.relax_integration.extract_tasks( + relax_mod, + target=target, + params=params, + ) + # TODO (yongwww, yuchen): tophub guides relay passes, which causes inconsistent tasks + # assert len(relay_tasks) == len(relax_tasks) + # TODO: Can we compare extracted tasks as well? + + +@pytest.mark.parametrize( + "layout, batch_size, image_shape", + [ + ("NCHW", 1, (3, 224, 224)), + ("NHWC", 1, (224, 224, 3)), + ], +) +def test_verify_extracted_tasks_cpu(layout, batch_size, image_shape): + verify_extracted_tasks("llvm --num-cores=16", layout, batch_size, image_shape) + + +@tvm.testing.requires_gpu +@pytest.mark.parametrize( + "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 1, (224, 224, 3))] +) +def test_verify_extracted_tasks_gpu(layout, batch_size, image_shape): + verify_extracted_tasks("cuda", layout, batch_size, image_shape) + + +def translate_and_build_vms(relay_mod, target_str="llvm", translate_op_with_tir=None): + target = tvm.target.Target(target_str) + + # build the relay IRModule and create relay vm + relay_ex = relay.vm.compile(relay_mod, target) + relay_vm = vm.VirtualMachine(relay_ex, tvm.cpu()) + + # build the relax IRModule and create relax vm + relax_mod = relay_translator.from_relay( + relay_mod["main"], target, translate_op_with_tir=translate_op_with_tir + ) + relax_ex = relax.build(relax_mod, target) + relax_vm = relax.VirtualMachine(relax_ex, tvm.cpu()) + + return relay_vm, relax_vm, relax_mod + + +def verify_vm_outputs( + input_shape, + relay_vm, + relax_vm, + extra_args=[], +): + input = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32)) + + # check correctness by comparing relax and relay result + args = [input] + extra_args + relax_output = relax_vm["main"](*args) + relay_output = relay_vm.run(*args) + tvm.testing.assert_allclose(relay_output.numpy(), relax_output.numpy()) + + +def test_single_dynamic_dim(): + wx, wy = 64, 128 + # create relay module: y = data * weights + bias with dynamic batch dimension + data = relay.var("data", shape=(relay.Any(), wx)) + weights = relay.var("weights", shape=(wx, wy)) + bias = relay.var("bias", shape=(wy,)) + y = relay.nn.matmul(data, weights) + relay_mod = tvm.IRModule.from_expr(relay.Function([data, weights, bias], y + bias)) + + relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod) + weights = tvm.nd.array(np.random.rand(wx, wy).astype(np.float32)) + bias = tvm.nd.array(np.random.rand(wy).astype(np.float32)) + # verify for different batch sizes + verify_vm_outputs([10, wx], relay_vm, relax_vm, [weights, bias]) + verify_vm_outputs([32, wx], relay_vm, relax_vm, [weights, bias]) + + +def test_multiple_dynamic_dims(): + # create relay module: y = a + a, where a has shape = (?, 5, ?) + shape = (relay.Any(), 5, relay.Any()) + a = relay.var("a", shape=shape) + + relay_mod = tvm.IRModule.from_expr(relay.Function([a], a + a)) + relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod) + # verify for different shapes + verify_vm_outputs([2, 5, 10], relay_vm, relax_vm) + verify_vm_outputs([12, 5, 24], relay_vm, relax_vm) + + +def test_layout_transform(): + shape = (1, 3, 224, 224) + a = relay.var("a", shape=shape) + b = relay.layout_transform(a, "NCHW", "NHWC") + relay_mod = tvm.IRModule.from_expr(relay.Function([a], b)) + + relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod) + verify_vm_outputs([1, 3, 224, 224], relay_vm, relax_vm) + + +def test_translate_op_with_tir(): + @T.prim_func + def tir_matmul( + A: T.Buffer((512, 512), "float32"), + B: T.Buffer((512, 512), "float32"), + C: T.Buffer((512, 512), "float32"), + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "multiply", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1, i2 in T.grid(512, 512, 512): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(C[i, j], A[i, k], B[k, j]) + T.writes(C[i, j]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + shape = (512, 512) + a = relay.var("a", shape=shape) + + relay_mod = tvm.IRModule.from_expr(relay.Function([a], a * a)) + _, _, relax_mod = translate_and_build_vms( + relay_mod, translate_op_with_tir={"multiply": tir_matmul} + ) + assert_structural_equal(relax_mod["multiply"], tir_matmul) + + +def test_translate_tuple_arg(): + x = relay.var("x", shape=(10, 16)) + y = relay.var("y", shape=(10, 16)) + relay_mod = tvm.IRModule.from_expr(relay.Function([x, y], relay.concatenate((x, y), axis=-1))) + relax_mod = relay_translator.from_relay(relay_mod["main"], target="llvm") + + # Construct the expected module + bb = relax.BlockBuilder() + x_relax = relax.Var("x", relax.TensorStructInfo([10, 16], "float32")) + y_relax = relax.Var("y", relax.TensorStructInfo([10, 16], "float32")) + with bb.function("main", [x_relax, y_relax]): + with bb.dataflow(): + _ = bb.emit(relax.Tuple((x_relax, y_relax))) + lv1 = bb.emit(x_relax) + lv2 = bb.emit(y_relax) + lv3 = bb.emit_te(topi.x86.concatenate, (lv1, lv2), axis=-1) + gv = bb.emit_output(lv3) + bb.emit_func_output(gv) + + assert_structural_equal(relax_mod, bb.get()) + + +def test_append_op_attrs(): + x = relay.var("x", shape=(10, 16)) + y = relay.var("y", shape=(10, 16)) + relay_mod = tvm.IRModule.from_expr(relay.Function([x, y], relay.concatenate((x, y), axis=-1))) + relax_mod_wo_attrs = relay_translator.from_relay(relay_mod["main"], target="llvm") + relax_mod_with_attrs = relay_translator.from_relay( + relay_mod["main"], target="llvm", append_op_attrs=True + ) + assert "op_attrs" in relax_mod_with_attrs["concatenate"].attrs + assert "op_attrs" not in relax_mod_wo_attrs["concatenate"].attrs + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_runtime_builtin.py b/tests/python/relax/test_runtime_builtin.py new file mode 100644 index 000000000000..b4ba54b45554 --- /dev/null +++ b/tests/python/relax/test_runtime_builtin.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import pytest +import numpy as np + +from tvm.ir import assert_structural_equal +from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode + + +def test_make_shape(): + MK = MakeShapeCode + make_shape = tvm.get_global_func("vm.builtin.make_shape") + heap = tvm.nd.array(np.arange(10).astype("int64")) + s = make_shape(heap, 3, MK.USE_IMM, 10, MK.LOAD_SHAPE, 0, MK.LOAD_SHAPE, 2) + + assert s == tvm.runtime.container.ShapeTuple([10, 0, 2]) + + +def test_match_shape(): + MS = MatchShapeCode + match_shape = tvm.get_global_func("vm.builtin.match_shape") + heap = tvm.nd.array(np.zeros(10).astype("int64")) + + assert heap.numpy()[2] == 0 + + s = tvm.runtime.container.ShapeTuple([1, 2, 3]) + x = tvm.nd.array(np.zeros([1, 2, 3])) + + match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 2, MS.NO_OP, 0, "") + + assert heap.numpy()[2] == 2 + + match_shape( + x, + heap, + 3, + MS.ASSERT_EQUAL_TO_IMM, + 1, + MS.ASSERT_EQUAL_TO_LOAD, + 2, + MS.ASSERT_EQUAL_TO_IMM, + 3, + "", + ) + + with pytest.raises(RuntimeError): + match_shape(s, heap, 2, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 2, "") + + with pytest.raises(RuntimeError): + match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 2, MS.STORE_TO_HEAP, 2, MS.NO_OP, 0, "") + + +def test_check_shape_info(): + check_shape_info = tvm.get_global_func("vm.builtin.check_shape_info") + s = tvm.runtime.container.ShapeTuple([1, 2, 3]) + + check_shape_info(s, 3, "") + check_shape_info(s, -1, "") + + # wrong ndim + with pytest.raises(ValueError): + check_shape_info(s, 2, "") + + # wrong type + with pytest.raises(TypeError): + check_shape_info([], 2, "") + + +def test_check_tensor_info(): + check_tensor_info = tvm.get_global_func("vm.builtin.check_tensor_info") + x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + + check_tensor_info(x, 2, "int32", "") + check_tensor_info(x, -1, "int32", "") + check_tensor_info(x, 2, "", "") + check_tensor_info(x, -1, "", "") + + # allow not passing in dtype + check_tensor_info(x, 2, "") + check_tensor_info(x, -1, "") + + # ndim mismatch + with pytest.raises(ValueError, match=r".* ndim .*"): + check_tensor_info(x, 3, "int32", "") + + # dtype mismatch + with pytest.raises(ValueError, match=r"myerror.* dtype .*"): + check_tensor_info(x, 2, "float32", "myerror") + + # error with context + with pytest.raises(ValueError, match=r".* myerror .*"): + check_tensor_info(x, 3, "myerror") + + # wrong type + with pytest.raises(TypeError): + check_tensor_info([], 2, "", "") + + +def test_check_tuple_info(): + check_tuple_info = tvm.get_global_func("vm.builtin.check_tuple_info") + x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + t = tvm.runtime.convert([x, x, x]) + + check_tuple_info(t, 3, "") + + # size + with pytest.raises(ValueError, match=r".*elements.*"): + check_tuple_info(t, 2, "") + + # wrong type + with pytest.raises(TypeError): + check_tuple_info(x, 2, "") + + +def test_check_func_info(): + check_func_info = tvm.get_global_func("vm.builtin.check_func_info") + f = tvm.runtime.convert(lambda x: x) + x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + + check_func_info(f, "") + + # wrong type + with pytest.raises(TypeError, match=".*myerror.*"): + check_func_info(x, "myerror") + + +def test_tuple_getitem(): + tuple_getitem = tvm.get_global_func("vm.builtin.tuple_getitem") + x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + y = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + t = tvm.runtime.convert([x, y]) + + assert tuple_getitem(t, 0) == x + assert tuple_getitem(t, 1) == y + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_struct_info.py b/tests/python/relax/test_struct_info.py new file mode 100644 index 000000000000..80ebc3cb182a --- /dev/null +++ b/tests/python/relax/test_struct_info.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +import pytest + +from tvm import relax as rx, TVMError, tir + + +def _check_equal(x, y, map_free_vars=False): + tvm.ir.assert_structural_equal(x, y, map_free_vars) + tvm.ir.assert_structural_equal(y, x, map_free_vars) + + xhash = tvm.ir.structural_hash(x, map_free_vars) + yhash = tvm.ir.structural_hash(y, map_free_vars) + + assert xhash == yhash + + +def _check_json_roundtrip(x): + xret = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, xret, map_free_vars=True) + return xret + + +def test_object_struct_info(): + s0 = rx.ObjectStructInfo() + s1 = rx.ObjectStructInfo() + + # can turn into str + str(s0) + _check_equal(s0, s1) + + assert isinstance(s0, rx.ObjectStructInfo) + _check_json_roundtrip(s0) + + +def test_shape_type(): + t0 = rx.ShapeType() + t1 = rx.ShapeType() + assert t0 == t1 + + +def test_dyn_tensor_type(): + t0 = rx.DynTensorType() + assert t0.ndim == -1 + t1 = rx.DynTensorType(3, "int32") + assert t1.ndim == 3 + assert t1.dtype == "int32" + + +def test_prim_struct_info(): + s0 = rx.PrimStructInfo("float32") + s1 = rx.PrimStructInfo("float32") + s2 = rx.PrimStructInfo("int32") + + _check_equal(s0, s1) + + # can turn into str + str(s0) + + assert s0 == s1 + assert s0 != s2 + + assert isinstance(s0, rx.PrimStructInfo) + _check_json_roundtrip(s0) + _check_json_roundtrip(s1) + + assert s1.dtype == "float32" + assert s2.dtype == "int32" + + # wrong API constructors + with pytest.raises(TVMError): + rx.PrimStructInfo(1) + + +def test_shape_struct_info(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s0 = rx.ShapeStructInfo([1, n + 1, m]) + s1 = rx.ShapeStructInfo([1, n + 1, m]) + + _check_equal(s0, s1) + + assert s0 == s1 + assert s0.ndim == 3 + assert s1.ndim == 3 + + assert s0.values[2] == m + + assert isinstance(s0, rx.ShapeStructInfo) + _check_json_roundtrip(s0) + _check_json_roundtrip(s1) + + s2 = rx.ShapeStructInfo(ndim=2) + + assert s2.ndim == 2 + assert s2.values is None + _check_json_roundtrip(s2) + assert s0 != s2 + + # can turn into str + str(s0) + + # wrong argument type + with pytest.raises(TVMError): + rx.ShapeStructInfo(1) + + # cannot pass both ndim and values + with pytest.raises(ValueError): + rx.ShapeStructInfo([1, 2], ndim=3) + + # cannot pass both ndim and values even if they are consistent + with pytest.raises(ValueError): + rx.ShapeStructInfo([1, 2], ndim=2) + + +def test_tensor_struct_info(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s0 = rx.TensorStructInfo([1, n + 1, m], "float32") + s1 = rx.TensorStructInfo(rx.ShapeExpr([1, n + 1, m]), "float32") + + _check_equal(s0, s1) + + assert s0 == s1 + assert s0.ndim == 3 + assert s1.ndim == 3 + + assert isinstance(s0, rx.TensorStructInfo) + _check_json_roundtrip(s0) + _check_json_roundtrip(s1) + + s2 = rx.TensorStructInfo(ndim=2, dtype="int32") + + assert s2.ndim == 2 + assert s2.dtype == "int32" + assert s2.shape is None + _check_json_roundtrip(s2) + assert s0 != s2 + + # take in opaque var + rshape = rx.Var("shape", rx.ShapeStructInfo(ndim=2)) + + s3 = rx.TensorStructInfo(rshape, dtype="int32") + assert s3.dtype == "int32" + assert s3.shape == rshape + assert s3.ndim == 2 + _check_json_roundtrip(s3) + + # can turn into str + str(s0) + + # cannot pass both ndim and values + with pytest.raises(ValueError): + rx.TensorStructInfo([1, 2], ndim=3) + + # cannot pass both ndim and values even if they are consistent + with pytest.raises(ValueError): + rx.TensorStructInfo([1, 2], ndim=2) + + +def test_tuple_struct_info(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s0 = rx.TensorStructInfo([1, 2, m + n], "float32") + s1 = rx.ObjectStructInfo() + + t0 = rx.TupleStructInfo([s0, s1]) + t1 = rx.TupleStructInfo([s0, rx.ObjectStructInfo()]) + t2 = rx.TupleStructInfo([s0, s0]) + + _check_equal(t0, t1) + + assert t0 == t1 + + assert isinstance(t0, rx.TupleStructInfo) + t0 = _check_json_roundtrip(t0) + t1 = _check_json_roundtrip(t1) + t2 = _check_json_roundtrip(t2) + + # can turn into str + str(t0) + + # wrong argument type + with pytest.raises(TVMError): + rx.TupleStructInfo(1) + + +def test_func_struct_info(): + def fn_info(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n, m], "float32") + return rx.FuncStructInfo([x, y], z) + + f0 = fn_info(1) + f1 = fn_info(1) + f2 = fn_info(2) + f3 = rx.FuncStructInfo.opaque_func() + + _check_equal(f0, f1) + + assert f0 == f1 + assert f0 != f2 + + assert len(f0.params) == 2 + assert isinstance(f0.ret, rx.TensorStructInfo) + assert f2.derive_func is None + assert f3.params is None + assert f3.derive_func is None + _check_equal(f3.ret, rx.ObjectStructInfo()) + + assert isinstance(f0, rx.FuncStructInfo) + f0 = _check_json_roundtrip(f0) + f1 = _check_json_roundtrip(f1) + f2 = _check_json_roundtrip(f2) + f3 = _check_json_roundtrip(f3) + + # can turn into str + str(f3) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py new file mode 100644 index 000000000000..2dc06b4a9d51 --- /dev/null +++ b/tests/python/relax/test_transform.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm import relax +from tvm.ir import structural_equal +from tvm.ir.base import assert_structural_equal + +import tvm.script +from tvm.script import tir as T, relax as R + + +def test_to_non_dataflow(): + @tvm.script.ir_module + class TestToNonDataflow: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.int64(), T.int64() + with R.dataflow(): + lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_dps_packed( + "test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32") + ) + R.output(gv0) + return gv0 + + mod = TestToNonDataflow + + old_vars = [] + + def fvisit(e): + if isinstance(e, relax.Var): + nonlocal old_vars + old_vars.append(e) + + relax.analysis.post_order_visit(mod["foo"], fvisit) + x, lv0, gv0 = old_vars + + new_mod = relax.transform.ToNonDataflow()(mod) + + new_vars = [] + + def fvisit(e): + if isinstance(e, relax.Var): + nonlocal new_vars + new_vars.append(e) + + relax.analysis.post_order_visit(new_mod["foo"], fvisit) + + assert x == new_vars[0] + assert lv0 != new_vars[1] + assert isinstance(lv0, relax.DataflowVar) + assert not isinstance(new_vars[1], relax.DataflowVar) + + assert isinstance(gv0, relax.Var) + assert isinstance(new_vars[2], relax.Var) + assert gv0 == new_vars[2] + + +def test_call_tir_rewrite(): + @tvm.script.ir_module + class TestCallTIRRewrite: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.int64(), T.int64() + gv0 = R.call_tir(TestCallTIRRewrite.exp, (x,), R.Tensor((m, n), dtype="float32")) + return gv0 + + mod = TestCallTIRRewrite + + # before rewrite + v0 = mod["foo"].body.blocks[0].bindings[0].var + s0 = mod["foo"].body.blocks[0].bindings[0].value + assert isinstance(s0, relax.Call) + assert s0.op.name == "relax.call_tir" + + # after rewrite + new_mod = relax.transform.CallTIRRewrite()(mod) + func = new_mod["foo"] + + block = func.body.blocks[0] + assert not isinstance(block, relax.DataflowBlock) + + s1 = block.bindings[0].value + assert isinstance(s1, relax.Call) + assert s1.op.name == "relax.builtin.alloc_tensor" + assert isinstance(s1.args[0], relax.ShapeExpr) + assert structural_equal(s1.args[0], s0.sinfo_args[0].shape) + s2 = block.bindings[1].value + tvm.ir.expr.GlobalVar + assert s2.op.name_hint == "exp" + + +def test_call_dps_packed_rewrite(): + @tvm.script.ir_module + class TestCallDPSPackedRewrite: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.int64(), T.int64() + gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + return gv0 + + mod = TestCallDPSPackedRewrite + + # before rewrite + v0 = mod["foo"].body.blocks[0].bindings[0].var + s0 = mod["foo"].body.blocks[0].bindings[0].value + assert isinstance(s0, relax.Call) + assert s0.op.name == "relax.call_dps_packed" + + # CallTIRRewrite also works for call_dps_packed + new_mod = relax.transform.CallTIRRewrite()(mod) + func = new_mod["foo"] + + block = func.body.blocks[0] + assert not isinstance(block, relax.DataflowBlock) + + s1 = block.bindings[0].value + assert isinstance(s1, relax.Call) + assert s1.op.name == "relax.builtin.alloc_tensor" + assert isinstance(s1.args[0], relax.ShapeExpr) + assert structural_equal(s1.args[0], s0.sinfo_args[0].shape) + s2 = block.bindings[1].value + assert s2.op.global_symbol == "test.op.identity" + + +def test_vm_builtin_lower(): + @tvm.script.ir_module + class TestVMBuiltinLower: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: + m, n = T.int64(), T.int64() + alloc = R.builtin.alloc_tensor(R.shape([m, n]), runtime_device_index=0, dtype="float32") + _ = R.call_packed( + "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + ) + gv0 = alloc + return gv0 + + mod = TestVMBuiltinLower + + # after vm builtin lowering + new_mod = relax.transform.VMBuiltinLower()(mod) + func = new_mod["foo"] + + assert isinstance(new_mod, tvm.IRModule) + assert isinstance(func, tvm.relax.expr.Function) + + block = func.body.blocks[0] + s1 = block.bindings[0].value + assert isinstance(s1, relax.Call) + assert s1.op.name == "relax.vm.alloc_storage" + s2 = block.bindings[1].value + assert isinstance(s2, relax.Call) + s3 = block.bindings[2].value + assert isinstance(s3, relax.Call) + assert isinstance(s3.op, relax.ExternFunc) + assert s3.op.global_symbol == "test.op.identity" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_alter_op_impl.py b/tests/python/relax/test_transform_alter_op_impl.py new file mode 100644 index 000000000000..77e2d4e35986 --- /dev/null +++ b/tests/python/relax/test_transform_alter_op_impl.py @@ -0,0 +1,342 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm.testing + +from tvm import relax +from tvm.script import tir as T, ir as I, relax as R + +kOperatorName = "operator_name" + + +def _check(before, expected, operator_name, replacement_primfunc, layout_changes): + after = relax.transform.AlterOpImpl( + {operator_name: replacement_primfunc}, {operator_name: layout_changes} + )(before) + after = relax.transform.DeadCodeElimination()(after) + tvm.ir.assert_structural_equal(after, expected) + + +def test_single_output(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func + def add(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): + T.func_attr({"operator_name": "relax.add"}) + for ax0 in range(16): + with T.block("T_add"): + v_ax0 = T.axis.spatial(16, ax0) + T.reads(arg0[v_ax0], arg1[v_ax0]) + T.writes(output[v_ax0]) + output[v_ax0] = arg0[v_ax0] + arg1[v_ax0] + + @R.function + def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): + with R.dataflow(): + lv = R.call_tir(Before.add, (x, y), out_sinfo=R.Tensor((16,), dtype="float32")) + gv: R.Tensor((16,), dtype="float32") = lv + R.output(gv) + return gv + @I.ir_module + class Expected: + @T.prim_func + def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")): + T.func_attr({"operator_name": "relax.add"}) + for ax0, ax1 in T.grid(4, 4): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) + T.writes(output[v_ax0, v_ax1]) + output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] + + @R.function + def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None) + lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None) + lv2 = R.call_tir(Expected.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((4, 4), dtype="float32")) + lv_1: R.Tensor((16,), dtype="float32") = R.layout_transform(lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None) + gv: R.Tensor((16,), dtype="float32") = lv_1 + R.output(gv) + return gv + + @T.prim_func + def add_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")): + for ax0, ax1 in T.grid(4, 4): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) + T.writes(output[v_ax0, v_ax1]) + output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] + # fmt: on + index_map = lambda i: (i // 4, i % 4) + _check( + Before, + Expected, + operator_name="relax.add", + replacement_primfunc=add_2d, + layout_changes=[index_map, index_map, index_map], + ) + + +def test_empty_layout_changes(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func + def mul_by_2(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): + T.func_attr({"operator_name": "relax.mul_by_2"}) + for ax0 in range(16): + with T.block("T_add"): + v_ax0 = T.axis.spatial(16, ax0) + T.reads(arg0[v_ax0]) + T.writes(output[v_ax0]) + output[v_ax0] = arg0[v_ax0] * T.float32(2) + + @R.function + def main(x: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): + with R.dataflow(): + lv = R.call_tir(Before.mul_by_2, (x,), out_sinfo=R.Tensor((16,), dtype="float32")) + gv: R.Tensor((16,), dtype="float32") = lv + R.output(gv) + return gv + @I.ir_module + class Expected: + @T.prim_func + def relax_mul_by_2_replacement(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): + T.func_attr({"operator_name": "relax.mul_by_2"}) + for ax0 in range(16): + with T.block("T_add"): + v_ax0 = T.axis.spatial(16, ax0) + T.reads(arg0[v_ax0]) + T.writes(output[v_ax0]) + output[v_ax0] = arg0[v_ax0] + arg0[v_ax0] + + @R.function + def main(x: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): + with R.dataflow(): + lv = R.call_tir(Expected.relax_mul_by_2_replacement, (x,), out_sinfo=R.Tensor((16,), dtype="float32")) + gv: R.Tensor((16,), dtype="float32") = lv + R.output(gv) + return gv + + @T.prim_func + def add_x_x(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): + T.func_attr({"operator_name": "relax.mul_by_2"}) + for ax0 in range(16): + with T.block("T_add"): + v_ax0 = T.axis.spatial(16, ax0) + T.reads(arg0[v_ax0]) + T.writes(output[v_ax0]) + output[v_ax0] = arg0[v_ax0] + arg0[v_ax0] + # fmt: on + _check( + Before, + Expected, + operator_name="relax.mul_by_2", + replacement_primfunc=add_x_x, + layout_changes=[], + ) + + +def test_multiple_outputs(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func + def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), output0: T.Buffer((16,), "float32"), output1: T.Buffer((16,), "float32")): + T.func_attr({"operator_name": "relax.some_op"}) + for ax0 in range(16): + with T.block("T_add"): + v_ax0 = T.axis.spatial(16, ax0) + T.reads(arg0[v_ax0], arg1[v_ax0]) + T.writes(output0[v_ax0], output1[v_ax0]) + output0[v_ax0] = arg0[v_ax0] + arg1[v_ax0] + output1[v_ax0] = arg0[v_ax0] - arg1[v_ax0] + + @R.function + def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")): + with R.dataflow(): + gv = R.call_tir(Before.some_op, (x, y), out_sinfo=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")]) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), "float32")): + T.func_attr({"operator_name": "relax.some_op"}) + for ax0, ax1 in T.grid(4, 4): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) + T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1]) + output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] + output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1] + + @R.function + def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None) + lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None) + lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv, lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), dtype="float32")]) + lv3: R.Tensor((4, 4), dtype="float32") = lv2[0] + lv4: R.Tensor((16,), dtype="float32") = R.layout_transform(lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None) + lv5: R.Tensor((4, 4), dtype="float32") = lv2[1] + lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None) + gv: R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")) = (lv4, lv6) + R.output(gv) + return gv + + @T.prim_func + def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), "float32")): + for ax0, ax1 in T.grid(4, 4): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) + T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1]) + output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] + output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1] + # fmt: on + + index_map = lambda i: (i // 4, i % 4) + _check( + Before, + Expected, + operator_name="relax.some_op", + replacement_primfunc=some_op_2d, + layout_changes=[index_map, index_map, index_map, index_map], + ) + + +def test_unsupported_implicit_padding(): + @I.ir_module + class InputModule: + @R.function + def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"): + with R.dataflow(): + lv = R.call_tir(InputModule.relu, (x,), out_sinfo=R.Tensor((14,), dtype="float32")) + gv: R.Tensor((14,), dtype="float32") = lv + R.output(gv) + return gv + + @T.prim_func + def relu(arg0: T.Buffer((14,), "float32"), output: T.Buffer((14,), "float32")): + T.func_attr({"operator_name": "relax.relu"}) + for ax0 in T.grid(14): + with T.block("T_add"): + v_ax0 = T.axis.remap("S", [ax0]) + T.reads(arg0[v_ax0]) + T.writes(output[v_ax0]) + output[v_ax0] = T.max(arg0[v_ax0], T.float32(0)) + + before = InputModule + + @T.prim_func + def relu_pad(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): + for ax0 in T.grid(16): + with T.block("T_add"): + v_ax0 = T.axis.remap("S", [ax0]) + T.reads(arg0[v_ax0]) + T.writes(output[v_ax0]) + output[v_ax0] = T.max(arg0[v_ax0], T.float32(0)) + + # introduces implicit padding for shape (14,) + index_map = lambda i: (i % 16) + operator_name = "relax.relu" + with pytest.raises( + tvm.TVMError, match="Non bijective transforms on input and output buffers are not supported" + ): + _ = relax.transform.AlterOpImpl( + {operator_name: relu_pad}, {operator_name: [index_map, index_map]} + )(before) + + +def test_multiple_call_sites(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func + def add(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): + T.func_attr({"operator_name": "relax.add"}) + for ax0 in range(16): + with T.block("T_add"): + v_ax0 = T.axis.spatial(16, ax0) + T.reads(arg0[v_ax0], arg1[v_ax0]) + T.writes(output[v_ax0]) + output[v_ax0] = arg0[v_ax0] + arg1[v_ax0] + + @R.function + def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): + with R.dataflow(): + lv0 = R.call_tir(Before.add, (x, y), out_sinfo=R.Tensor((16,), dtype="float32")) + lv1 = R.nn.relu(lv0) + lv2 = R.call_tir(Before.add, (lv0, lv1), out_sinfo=R.Tensor((16,), dtype="float32")) + gv: R.Tensor((16,), dtype="float32") = lv2 + R.output(gv) + return gv + @I.ir_module + class Expected: + @T.prim_func + def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")): + T.func_attr({"operator_name": "relax.add"}) + # with T.block("root"): + for ax0, ax1 in T.grid(4, 4): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) + T.writes(output[v_ax0, v_ax1]) + output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] + + @R.function + def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None) + lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None) + lv2 = R.call_tir(Expected.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((4, 4), dtype="float32")) + lv0: R.Tensor((16,), dtype="float32") = R.layout_transform(lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None) + lv1_1: R.Tensor((16,), dtype="float32") = R.nn.relu(lv0) + lv3: R.Tensor((4, 4), dtype="float32") = R.layout_transform(lv0, index_map=lambda i: (i // 4, i % 4), pad_value=None) + lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(lv1_1, index_map=lambda i: (i // 4, i % 4), pad_value=None) + lv5 = R.call_tir(Expected.relax_add_replacement, (lv3, lv4), out_sinfo=R.Tensor((4, 4), dtype="float32")) + lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None) + gv: R.Tensor((16,), dtype="float32") = lv2_1 + R.output(gv) + return gv + @T.prim_func + def add_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")): + for ax0, ax1 in T.grid(4, 4): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) + T.writes(output[v_ax0, v_ax1]) + output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] + # fmt: on + index_map = lambda i: (i // 4, i % 4) + _check( + Before, + Expected, + operator_name="relax.add", + replacement_primfunc=add_2d, + layout_changes=[index_map, index_map, index_map], + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py b/tests/python/relax/test_transform_annotate_tir_op_pattern.py new file mode 100644 index 000000000000..c2f0e7af5fba --- /dev/null +++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py @@ -0,0 +1,406 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import enum + +import tvm +import tvm.script +import tvm.testing +from tvm import relax +from tvm.script import tir as T + + +class OpPatternKind(enum.IntEnum): + kElemWise = 0 + kBroadcast = 1 + kInjective = 2 + kCommReduce = 3 + kOutEWiseFusable = 4 + kTuple = 7 + kOpaque = 8 + + +def test_annotate_opkind_outewisefusable(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.int32() + n = T.int32() + k = T.int32() + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["tir_matmul"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_annotate_opkind_outewisefusable_int_var_signature(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle, m: T.int64, n: T.int64, k: T.int64): + T.func_attr({"global_symbol": "tir_matmul"}) + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["tir_matmul"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_annotate_opkind_reduce(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def sum(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16,)) + + for i, j in T.grid(16, 16): + with T.block("matmul"): + vi, vj = T.axis.remap("SR", [i, j]) + with T.init(): + B[vi] = 0.0 + B[vi] += A[vi, vj] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["sum"].attrs["op_pattern"] == OpPatternKind.kCommReduce + + +def test_annotate_opkind_ewise(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def elemwise(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16, 16)) + + for i, j in T.grid(16, 16): + with T.block("matmul"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["elemwise"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_broadcast(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def broadcast(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16, 16, 16, 16)) + + for i0, j0, i1, j1 in T.grid(16, 16, 16, 16): + with T.block("matmul"): + vi0, vj0, vi1, vj1 = T.axis.remap("SSSS", [i0, j0, i1, j1]) + B[vi0, vj0, vi1, vj1] = A[vj0, vj1] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["broadcast"].attrs["op_pattern"] == OpPatternKind.kBroadcast + + +def test_annotate_opkind_injective(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def injective(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (4, 4, 4, 4)) + B = T.match_buffer(y, (16, 16)) + + for i, j in T.grid(16, 16): + with T.block("matmul"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi // 4, vj // 4, vi % 4, vj % 4] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["injective"].attrs["op_pattern"] == OpPatternKind.kInjective + + +def test_annotate_opkind_bias_add(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_bias_add( + A: T.Buffer((1, 1000), "float32"), + B: T.Buffer((1000,), "float32"), + C: T.Buffer((1, 1000), "float32"), + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1 in T.grid(1, 1000): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax1] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["tir_bias_add"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_add_broadcast_with_unit_shape(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def add_with_unit_dim_len_broadcast( + A: T.Buffer((1, 64, 112, 112), "float32"), + B: T.Buffer((64, 1, 1), "float32"), + C: T.Buffer((1, 64, 112, 112), "float32"), + ) -> None: + T.func_attr({"global_symbol": "add5", "tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(1, 64, 112, 112): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, ax2, ax3], B[ax1, 0, 0]) + T.writes(C[ax0, ax1, ax2, ax3]) + C[ax0, ax1, ax2, ax3] = A[ax0, ax1, ax2, ax3] + B[ax1, 0, 0] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["add_with_unit_dim_len_broadcast"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_add_zero_dim_element_wise(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def add_zero_dim( + A: T.Buffer((128,), "float32"), + B: T.Buffer((), "float32"), + C: T.Buffer((128,), "float32"), + ) -> None: + T.func_attr({"global_symbol": "add8", "tir.noalias": True}) + for i0 in T.serial(128): + with T.block("T_add"): + ax0 = T.axis.spatial(128, i0) + T.reads(A[ax0], B[()]) + T.writes(C[ax0]) + C[ax0] = A[ax0] + B[()] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["add_zero_dim"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_pooling(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def max_pool2d( + rxplaceholder_1: T.Buffer((1, 64, 112, 112), "float32"), + tensor_1: T.Buffer((1, 64, 56, 56), "float32"), + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "max_pool2d", "T.noalias": True}) + # body + # with T.block("root") + pad_temp_1 = T.alloc_buffer([1, 64, 114, 114], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 64, 114, 114): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1]) + T.writes(pad_temp_1[ax0, ax1, ax2, ax3]) + pad_temp_1[ax0, ax1, ax2, ax3] = T.if_then_else( + 1 <= ax2 and ax2 < 113 and 1 <= ax3 and ax3 < 113, + rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1], + T.float32(-3.4028234663852886e38), + dtype="float32", + ) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 64, 56, 56, 3, 3): + with T.block("tensor"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads( + tensor_1[ax0, ax1, ax2, ax3], + pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1], + ) + T.writes(tensor_1[ax0, ax1, ax2, ax3]) + with T.init(): + tensor_1[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e38) + tensor_1[ax0, ax1, ax2, ax3] = T.max( + tensor_1[ax0, ax1, ax2, ax3], + pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1], + ) + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["max_pool2d"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_annotate_opkind_softmax(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def softmax( + rxplaceholder_1: T.Buffer((16, 16), "float32"), + T_softmax_norm_1: T.Buffer((16, 16), "float32"), + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "softmax", "T.noalias": True}) + # body + # with T.block("root") + T_softmax_maxelem_1 = T.alloc_buffer([16], dtype="float32") + T_softmax_exp_1 = T.alloc_buffer([16, 16], dtype="float32") + T_softmax_expsum_1 = T.alloc_buffer([16], dtype="float32") + for i0_7, i1_3 in T.grid(16, 16): + with T.block("T_softmax_maxelem"): + i0_8, k = T.axis.remap("SR", [i0_7, i1_3]) + T.reads(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k]) + T.writes(T_softmax_maxelem_1[i0_8]) + with T.init(): + T_softmax_maxelem_1[i0_8] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem_1[i0_8] = T.max( + T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k] + ) + for i0_9, i1_4 in T.grid(16, 16): + with T.block("T_softmax_exp"): + i0_10, i1_5 = T.axis.remap("SS", [i0_9, i1_4]) + T.reads(rxplaceholder_1[i0_10, i1_5], T_softmax_maxelem_1[i0_10]) + T.writes(T_softmax_exp_1[i0_10, i1_5]) + T_softmax_exp_1[i0_10, i1_5] = T.exp( + rxplaceholder_1[i0_10, i1_5] - T_softmax_maxelem_1[i0_10], dtype="float32" + ) + for i0_11, i1_6 in T.grid(16, 16): + with T.block("T_softmax_expsum"): + i0_12, k = T.axis.remap("SR", [i0_11, i1_6]) + T.reads(T_softmax_expsum_1[i0_12], T_softmax_exp_1[i0_12, k]) + T.writes(T_softmax_expsum_1[i0_12]) + with T.init(): + T_softmax_expsum_1[i0_12] = T.float32(0) + T_softmax_expsum_1[i0_12] = ( + T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k] + ) + for i0_13, i1_7 in T.grid(16, 16): + with T.block("T_softmax_norm"): + i0_14, i1_8 = T.axis.remap("SS", [i0_13, i1_7]) + T.reads(T_softmax_exp_1[i0_14, i1_8], T_softmax_expsum_1[i0_14]) + T.writes(T_softmax_norm_1[i0_14, i1_8]) + T.block_attr({"axis": 1}) + T_softmax_norm_1[i0_14, i1_8] = ( + T_softmax_exp_1[i0_14, i1_8] / T_softmax_expsum_1[i0_14] + ) + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["softmax"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_multiple_bufer_stores_fallback(): + @tvm.script.ir_module + class CumsumModule: + @T.prim_func + def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer(160, "float32")): + rxplaceholder = T.match_buffer( + var_rxplaceholder, [10, 16], dtype="float32", offset_factor=1 + ) + with T.block("cumsum_generic"): + T.reads(rxplaceholder[0:10, 0:16]) + T.writes(out_buf[0:160]) + for fused in T.parallel(1): + out_buf[fused * 160] = rxplaceholder[fused * 160 // 16, fused * 160 % 16] + for v_k in T.serial(159): + out_buf[fused * 160 + (v_k + 1)] = ( + out_buf[fused * 160 + (v_k + 1 - 1)] + + rxplaceholder[ + (fused * 160 + (v_k + 1)) // 16, + (fused * 160 + (v_k + 1)) % 16, + ] + ) + + mod = CumsumModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["cumsum"].attrs["op_pattern"] == OpPatternKind.kOpaque + + +def test_sum_sqsum(): + @tvm.script.ir_module + class Module: + @T.prim_func + def sum_sqsum( + A: T.Buffer((32, 64), "float32"), + vsum: T.Buffer((32,), "float32"), + sqsum: T.Buffer((32,), "float32"), + ): + for ax0, k0 in T.grid(32, 64): + with T.block("block"): + v_ax0, v_k0 = T.axis.remap("SR", [ax0, k0]) + T.reads(A[v_ax0, v_k0]) + T.writes(vsum[v_ax0], sqsum[v_ax0]) + with T.init(): + vsum[v_ax0] = T.float32(0) + sqsum[v_ax0] = T.float32(0) + v_vsum: T.float32 = vsum[v_ax0] + A[v_ax0, v_k0] + v_sqsum: T.float32 = sqsum[v_ax0] + A[v_ax0, v_k0] * A[v_ax0, v_k0] + vsum[v_ax0] = v_vsum + sqsum[v_ax0] = v_sqsum + + mod = Module + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["sum_sqsum"].attrs["op_pattern"] == OpPatternKind.kCommReduce + + +def test_no_buffer_stores(): + @tvm.script.ir_module + class Module: + @T.prim_func + def no_buffer_stores(A: T.Buffer((32, 64), "float32"), vsum: T.Buffer((32,), "float32")): + for ax0, k0 in T.grid(32, 64): + with T.block("block"): + v_ax0, v_k0 = T.axis.remap("SR", [ax0, k0]) + T.reads(A[v_ax0, v_k0]) + T.writes(vsum[v_ax0]) + # absence of buffer stores usually happens when there is an external call for + # computation. We assume opaque in all such cases. + T.call_packed("some_func") + + mod = Module + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["no_buffer_stores"].attrs["op_pattern"] == OpPatternKind.kOpaque + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py new file mode 100644 index 000000000000..0937eed25a9a --- /dev/null +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm import tir, relax +from tvm.ir import assert_structural_equal + +import tvm.script +from tvm.script import tir as T, relax as R + + +@tvm.script.ir_module +class Before: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + m = T.int64() + n = T.int64() + k = T.int64() + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")) -> R.Tensor: + m, n, k = T.int64(), T.int64(), T.int64() + gv0 = R.call_tir(Before.tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) + return gv0 + + +def test_basic(): + @tvm.script.ir_module + class Expected: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.int64() + n = T.int64() + k = T.int64() + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def main( + x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") + ) -> R.Tensor: + R.func_attr({"global_symbol": "main"}) + m, n, k = T.int64(), T.int64(), T.int64() + gv0 = R.call_tir(Expected.tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) + return gv0 + + before = Before + expected = Expected + after = relax.transform.AttachGlobalSymbol()(before) + assert_structural_equal(after, expected) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py new file mode 100644 index 000000000000..8e760b6fd70f --- /dev/null +++ b/tests/python/relax/test_transform_bind_params.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +import tvm.script +import tvm.testing +from tvm import relax +from tvm.script import relax as R +from tvm.script import tir as T + +use_np_array = tvm.testing.parameter(False, True) + + +def test_bind_params(use_np_array): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16, 16)) + C = T.match_buffer(z, (16, 16)) + for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): + with T.block("matmul"): + vi = T.axis.S(16, i0 * 4 + i1) + vj = T.axis.S(16, j) + vk = T.axis.R(16, k0 * 4 + k1) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = R.call_tir(InputModule.tir_matmul, (x, w), R.Tensor((16, 16), dtype="float32")) + return gv0 + + x_np = np.random.rand(16, 16).astype(np.float32) + w_np = np.random.rand(16, 16).astype(np.float32) + x_tvm = tvm.nd.array(x_np) + w_tvm = tvm.nd.array(w_np) + params_dict = {"w": w_np if use_np_array else w_tvm} + mod = relax.transform.BindParams("main", params_dict)(InputModule) + assert len(mod["main"].params) == 1 + + target = tvm.target.Target("llvm") + ex_after = relax.build(mod, target) + vm_after = relax.VirtualMachine(ex_after, tvm.cpu()) + res_after = vm_after["main"](x_tvm) + + ex_before = relax.build(InputModule, target) + vm_before = relax.VirtualMachine(ex_before, tvm.cpu()) + res_before = vm_before["main"](x_tvm, w_tvm) + + tvm.testing.assert_allclose(res_before.numpy(), res_after.numpy()) + + +def test_bind_params_symbolic_vars(): + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor(("batch", "m"), dtype="float32"), + w0: R.Tensor(("n", "m"), dtype="float32"), + b0: R.Tensor(("n",), dtype="float32"), + w1: R.Tensor(("k", "n"), dtype="float32"), + b1: R.Tensor(("k",), dtype="float32"), + ) -> R.Tensor(("batch", "k"), dtype="float32"): + batch = T.Var("batch", "int64") + k = T.Var("k", "int64") + m = T.Var("m", "int64") + n = T.Var("n", "int64") + with R.dataflow(): + lv0 = R.call_dps_packed( + "linear0", (x, w0, b0), out_sinfo=R.Tensor((batch, n), dtype="float32") + ) + out = R.call_dps_packed( + "linear1", (lv0, w1, b1), out_sinfo=R.Tensor((batch, k), dtype="float32") + ) + R.output(out) + return out + + m, n, k = 4, 6, 8 + w0_tvm = tvm.nd.array(np.random.rand(n, m).astype(np.float32)) + b0_tvm = tvm.nd.array(np.random.rand(n).astype(np.float32)) + w1_tvm = tvm.nd.array(np.random.rand(k, n).astype(np.float32)) + b1_tvm = tvm.nd.array(np.random.rand(k).astype(np.float32)) + params_dict = {"w0": w0_tvm, "b0": b0_tvm, "w1": w1_tvm, "b1": b1_tvm} + mod = relax.transform.BindParams("main", params_dict)(Before) + + # Since it contains ConstantNode, it's hard to check with structural equality. + func = mod["main"] + assert len(func.params) == 1 + batch = func.params[0].struct_info.shape[0] + tvm.ir.assert_structural_equal( + func.params[0].struct_info, relax.TensorStructInfo((batch, 4), "float32") + ) + tvm.ir.assert_structural_equal( + func.ret_struct_info, relax.TensorStructInfo((batch, 8), "float32") + ) + bindings = func.body.blocks[0].bindings + tvm.ir.assert_structural_equal( + bindings[0].var.struct_info, relax.TensorStructInfo((batch, 6), "float32") + ) + tvm.ir.assert_structural_equal( + bindings[1].var.struct_info, relax.TensorStructInfo((batch, 8), "float32") + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py new file mode 100644 index 000000000000..086c316ae817 --- /dev/null +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.script +import tvm.testing +import pytest +from tvm import relax +from tvm.ir.base import assert_structural_equal +from tvm.script import relax as R, tir as T + + +def test_simple_assignments(): + @tvm.script.ir_module + class TestChainAssignments: + @R.function + def main(x: R.Tensor): + y = x + z = y + q = z + p = q + o = p + return o + + # a little annoying to have these unused bindings around + # but they can be eliminated in a separate pass + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + y = x + z = x + q = x + p = x + o = x + return x + + new_mod = relax.transform.CanonicalizeBindings()(TestChainAssignments) + assert_structural_equal(new_mod, Expected) + + +def test_dataflow_block(): + @tvm.script.ir_module + class TestDataflowAssignments: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = R.const(1) + z = y + o = z + p = o + m = p + n = m + R.output(n) + return n + + # a little annoying to have these unused bindings around + # but they can be eliminated in a separate pass + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = R.const(1) + z = y + o = y + p = y + m = y + # we can't get rid of n because it leaves the block + n = y + R.output(n) + return n + + new_mod = relax.transform.CanonicalizeBindings()(TestDataflowAssignments) + assert_structural_equal(new_mod, Expected) + + +def test_ops(): + @tvm.script.ir_module + class TestOps: + @R.function + def main(x: R.Tensor, y: R.Tensor): + w = y + q = x + z = R.add(w, q) + return R.add(q, z) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor, y: R.Tensor): + w = y + q = x + z = R.add(y, x) + return R.add(x, z) + + new_mod = relax.transform.CanonicalizeBindings()(TestOps) + assert_structural_equal(new_mod, Expected) + + +@pytest.mark.xfail(reason="The lhs and rhs of an assignment should have the same struct info.") +def test_casting(): + @tvm.script.ir_module + class TestCasting: + @R.function + def main(x: R.Tensor) -> R.Object: + y = x + # z will be treated as object type even though it's a tensor + z: R.Object = y + return z + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor) -> R.Object: + y = x + # Cannot unify because the cast indicates user intent + z: R.Object = x + return z + + new_mod = relax.transform.CanonicalizeBindings()(TestCasting) + assert_structural_equal(new_mod, Expected) + + +def test_match_cast(): + @tvm.script.ir_module + class TestMatchCast: + @R.function + def main(x: R.Tensor): + q = x + m, n = T.int64(), T.int64() + z = R.match_cast(q, R.Tensor((m, n))) + w = z + return w + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + q = x + # can't get rid of z because its shape_ is different from x's + m, n = T.int64(), T.int64() + z = R.match_cast(x, R.Tensor((m, n))) + w = z + return z + + new_mod = relax.transform.CanonicalizeBindings()(TestMatchCast) + assert_structural_equal(new_mod, Expected) + + +def test_same_shape(): + @tvm.script.ir_module + class TestSameShape: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")): + m, n = T.int64(), T.int64() + y = x + # trivial check + z = R.match_cast(x, R.Tensor((m, n), "float32")) + w = z + q = R.add(w, y) + return R.add(q, w) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")): + m, n = T.int64(), T.int64() + y = x + # canonicalized into a var binding + z = x + w = x + q = R.add(x, x) + return R.add(q, x) + + new_mod = relax.transform.CanonicalizeBindings()(TestSameShape) + assert_structural_equal(new_mod, Expected) + + +def test_change_shape(): + @tvm.script.ir_module + class TestChangeShape: + @R.function + def main(x: R.Tensor(("m", "n"))): + y = x + # not trivial: introduces new shape vars + o, p = T.int64(), T.int64() + z = R.match_cast(x, R.Tensor((o, p))) + w = z + q = R.add(w, y) + return R.add(q, w) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"))): + y = x + o, p = T.int64(), T.int64() + z = R.match_cast(x, R.Tensor((o, p))) + w = z + # the shape_ field on q will need to be updated + q = R.add(z, x) + return R.add(q, z) + + new_mod = relax.transform.CanonicalizeBindings()(TestChangeShape) + assert_structural_equal(new_mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py new file mode 100644 index 000000000000..77756dc66474 --- /dev/null +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -0,0 +1,254 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import os +import tvm +import tvm.testing +from tvm import relax, tir +import numpy as np +from tvm.script import relax as R +from tvm.relax.testing import transform +import tempfile +from tvm.relax.transform.tuning_api import Trace +from tvm.relax.dpl import is_op, wildcard + +env_checker_codegen = tvm.get_global_func("relax.ext.tensorrt", True) +env_checker_runtime = tvm.get_global_func("relax.is_tensorrt_runtime_enabled", True) + +has_tensorrt_codegen = pytest.mark.skipif( + not env_checker_codegen, + reason="TensorRT codegen not available", +) +has_tensorrt_runtime = pytest.mark.skipif( + not env_checker_runtime or not env_checker_runtime(), + reason="TensorRT runtime not available", +) + +# Global variable in pytest that applies markers to all tests. +pytestmark = [has_tensorrt_codegen, has_tensorrt_runtime] + +# Target gpu +target_str = "nvidia/nvidia-t4" +target = tvm.target.Target(target_str) +dev = tvm.cuda() + + +def check_executable(exec, dev, inputs, expected): + vm = relax.VirtualMachine(exec, dev) + out = vm["main"](*inputs) + tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5, rtol=1e-5) + + +def check_roundtrip(exec0, dev, inputs, expected): + exec0.mod.export_library("exec.so") + exec1 = tvm.runtime.load_module("exec.so") + os.remove("exec.so") + assert exec0.stats() == exec1["stats"]() + assert exec0.as_text() == exec1["as_text"]() + + check_executable(exec0, dev, inputs, expected) + check_executable(exec1, dev, inputs, expected) + + +def gen_ground_truth(mod, target, dev, inputs): + # Lower and run tuning + # Since there is no default schedule for GPU in MS yet, this is necessary + with target: + seq = tvm.transform.Sequential( + [relax.transform.LegalizeOps(), tir.transform.DefaultGPUSchedule()] + ) + new_mod = seq(mod) + assert relax.analysis.well_formed(new_mod) + exec = relax.build(new_mod, target, params={}) + vm = relax.VirtualMachine(exec, dev) + return vm["main"](*inputs) + + +@tvm.script.ir_module +class InputModule: + @R.function + def main( + x: R.Tensor((16, 16), "float32"), y: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + with R.dataflow(): + z1 = R.multiply(x, y) + z2 = R.add(z1, x) + z3 = R.add(z1, z2) + z4 = R.multiply(z3, z2) + z5 = R.add(z4, z1) + R.output(z5) + return z5 + + +def setup_test(): + # Prepare IRModule and its input + mod = InputModule + assert isinstance(mod, tvm.IRModule) + + np0 = np.random.rand(16, 16).astype(np.float32) + np1 = np.random.rand(16, 16).astype(np.float32) + data0 = tvm.nd.array(np0, dev) + data1 = tvm.nd.array(np1, dev) + inputs = [data0, data1] + + # Ground truth should be generated before annotation + # due to the conflict with MS task extraction + # TODO(@sunggg): Sort this out + expected = gen_ground_truth(mod, target, dev, inputs) + return mod, inputs, expected + + +@tvm.testing.requires_gpu +def test_tensorrt_only(): + mod, inputs, expected = setup_test() + + # Define patterns that we want to offload to byoc + # This test will offload entire model + # Thus, define patterns for both `multiply` and `add` ops + patterns = [ + ("tensorrt.multiply", is_op("relax.multiply")(wildcard(), wildcard())), + ("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())), + ] + + new_mod = tvm.transform.Sequential( + [ + relax.transform.FuseOpsByPattern(patterns), + relax.transform.MergeCompositeFunctions(), + relax.transform.RunCodegen(), + ] + )(mod) + + ex0 = relax.build(new_mod, target, params={}) + # Sanity check for the correctness and roundtrip + check_roundtrip(ex0, dev, inputs, expected) + + +@tvm.testing.requires_gpu +def test_mix_use_tensorrt_and_tvm(): + mod, inputs, expected = setup_test() + + # Define patterns that we want to offload to byoc + # This test will only offload `add` op to tensorrt + # and tune `multiply` op with MetaSchedule + patterns = [ + ("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())), + ] + + # Run Codegen pass + with tempfile.TemporaryDirectory() as work_dir: + with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=0): + new_mod = tvm.transform.Sequential( + [ + relax.transform.FuseOpsByPattern(patterns), + relax.transform.MergeCompositeFunctions(), + relax.transform.RunCodegen(), + relax.transform.LegalizeOps(), + relax.transform.MetaScheduleTuneIRMod( + params={}, work_dir=work_dir, max_trials_global=8 + ), + relax.transform.MetaScheduleApplyDatabase(work_dir), + ] + )(mod) + assert relax.analysis.well_formed(new_mod) + with transform.PassContext(opt_level=0): + ex0 = relax.build(new_mod, target, params={}) + + # Sanity check for the correctness and roundtrip + check_roundtrip(ex0, dev, inputs, expected) + + +@tvm.script.ir_module +class Conv2dx2: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), + weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + cls = Conv2dx2 + with R.dataflow(): + lv: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_tensorrt( + data, weight1 + ) + gv: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_tensorrt( + lv, weight2 + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_tensorrt( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + R.func_attr({"Codegen": "tensorrt", "global_symbol": "fused_relax_nn_conv2d_tensorrt"}) + + @R.function + def gv( + data_1: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1}) + with R.dataflow(): + gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = R.nn.conv2d( + data_1, + weight1_1, + padding=[1, 1, 1, 1], + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(gv_1) + return gv_1 + + gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv(data, weight1) + return gv1 + + +@tvm.script.ir_module +class Conv2dx2_after: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), + weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + with R.dataflow(): + lv = R.call_dps_packed( + "fused_relax_nn_conv2d_tensorrt", + (data, weight1), + out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"), + ) + gv = R.call_dps_packed( + "fused_relax_nn_conv2d_tensorrt", + (lv, weight2), + out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"), + ) + R.output(gv) + return gv + + +def test_multiple_calls_same_extern(): + mod = relax.transform.RunCodegen()(Conv2dx2) + tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"]) + + +# TODO(@sunggg): test with more complex patterns (e.g., multiple annots, mixed codegens, different ops, const binding) + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py new file mode 100644 index 000000000000..5187ab30b762 --- /dev/null +++ b/tests/python/relax/test_transform_convert_layout.py @@ -0,0 +1,1406 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.relax.transform import ConvertLayout, Normalize +from tvm.script.parser import ir as I, relax as R, tir as T + + +def verify(input, expected): + mod = ConvertLayout({"relax.nn.conv2d": ["NHWC", "OHWI"]})(input) + mod = Normalize()(mod) + print(mod.script()) + tvm.ir.assert_structural_equal(mod, expected) + + +def test_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv) + return gv + + verify(Input, Expected) + + +def test_conv2d_onlydim(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4) + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor("float32", ndim=4) = R.nn.conv2d(x, w, out_dtype="float32") + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(dtype="float32", ndim=4), w: R.Tensor(dtype="float32", ndim=4) + ) -> R.Tensor(dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv2: R.Tensor(dtype="float32", ndim=4) = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv2, axes=[0, 3, 1, 2]) + R.output(gv) + return gv + + verify(Input, Expected) + + +def test_conv2d_symbolic(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4) + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + N, C, H, W = T.int64(), T.int64(), T.int64(), T.int64() + lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32")) + gv: R.Tensor("float32", ndim=4) = R.nn.conv2d(lv0, w, out_dtype="float32") + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(dtype="float32", ndim=4), w: R.Tensor(dtype="float32", ndim=4) + ) -> R.Tensor(dtype="float32", ndim=4): + N = T.int64() + C = T.int64() + H = T.int64() + W = T.int64() + with R.dataflow(): + lv0: R.Tensor((N, C, H, W), dtype="float32") = R.match_cast( + x, R.Tensor((N, C, H, W), dtype="float32") + ) + lv: R.Tensor((N, H, W, C), dtype="float32") = R.permute_dims(lv0, axes=[0, 2, 3, 1]) + lv1: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv2: R.Tensor(dtype="float32", ndim=4) = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv2, axes=[0, 3, 1, 2]) + R.output(gv) + return gv + + verify(Input, Expected) + + +def test_conv2d_matchcast_bias(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4) + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + lv0: R.Tensor("float32", ndim=4) = R.nn.conv2d(x, w, out_dtype="float32") + N, C, H, W = T.int64(), T.int64(), T.int64(), T.int64() + lv1 = R.match_cast(lv0, R.Tensor((N, C, H, W), "float32")) + gv = R.add(lv1, w) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(dtype="float32", ndim=4), w: R.Tensor(dtype="float32", ndim=4) + ) -> R.Tensor(dtype="float32", ndim=4): + N = T.int64() + H = T.int64() + W = T.int64() + C = T.int64() + with R.dataflow(): + lv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv0: R.Tensor(dtype="float32", ndim=4) = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((N, H, W, C), dtype="float32") = R.match_cast( + lv0, R.Tensor((N, H, W, C), dtype="float32") + ) + lv3: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv4: R.Tensor(dtype="float32", ndim=4) = R.add(lv2, lv3) + gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv4, axes=[0, 3, 1, 2]) + R.output(gv) + return gv + + verify(Input, Expected) + + +def test_conv2d_relu(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_relu_conv2d_relu(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 3, 28, 28), "float32") = R.nn.relu(x) + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.relu(x) + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims( + x0, axes=[0, 2, 3, 1] + ) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_relu_tanh(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.tanh(gv2) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.tanh(gv2) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv3) + return gv3 + + verify(Input, Expected) + + +def test_conv2d_add(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.permute_dims( + bias, axes=[0, 2, 3, 1] + ) + lv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv3, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_add_relu_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 4), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.permute_dims( + bias, axes=[0, 2, 3, 1] + ) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.add(gv, lv2) + gv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((4, 3, 3, 4), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv4: R.Tensor((2, 24, 24, 4), dtype="float32") = R.nn.conv2d( + gv3, + lv3, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.permute_dims( + lv4, axes=[0, 3, 1, 2] + ) + R.output(gv4) + return gv4 + + verify(Input, Expected) + + +def test_conv2d_fma_relu_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.ewise_fma(gv, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 4), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + gv, axes=[0, 3, 1, 2] + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.ewise_fma(lv2, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.permute_dims( + gv3, axes=[0, 2, 3, 1] + ) + lv4: R.Tensor((4, 3, 3, 4), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv5: R.Tensor((2, 24, 24, 4), dtype="float32") = R.nn.conv2d( + lv3, + lv4, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.permute_dims( + lv5, axes=[0, 3, 1, 2] + ) + R.output(gv4) + return gv4 + + verify(Input, Expected) + + +def test_conv2d_sum(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[2, 3]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=2): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 4), dtype="float32") = R.sum(gv, axis=[1, 2], keepdims=False) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_sum_keepdim(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 1, 1), "float32") = R.sum(gv, axis=[2, 3], keepdims=True) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 1, 4), dtype="float32") = R.sum(gv, axis=[1, 2], keepdims=True) + gv2: R.Tensor((2, 4, 1, 1), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_sum_negative_dims(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[-2, -1]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 4), dtype="float32") = R.sum(gv, axis=[1, 2]) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_transpose(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv, axes=[3, 2, 1, 0]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((26, 26, 4, 2), dtype="float32") = R.permute_dims( + gv, axes=[2, 1, 3, 0] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_expand_dims_scalar(): + @I.ir_module + class Input: + @R.function + def main() -> R.Tensor((1,), dtype="int64"): + with R.dataflow(): + gv: R.Tensor((1,), dtype="int64") = R.expand_dims(R.const(0, "int64"), axis=[0]) + R.output(gv) + return gv + + verify(Input, Input) + + +def test_conv2d_expand_dims(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=6): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") = R.expand_dims(gv, axis=(-3, 1)) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=6): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 1, 26, 4), dtype="float32") = R.expand_dims( + gv, axis=[-3, 1] + ) + gv2: R.Tensor((2, 1, 4, 1, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 1, 5, 3, 2, 4] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_expand_dims_squeeze(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") = R.expand_dims(gv, axis=(-3, 1)) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.squeeze(gv2, axis=[1, 3]) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 1, 26, 1, 26, 4), dtype="float32") = R.expand_dims( + gv, axis=[-3, 1] + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.squeeze(gv2, axis=[1, 3]) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv3) + return gv3 + + verify(Input, Expected) + + +def test_conv2d_strided_slice(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice( + gv, begin=[0, 0, 0], end=[4, 26, 26], strides=[2, 3, 4], axes=[1, 2, 3] + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 9, 7, 2), dtype="float32") = R.strided_slice( + gv, axes=[3, 1, 2], begin=[0, 0, 0], end=[4, 26, 26], strides=[2, 3, 4] + ) + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_relu_concat(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 26, 26, 8), dtype="float32") = R.concat((gv, gv2), axis=3) + gv3: R.Tensor((2, 8, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv3) + return gv3 + + verify(Input, Expected) + + +def test_conv2d_relu_concat_split(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=2, axis=1) + R.output(gv4) + return gv4 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 26, 26, 8), dtype="float32") = R.concat((gv, gv2), axis=3) + gv4: R.Tuple( + R.Tensor((2, 26, 26, 4), dtype="float32"), + R.Tensor((2, 26, 26, 4), dtype="float32"), + ) = R.split(gv3, indices_or_sections=2, axis=3) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = gv4[0] + lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + lv4: R.Tensor((2, 26, 26, 4), dtype="float32") = gv4[1] + lv5: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv4, axes=[0, 3, 1, 2] + ) + gv5 = (lv3, lv5) + R.output(gv5) + return gv5 + + verify(Input, Expected) + + +def test_conv2d_maxpool2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0], + layout="NCHW", + out_layout="NCHW", + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 13, 13, 4), dtype="float32") = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + dilation=[1, 1], + padding=[0, 0, 0, 0], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_avgpool2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.adaptive_avg_pool2d(gv, output_size=[13, 13], layout="NCHW") + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 13, 13, 4), dtype="float32") = R.nn.adaptive_avg_pool2d( + gv, output_size=[13, 13], layout="NHWC", out_layout="NHWC" + ) + gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_softmax(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.softmax(gv, axis=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.softmax(gv, axis=3) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_batchnorm(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + gamma: R.Tensor((4,), dtype="float32"), + beta: R.Tensor((4,), dtype="float32"), + moving_mean: R.Tensor((4,), dtype="float32"), + moving_var: R.Tensor((4,), dtype="float32"), + ): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ) = R.nn.batch_norm(gv, gamma, beta, moving_mean, moving_var, axis=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + gamma: R.Tensor((4,), dtype="float32"), + beta: R.Tensor((4,), dtype="float32"), + moving_mean: R.Tensor((4,), dtype="float32"), + moving_var: R.Tensor((4,), dtype="float32"), + ): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tuple( + R.Tensor((2, 26, 26, 4), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ) = R.nn.batch_norm( + gv, + gamma, + beta, + moving_mean, + moving_var, + axis=3, + epsilon=1.0000000000000001e-05, + center=True, + scale=True, + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = gv2[0] + lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + lv4: R.Tensor((4,), dtype="float32") = gv2[1] + lv5: R.Tensor((4,), dtype="float32") = gv2[2] + gv3 = (lv3, lv4, lv5) + R.output(gv3) + return gv3 + + verify(Input, Expected) + + +def test_conv2d_layernorm(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + gamma: R.Tensor((26, 26), dtype="float32"), + beta: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.layer_norm( + gv, gamma, beta, axes=[-2, -1] + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + gamma: R.Tensor((26, 26), dtype="float32"), + beta: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.layer_norm( + gv, + gamma, + beta, + axes=[1, 2], + epsilon=1.0000000000000001e-05, + center=True, + scale=True, + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_resize2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.image.resize2d(gv, (52, 52), layout="NCHW") + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 52, 52, 4), dtype="float32") = R.image.resize2d( + gv, + (52, 52), + roi=[T.float32(0), T.float32(0), T.float32(0), T.float32(0)], + layout="NHWC", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="void", + ) + gv2: R.Tensor((2, 4, 52, 52), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_unknown_bias_dim(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + w2: R.Tensor(dtype="float32"), + ) -> R.Tensor(None, "float32"): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = w2 + gv + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + w2: R.Tensor(dtype="float32"), + ) -> R.Tensor(dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + gv, axes=[0, 3, 1, 2] + ) + gv2: R.Tensor(dtype="float32") = R.add(w2, lv2) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_binary_broadcast(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + bias: R.Tensor((26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + bias: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + gv, axes=[0, 3, 1, 2] + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.add(lv2, bias) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py new file mode 100644 index 000000000000..4ee9653ead39 --- /dev/null +++ b/tests/python/relax/test_transform_cse.py @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test eliminate common subexpr pass""" +import tvm +import tvm.testing +from tvm.relax.transform import EliminateCommonSubexpr +from tvm.script.parser import ir as I, relax as R, tir as T + +import numpy as np + + +def verify(input, expected): + tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected) + + +def test_simple(): + @I.ir_module + class Before: + @R.function + def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv0 = R.add(x, y) + lv1 = R.add(x, y) + gv = R.multiply(lv0, lv1) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv0 = R.add(x, y) + # can combine with canonicalizing bindings + # and getting rid of unused bindings to eliminate this line too + lv1 = lv0 + gv = R.multiply(lv0, lv1) + R.output(gv) + return gv + + verify(Before, Expected) + + +def test_constants(): + @I.ir_module + class Before: + @R.function + def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")): + with R.dataflow(): + # we are not going to bind the constant 1 to a var + lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32")) + # we expect to bind the repeated large constants + lv1 = R.add( + R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))), + R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))), + ) + gv = (lv0, lv1) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")): + with R.dataflow(): + lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32")) + lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))) + lv2 = R.add(lv1, lv1) + gv = (lv0, lv2) + R.output(gv) + return gv + + verify(Before, Expected) + + +def test_repeated_inner_tuples(): + @I.ir_module + class Before: + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + # repeated units: (x, x), (x, (x, x)), ((x, x), (x, (x, x))) + tup = (((x, x), (x, (x, x))), ((x, x), (x, (x, x))), (x, (x, x))) + gv = tup[0][0][1] + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + t1 = (x, x) + t2 = (x, t1) + t3 = (t1, t2) + t4 = (t3, t3, t2) + gv = t4[0][0][1] + R.output(gv) + return gv + + verify(Before, Expected) + + +def test_inner_function(): + @I.ir_module + class Before: + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + # we are going to do CSE inside the local function + @R.function + def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + # not in dataflow: should not be touched + z = R.add(R.add(y, y), R.add(y, y)) + with R.dataflow(): + # writing this out in ANF to illustrate why CSE behaves as it does + # result of ANF transforming R.add(R.add(y, y), R.add(y, y)) + lv0 = R.add(y, y) + lv1 = R.add(y, y) + lv2 = R.add(lv0, lv1) + gv = lv2 + R.output(gv) + return R.add(z, gv) + + # also making the ANF explicit to better illustrate the result of CSE + # result of ANF transforming R.add(R.add(bar(x), bar(x)), R.add(bar(x), bar(x))) + lv0 = bar(x) + lv1 = bar(x) + lv2 = R.add(lv0, lv1) + lv3 = bar(x) + lv4 = bar(x) + lv5 = R.add(lv3, lv4) + lv6 = R.add(lv2, lv5) + gv = lv6 + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + + @R.function + def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + z = R.add(R.add(y, y), R.add(y, y)) + with R.dataflow(): + lv0 = R.add(y, y) + lv1 = lv0 + lv2 = R.add(lv0, lv1) + gv = lv2 + R.output(gv) + return R.add(z, gv) + + # can further clean this up + # using canonicalize bindings, eliminate unused bindings, and CSE again + lv0 = bar(x) + lv1 = lv0 + lv2 = R.add(lv0, lv1) + lv3 = lv0 + lv4 = lv0 + lv5 = R.add(lv3, lv4) + lv6 = R.add(lv2, lv5) + gv = lv6 + R.output(gv) + return gv + + verify(Before, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py new file mode 100644 index 000000000000..9c6e0e0567fe --- /dev/null +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -0,0 +1,452 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.relax.transform import DeadCodeElimination +from tvm.script.parser import ir as I, relax as R, tir as T + + +def verify(input, expected): + tvm.ir.assert_structural_equal(DeadCodeElimination()(input), expected) + + +def test_simple(): + @tvm.script.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + bias: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + # block 0 + with R.dataflow(): + gv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + gv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + gv, + gv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv21: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + gv2, axes=[0, 3, 1, 2] + ) + gv22: R.Tensor((2, 4, 26, 26), dtype="float32") = R.add(gv21, bias) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + bias: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + gv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + gv, + gv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_2block(): + @tvm.script.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + bias: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float16"): + # block 0 + with R.dataflow(): + gv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + gv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + gv, + gv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv21: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + gv2, axes=[0, 3, 1, 2] + ) + gv22: R.Tensor((2, 4, 26, 26), dtype="float32") = R.add(gv21, bias) + R.output(gv2) + gv3 = R.astype(gv2, dtype="float16") + return gv3 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + bias: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float16"): + with R.dataflow(): + gv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + gv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + gv, + gv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + R.output(gv2) + gv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.astype(gv2, dtype="float16") + return gv3 + + verify(Input, Expected) + + +def check_if_func_exists(mod, func_name): + gvs = [gv.name_hint for gv in mod.get_global_vars()] + return func_name in gvs + + +def test_unused_relax_func(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_add( + x: T.Buffer((16, 16), "float32"), + y: T.Buffer((16, 16), "float32"), + z: T.Buffer((16, 16), "float32"), + ) -> None: + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): + gv0 = R.add(x, w) + return gv0 + + @R.function + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((16, 16), dtype="float32")) + return gv0 + + mod = InputModule + assert mod + new_mod = DeadCodeElimination()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "tir_add") + assert not check_if_func_exists(new_mod, "unused_func") + + +def test_unused_relax_func_custom_entry_func(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_add( + x: T.Buffer((16, 16), "float32"), + y: T.Buffer((16, 16), "float32"), + z: T.Buffer((16, 16), "float32"), + ) -> None: + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): + gv0 = R.add(x, w) + return gv0 + + @R.function + def foo( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((16, 16), dtype="float32")) + return gv0 + + mod = InputModule + assert mod + + # Test entry function other than "main". + new_mod = DeadCodeElimination(entry_functions=["foo"])(mod) + assert check_if_func_exists(new_mod, "foo") + assert check_if_func_exists(new_mod, "tir_add") + assert not check_if_func_exists(new_mod, "unused_func") + + +def test_unused_relax_func_symbolic_shape(): + # Test with relax function w/ symbolic shape. + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_add( + x: T.Buffer((16, 16), "float32"), + y: T.Buffer((16, 16), "float32"), + z: T.Buffer((16, 16), "float32"), + ) -> None: + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): + gv0 = R.add(x, w) + return gv0 + + @R.function + def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): + m, k = T.int64(), T.int64() + gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((m + 1, k), dtype="float32")) + return gv0 + + mod = InputModule + assert mod + + new_mod = DeadCodeElimination()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "tir_add") + assert not check_if_func_exists(new_mod, "unused_func") + + +def test_unused_prim_func(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def unused_func( + x: T.Buffer((16, 16), "float32"), + y: T.Buffer((16, 16), "float32"), + z: T.Buffer((16, 16), "float32"), + ) -> None: + T.func_attr({"global_symbol": "tir_unused"}) + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def relax_add(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): + gv0 = R.add(x, w) + return gv0 + + @R.function + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = InputModule.relax_add(x, w) + return gv0 + + mod = InputModule + assert mod + new_mod = DeadCodeElimination()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "relax_add") + # RemoveUnusedFunction pass won't remove the function with global symbol for the external linkage. + assert check_if_func_exists(new_mod, "unused_func") + + +def test_multiple_unused_funcs(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def unused_func1( + x: T.Buffer((16, 16), "float32"), + y: T.Buffer((16, 16), "float32"), + z: T.Buffer((16, 16), "float32"), + ) -> None: + T.func_attr({"global_symbol": "tir_unused"}) + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def unused_func2(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): + gv0 = R.add(x, w) + return gv0 + + @R.function + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = R.add(x, w) + return gv0 + + mod = InputModule + assert mod + + new_mod = DeadCodeElimination()(mod) + assert check_if_func_exists(new_mod, "main") + # RemoveUnusedFunction pass won't remove the function with global symbol for the external linkage. + assert check_if_func_exists(new_mod, "unused_func1") + assert not check_if_func_exists(new_mod, "unused_func2") + + +def test_unused_dfb(): + # test if an unused dataflow block can be removed. + @tvm.script.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float16"): + # block 0 + with R.dataflow(): + lv0: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims( + x, axes=[0, 2, 3, 1] + ) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv0, lv1, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC" + ) + lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(lv2) + gv3 = R.astype(lv2, dtype="float16") + # dead block + with R.dataflow(): + lv4: R.Tensor((2, 4, 26, 26), dtype="float16") = R.permute_dims( + gv3, axes=[0, 3, 1, 2] + ) + R.output(lv4) + return gv3 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float16"): + # block 0 + with R.dataflow(): + lv0: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims( + x, axes=[0, 2, 3, 1] + ) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv0, lv1, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC" + ) + R.output(lv2) + gv3 = R.astype(lv2, dtype="float16") + return gv3 + + verify(Input, Expected) + + +def test_unused_dfb2(): + # test if an unused dataflow block can be removed. + @tvm.script.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float16"): + # dead block + with R.dataflow(): + lv0: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims( + x, axes=[0, 2, 3, 1] + ) + R.output(lv0) + + gv_x = R.astype(x, dtype="float16") + gv_w = R.astype(x, dtype="float16") + + with R.dataflow(): + lv1: R.Tensor((2, 28, 28, 3), dtype="float16") = R.permute_dims( + gv_x, axes=[0, 2, 3, 1] + ) + lv2: R.Tensor((4, 3, 3, 3), dtype="float16") = R.permute_dims( + gv_w, axes=[0, 2, 3, 1] + ) + lv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.nn.conv2d( + lv1, lv2, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC" + ) + # dead instruction -> usee lv1 also dead. + lv4: R.Tensor((2, 3, 28, 28), dtype="float32") = R.permute_dims( + lv0, axes=[0, 3, 1, 2] + ) + R.output(lv3) + return lv3 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float16"): + gv_x = R.astype(x, dtype="float16") + gv_w = R.astype(x, dtype="float16") + + with R.dataflow(): + lv1: R.Tensor((2, 28, 28, 3), dtype="float16") = R.permute_dims( + gv_x, axes=[0, 2, 3, 1] + ) + lv2: R.Tensor((4, 3, 3, 3), dtype="float16") = R.permute_dims( + gv_w, axes=[0, 2, 3, 1] + ) + lv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.nn.conv2d( + lv1, lv2, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC" + ) + R.output(lv3) + return lv3 + + verify(Input, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_decompose_composite_ops.py b/tests/python/relax/test_transform_decompose_composite_ops.py new file mode 100644 index 000000000000..08483600a3ed --- /dev/null +++ b/tests/python/relax/test_transform_decompose_composite_ops.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.relax import Function +from tvm.script import relax as R, tir as T + + +def _check(before: Union[Function, IRModule], expected: Union[Function, IRModule]): + if isinstance(before, Function): + before = IRModule({"main": before}) + if isinstance(expected, Function): + expected = IRModule({"main": expected}) + after = relax.transform.DecomposeCompositeOps()(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_batch_norm_simple(): + @R.function + def before( + x: R.Tensor((1, 64, 112, 112), "float32"), + gamma: R.Tensor((64,), "float32"), + beta: R.Tensor((64,), "float32"), + moving_mean: R.Tensor((64,), "float32"), + moving_var: R.Tensor((64,), "float32"), + ): + with R.dataflow(): + bn = R.nn.batch_norm( + x, + gamma, + beta, + moving_mean, + moving_var, + axis=1, + epsilon=1e-5, + center=True, + scale=True, + ) + gv = bn[0] + R.output(gv) + return gv + + @R.function + def expected( + x: R.Tensor((1, 64, 112, 112), "float32"), + gamma: R.Tensor((64,), "float32"), + beta: R.Tensor((64,), "float32"), + moving_mean: R.Tensor((64,), "float32"), + moving_var: R.Tensor((64,), "float32"), + ): + with R.dataflow(): + mean = R.expand_dims(moving_mean, axis=[0, 2, 3]) + out = x - mean + var = R.expand_dims(moving_var, axis=[0, 2, 3]) + var_eps = var + R.const(1e-05, "float32") + sqrt_var = R.sqrt(var_eps) + div = R.divide(out, sqrt_var) + new_gamma = R.expand_dims(gamma, axis=[0, 2, 3]) + out = div * new_gamma + new_beta = R.expand_dims(beta, axis=[0, 2, 3]) + out = out + new_beta + R.output(out) + return out + + _check(before, expected) + + +def test_batch_norm_complex(): + @R.function + def before( + x: R.Tensor((1, 64, 112, 112), "float32"), + gamma: R.Tensor((64,), "float32"), + beta: R.Tensor((64,), "float32"), + moving_mean: R.Tensor((64,), "float32"), + moving_var: R.Tensor((64,), "float32"), + ): + with R.dataflow(): + bn = R.nn.batch_norm( + x, + gamma, + beta, + moving_mean, + moving_var, + axis=1, + epsilon=1e-5, + center=True, + scale=True, + ) + gv0 = bn[0] + gv1 = bn[1] + R.output(gv0, gv1) + return gv0, gv1 + + @R.function + def expected( + x: R.Tensor((1, 64, 112, 112), "float32"), + gamma: R.Tensor((64,), "float32"), + beta: R.Tensor((64,), "float32"), + moving_mean: R.Tensor((64,), "float32"), + moving_var: R.Tensor((64,), "float32"), + ): + with R.dataflow(): + # bn[1] is used, so we need to keep the original batch_norm + # NOTE: It's a rare case, so that we don't optimize it for now + bn = R.nn.batch_norm( + x, + gamma, + beta, + moving_mean, + moving_var, + axis=1, + epsilon=1e-5, + center=True, + scale=True, + ) + mean = R.expand_dims(moving_mean, axis=[0, 2, 3]) + out = x - mean + var = R.expand_dims(moving_var, axis=[0, 2, 3]) + var_eps = var + R.const(1e-05, "float32") + sqrt_var = R.sqrt(var_eps) + div = R.divide(out, sqrt_var) + new_gamma = R.expand_dims(gamma, axis=[0, 2, 3]) + out = div * new_gamma + new_beta = R.expand_dims(beta, axis=[0, 2, 3]) + out = out + new_beta + gv1 = bn[1] + R.output(out, gv1) + return out, gv1 + + _check(before, expected) + + +def test_op_tensor_to_shape(): + @R.function + def before(t: R.Tensor(ndim=1, dtype="int64")): + gv: R.Shape(ndim=3) = R.tensor_to_shape(t) + return gv + + @R.function + def expected(t: R.Tensor(dtype="int64", ndim=1)) -> R.Shape(ndim=3): + x = T.int64() + x_1 = T.int64() + x_2 = T.int64() + gv: R.Shape(ndim=3) = R.call_packed( + "vm.builtin.tensor_to_shape", t, sinfo_args=(R.Shape(ndim=3),) + ) + y: R.Shape([x, x_1, x_2]) = R.match_cast(gv, R.Shape([x, x_1, x_2])) + gv_1: R.Shape([x, x_1, x_2]) = R.shape([x, x_1, x_2]) + return gv_1 + + _check(before, expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py new file mode 100644 index 000000000000..b8ad5c4487d3 --- /dev/null +++ b/tests/python/relax/test_transform_fold_constant.py @@ -0,0 +1,454 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax +import numpy as np + +import tvm.script +from tvm.script import ir as I, tir as T, relax as R + + +def gen_mod(mod, name, binding): + """Select relax function with name, rename to main and and bind constant. + + Parameters + ---------- + mod: IRModule + The input module + + name: str + The name of relax function to preserve and rename to main + + binding: Dict[str, array] + The const parameter bindings + """ + funcs = {} + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + + for k, v in mod.functions.items(): + if isinstance(v, tvm.relax.Function): + if k.name_hint == name: + # rename to main + gv = tvm.ir.GlobalVar("main") + funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_struct_info).with_attr( + "global_symbol", "main" + ) + else: + funcs[k] = v + mod = tvm.IRModule(funcs) + return relax.transform.BindParams("main", binding)(mod) + + +def test_one_fold_addone(): + # put before after in a single module + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None: + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @R.function + def before(c0: R.Tensor((16, 16), "float32")): + cls = Module + lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype="float32")) + return lv0 + + @R.function + def expected(c1: R.Tensor((16, 16), "float32")): + lv0 = c1 + return c1 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_one_fold_transpose(): + # put before after in a single module + @tvm.script.ir_module + class Module: + @T.prim_func + def func(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32")) -> None: + for i, j in T.grid(3, 2): + with T.block("transpose"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] + + @R.function + def before(c0: R.Tensor((2, 3), "float32")): + cls = Module + lv0 = relax.call_tir(cls.func, (c0,), R.Tensor((3, 2), dtype="float32")) + return lv0 + + @R.function + def expected(c1: R.Tensor((3, 2), "float32")): + lv0 = c1 + return c1 + + c0_np = np.arange(2 * 3).astype("float32").reshape(2, 3) + c1_np = c0_np.T + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_two_hop_addone(): + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer((2, 2), "float32"), B: T.Buffer((2, 2), "float32")) -> None: + for i, j in T.grid(2, 2): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @R.function + def before(c0: R.Tensor((2, 2), "float32")): + cls = Module + lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((2, 2), dtype="float32")) + lv1 = relax.call_tir(cls.addone, (lv0,), R.Tensor((2, 2), dtype="float32")) + return lv1 + + @R.function + def expected(c1: R.Tensor((2, 2), "float32"), c2: R.Tensor((2, 2), "float32")): + lv0 = c1 + lv1 = c2 + return c2 + + c0_np = np.arange((2 * 2)).astype("float32").reshape(2, 2) + c1_np = c0_np + 1 + c2_np = c1_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np, "c2": c2_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_dataflow_fold(): + @tvm.script.ir_module + class Module: + @T.prim_func + def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None: + for i, j in T.grid(16, 16): + with T.block("identity"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + + @R.function + def before(c0: R.Tensor((16, 16), "float32")): + cls = Module + with R.dataflow(): + gv0 = relax.call_tir(cls.identity, (c0,), R.Tensor((16, 16), dtype="float32")) + R.output(gv0) + return gv0 + + @R.function + def expected(c1: R.Tensor((16, 16), "float32")): + return c1 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_fold_mixed_case(): + @tvm.script.ir_module + class Module: + # TIR function can handle different cases. + @T.prim_func + def addone(a: T.handle, b: T.handle) -> None: + n = T.int32() + m = T.int32() + A = T.match_buffer(a, (n, m)) + B = T.match_buffer(b, (n, m)) + for i, j in T.grid(n, m): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @T.prim_func + def sub( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + ) -> None: + for i, j in T.grid(16, 16): + with T.block("sub"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] - B[vi, vj] + + @R.function + def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2)): + n, m = T.int64(), T.int64() + cls = Module + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) + # this line cannot be folded because n is unknown + lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16), dtype="float32")) + # this line can be folded + lv1 = relax.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype="float32")) + # this line can be folded because all inputs are const + lv2 = relax.call_tir(cls.sub, (c0, lv1), R.Tensor((16, 16), dtype="float32")) + # this line can not be folded because x's shape is unknown + lv3 = relax.call_tir(cls.sub, (lv2, x), R.Tensor((16, 16), dtype="float32")) + return lv3 + + @R.function + def expected( + c0: R.Tensor((16, 16), "float32"), + c1: R.Tensor((16, 16), "float32"), + c2: R.Tensor((16, 16), "float32"), + x: R.Tensor("float32", ndim=2), + ) -> R.Tensor: + n, m = T.int64(), T.int64() + cls = Module + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) + # this line cannot be folded because n is unknown + lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16), dtype="float32")) + # this line can be folded + lv1 = c1 + # this line can be folded because all inputs are const + lv2 = c2 + # this line can not be folded because x's shape is unknown + lv3 = relax.call_tir(cls.sub, (c2, x), R.Tensor((16, 16), dtype="float32")) + return lv3 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + 1 + c2_np = c0_np - c1_np + + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c0": c0_np, "c1": c1_np, "c2": c2_np}) + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_int32_fold(): + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + @R.function + def before(c0: R.Tensor((16, 16), "int32")): + cls = Module + lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype="int32")) + return lv0 + + @R.function + def expected(c1: R.Tensor((16, 16), "int32")): + lv0 = c1 + return c1 + + c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_fold_single_relax_op(): + # put before after in a single module + @tvm.script.ir_module + class Module: + @R.function + def before(c0: R.Tensor((16, 16), "float32")): + with R.dataflow(): + gv = R.add(c0, c0) + R.output(gv) + return gv + + @R.function + def expected(c1: R.Tensor((16, 16), "float32")): + return c1 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + c0_np + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_fold_multiple_relax_ops(): + # put before after in a single module + @tvm.script.ir_module + class Module: + @R.function + def before(c0: R.Tensor((16, 16), "float32"), c1: R.Tensor((16, 16), "float32")): + with R.dataflow(): + lv0 = R.add(c0, c1) + lv1 = R.multiply(c0, lv0) + gv = R.subtract(lv1, c1) + R.output(gv) + return gv + + @R.function + def expected(c4: R.Tensor((16, 16), "float32")): + return c4 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c2_np = c0_np + c1_np + c3_np = c0_np * c2_np + c4_np = c3_np - c1_np + before = gen_mod(Module, "before", {"c0": c0_np, "c1": c1_np}) + expected = gen_mod(Module, "expected", {"c4": c4_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_do_not_fold_ops_outside_dataflow(): + # put before after in a single module + @tvm.script.ir_module + class Module: + @R.function + def before(c0: R.Tensor((16, 16), "float32")): + gv = R.add(c0, c0) + return gv + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + before = gen_mod(Module, "before", {"c0": c0_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, before) + + +def test_fold_multiple_relax_ops_with_data_dependent_reshape(): + @tvm.script.ir_module + class Module: + @R.function + def before( + data: R.Tensor((256,), "float32"), + c0: R.Tensor((2,), "int64"), + c1: R.Tensor((2,), "int64"), + ): + with R.dataflow(): + lv0 = R.add(c0, c0) + target_shape = R.multiply(lv0, c1) + lv2: R.Shape(ndim=2) = R.tensor_to_shape(target_shape) + gv: R.Tensor(ndim=2, dtype="float32") = R.reshape(data, lv2) + R.output(gv) + return gv + + @R.function + def expected(data: R.Tensor((256,), "float32")) -> R.Tensor((16, 16), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((16, 16), dtype="float32") = R.reshape(data, R.shape([16, 16])) + R.output(gv) + return gv + + c0_np = [8, 8] + c1_np = [1, 1] + before = gen_mod(Module, "before", {"c0": c0_np, "c1": c1_np}) + assert relax.analysis.well_formed(before) + + c2_np = np.multiply(np.add(c0_np, c0_np), c1_np) + expected = gen_mod(Module, "expected", {"c2": c2_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_unsupported_fold_ops_legalized_to_multiple_calls(): + @tvm.script.ir_module + class Module: + @R.function + def before(c0: R.Tensor((16, 16), "float32")): + with R.dataflow(): + gv = R.nn.relu(c0) + R.output(gv) + return gv + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + before = gen_mod(Module, "before", {"c0": c0_np}) + + from tvm.relax.transform.legalize_ops.common import register_legalize + + def customized_legalize_relu(bb: relax.BlockBuilder, call: relax.Call): + from tvm import topi # pylint: disable=import-outside-toplevel + + x = bb.emit_te(topi.nn.relu, *call.args) + return bb.call_te(topi.identity, x) + + # register custom legalization for relu that emits multiple bindings for testing + relu_legalize = tvm.ir.Op.get("relax.nn.relu").get_attr("FLegalize") + tvm.ir.Op.get("relax.nn.relu").reset_attr("FLegalize") + register_legalize("relax.nn.relu", customized_legalize_relu) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, before) + + # revert to correct legalization of relu + tvm.ir.Op.get("relax.nn.relu").reset_attr("FLegalize") + register_legalize("relax.nn.relu", relu_legalize) + + +def test_fold_shape_computation(): + @I.ir_module + class Module: + @R.function + def before( + data: R.Tensor((5, 4, 3, 2), dtype="float32"), + indices: R.Tensor((1,), dtype="int64"), + ) -> R.Tensor((1, 1), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((4,), dtype="int64") = R.shape_to_tensor(R.shape([5, 4, 3, 2])) + lv1: R.Tensor((1,), dtype="int64") = R.take(lv, indices, axis=0) + lv2: R.Tensor((1, 1), dtype="int64") = R.expand_dims(lv1, axis=[0]) + gv: R.Tensor((1, 1), dtype="int64") = R.concat((lv2,), axis=0) + R.output(gv) + return gv + + @R.function + def expected( + data: R.Tensor((5, 4, 3, 2), dtype="float32"), new_shape: R.Tensor((1, 1), "int64") + ) -> R.Tensor((1, 1), dtype="int64"): + return new_shape + + before = gen_mod(Module, "before", {"indices": tvm.nd.array(np.array([0]).astype("int64"))}) + after = relax.transform.FoldConstant()(before) + np_take = np.take([5, 4, 3, 2], [0], axis=0) + np_expand = np.expand_dims(np_take, axis=[0]) + np_concat = np.concatenate([np_expand], axis=0) + expected = gen_mod(Module, "expected", {"new_shape": tvm.nd.array(np_concat)}) + tvm.ir.assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py new file mode 100644 index 000000000000..72f4e29a1690 --- /dev/null +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -0,0 +1,1295 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax, topi +from tvm.script import ir as I, relax as R, tir as T + + +def _check(mod_actual, mod_expected): + mod_actual = relax.transform.AnnotateTIROpPattern()(mod_actual) + mod_actual = relax.transform.FuseOps()(mod_actual) + mod_expected = relax.transform.AnnotateTIROpPattern()(mod_expected) + tvm.ir.assert_structural_equal(mod_actual, mod_expected) + + +def test_fuse_simple(): + """Simple testcase.""" + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call(fused_add_exp_squeeze, [x, relax.const(1, "float32")]) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_conv2d_fuse(): + """Test fusion case of conv2d""" + + def before(dtype): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) + w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) + w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit_te(topi.nn.conv2d, lv0, w1, strides=1, padding=1, dilation=1) + # this is the next dominator. + lv2 = bb.emit_te(topi.add, relax.const(1, dtype), lv1) + lv3 = bb.emit_te(topi.add, lv1, lv2) + # second path + lv4 = bb.emit_te(topi.nn.conv2d, lv3, w2, strides=1, padding=0, dilation=1) + lv5 = bb.emit_te(topi.nn.conv2d, lv3, w3, strides=1, padding=1, dilation=1) + gv = bb.emit_output(bb.call_te(topi.add, lv4, lv5)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dtype): + bb = relax.BlockBuilder() + + # Grouped function 1 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype)) + p0 = relax.Var("p0", R.Tensor((), dtype)) + with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d", + ) + lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1") + gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Grouped function 2 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype)) + y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype)) + with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=0, + dilation=1, + primfunc_name_hint="conv2d1", + ) + gv = bb.emit_output(bb.call_te(topi.add, lv0, y, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2") + fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2") + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) + w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) + w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, relax.const(1, dtype)])) + lv2 = bb.emit_te( + topi.nn.conv2d, + lv1, + w3, + strides=1, + padding=1, + dilation=1, + ) + gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2, lv2])) + bb.emit_func_output(gv) + + return bb.get() + + _check(before("float32"), expected("float32")) + _check(before("float16"), expected("float16")) + _check(before("int8"), expected("int8")) + + +def test_concatenate(): + """Test fusion case involving concat op and Tuple node""" + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + x, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + lv1 = bb.emit_te(topi.nn.upsampling, lv0, scale_h=2.0, scale_w=2.0) + lv2 = bb.emit_te(topi.concatenate, (lv1, x), axis=1) + gv = bb.emit_output(bb.call_te(topi.add, lv2, relax.const(1, "float32"))) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + w = relax.Var("w", R.Tensor((1, 16, 32, 32), "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_upsampling_concatenate_add", [w, x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.upsampling, w, scale_h=2.0, scale_w=2.0) + lv1 = bb.emit_te(topi.concatenate, (lv0, x), axis=1) + gv = bb.emit_output(bb.call_te(topi.add, lv1, p0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_upsampling_concatenate_add = bb.get().get_global_var( + "fused_upsampling_concatenate_add" + ) + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + x, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + gv = bb.emit_output( + relax.Call( + fused_upsampling_concatenate_add, (lv0, x, relax.const(1, "float32")) + ) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_tuple_root(): + """Test fusion case where Tuple node is the root in its group""" + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + x, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + lv1 = bb.emit_te(topi.nn.upsampling, lv0, scale_h=2.0, scale_w=2.0) + gv = bb.emit_output((lv1, x)) + bb.emit_func_output(gv) + + return bb.get() + + # The fusion is supposed to make no change. + _check(before(), before()) + + +def test_fuse_tuple_get_elemwise(): + def before(dim: int): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, dim), "float32")) + w = relax.Var("w", R.Tensor((3 * dim, dim), "float32")) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.dense, x, w) + lv1 = bb.emit_te(topi.split, lv0, indices_or_sections=3, axis=1) + lv2 = bb.emit(relax.TupleGetItem(lv1, 0)) + lv3 = bb.emit_te(topi.sigmoid, lv2) + lv4 = bb.emit(relax.TupleGetItem(lv1, 1)) + lv5 = bb.emit_te(topi.tanh, lv4) + lv6 = bb.emit(relax.TupleGetItem(lv1, 2)) + lv7 = bb.emit_te(topi.exp, lv6) + lv8 = bb.emit_te(topi.multiply, lv5, lv7) + gv = bb.emit_output(bb.call_te(topi.add, lv3, lv8)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dim: int): + bb = relax.BlockBuilder() + + # Grouped function + dense = relax.Var("dense", R.Tensor((1, 3 * dim), "float32")) + with bb.function( + "fused_split_sigmoid_tanh_exp_multiply_add", [dense], attrs={"Primitive": 1} + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.split, dense, indices_or_sections=3, axis=1) + lv1 = bb.emit(relax.TupleGetItem(lv0, 0)) + lv2 = bb.emit_te(topi.sigmoid, lv1) + lv3 = bb.emit(relax.TupleGetItem(lv0, 1)) + lv4 = bb.emit_te(topi.tanh, lv3) + lv5 = bb.emit(relax.TupleGetItem(lv0, 2)) + lv6 = bb.emit_te(topi.exp, lv5) + lv7 = bb.emit_te(topi.multiply, lv4, lv6) + gv = bb.emit_output(bb.call_te(topi.add, lv2, lv7)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_split_sigmoid_tanh_exp_multiply_add = bb.get().get_global_var( + "fused_split_sigmoid_tanh_exp_multiply_add" + ) + + # Main function + x = relax.Var("x", R.Tensor((1, dim), "float32")) + w = relax.Var("w", R.Tensor((3 * dim, dim), "float32")) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.dense, x, w) + gv = bb.emit_output(relax.Call(fused_split_sigmoid_tanh_exp_multiply_add, (lv0,))) + bb.emit_func_output(gv) + + return bb.get() + + dim = 10 + _check(before(dim), expected(dim)) + + +def test_tuple_get_root(): + def before(dim: int): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, 3 * dim), "float32")) + w = relax.Var("w", R.Tensor((dim, dim), "float32")) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1) + lv1 = bb.emit(relax.TupleGetItem(lv0, 0)) + gv = bb.emit_output(bb.call_te(topi.nn.dense, lv1, w)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dim: int): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((1, 3 * dim), "float32")) + with bb.function("fused_split", [x], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1) + gv = bb.emit_output(relax.TupleGetItem(lv0, 0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_split = bb.get().get_global_var("fused_split") + + # Main function + x = relax.Var("x", R.Tensor((1, 3 * dim), "float32")) + w = relax.Var("w", R.Tensor((dim, dim), "float32")) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit(relax.Call(fused_split, (x,))) + gv = bb.emit_output(bb.call_te(topi.nn.dense, lv0, w)) + bb.emit_func_output(gv) + + return bb.get() + + dim = 10 + _check(before(dim), expected(dim)) + + +def test_tuple_intermediate(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.squeeze, x) + lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32")) + lv2 = bb.emit_te(topi.squeeze, lv0) + lv3 = bb.emit_te(topi.add, lv2, relax.const(1, "float32")) + lv4 = bb.emit_te(topi.add, lv3, relax.const(1, "float32")) + lv5 = bb.emit_te(topi.add, lv0, relax.const(1, "float32")) + lv6 = bb.emit_te(topi.concatenate, (lv1, lv4, lv5), axis=1) + lv7 = bb.emit_te(topi.squeeze, lv6) + gv = bb.emit_output(bb.call_te(topi.add, lv7, relax.const(1, "float32"))) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + p1 = relax.Var("p1", R.Tensor((), "float32")) + p2 = relax.Var("p2", R.Tensor((), "float32")) + p3 = relax.Var("p3", R.Tensor((), "float32")) + p4 = relax.Var("p4", R.Tensor((), "float32")) + with bb.function( + "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1", + [x, p0, p1, p2, p3, p4], + attrs={"Primitive": 1}, + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.squeeze, x) + lv1 = bb.emit_te(topi.add, lv0, p0) + lv2 = bb.emit_te(topi.squeeze, lv0) + lv3 = bb.emit_te(topi.add, lv2, p1) + lv4 = bb.emit_te(topi.add, lv3, p2) + lv5 = bb.emit_te(topi.add, lv0, p3) + lv6 = bb.emit_te(topi.concatenate, (lv1, lv4, lv5), axis=1) + lv7 = bb.emit_te(topi.squeeze, lv6) + gv = bb.emit_output(bb.call_te(topi.add, lv7, p4)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_func = bb.get().get_global_var( + "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1" + ) + + # Main func + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call( + fused_func, + ( + x, + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + ), + ) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_tuple_consecutive(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv1 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv2 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv3 = bb.emit_te(topi.concatenate, (lv0, lv1, lv2), axis=1) + lv4 = bb.emit_te(topi.add, lv3, relax.const(1, "float32")) + lv5 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv6 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv7 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv8 = bb.emit_te(topi.concatenate, (lv5, lv6, lv7), axis=1) + lv9 = bb.emit_te(topi.add, lv8, relax.const(1, "float32")) + lv10 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv11 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv12 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv13 = bb.emit_te(topi.concatenate, (lv10, lv11, lv12), axis=1) + lv14 = bb.emit_te(topi.add, lv13, relax.const(1, "float32")) + lv15 = bb.emit_te(topi.concatenate, (lv4, lv9, lv14), axis=1) + lv16 = bb.emit_te( + topi.nn.pool2d, + lv15, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + lv17 = bb.emit_te(topi.add, lv16, relax.const(1, "float32")) + lv18 = bb.emit_te(topi.add, lv17, relax.const(1, "float32")) + gv = bb.emit_output((lv17, lv18)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function 1 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + p1 = relax.Var("p1", R.Tensor((), "float32")) + p2 = relax.Var("p2", R.Tensor((), "float32")) + p3 = relax.Var("p3", R.Tensor((), "float32")) + p4 = relax.Var("p4", R.Tensor((), "float32")) + p5 = relax.Var("p5", R.Tensor((), "float32")) + p6 = relax.Var("p6", R.Tensor((), "float32")) + p7 = relax.Var("p7", R.Tensor((), "float32")) + p8 = relax.Var("p8", R.Tensor((), "float32")) + p9 = relax.Var("p9", R.Tensor((), "float32")) + p10 = relax.Var("p10", R.Tensor((), "float32")) + p11 = relax.Var("p11", R.Tensor((), "float32")) + with bb.function( + "fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1", + [x, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11], + attrs={"Primitive": 1}, + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.add, x, p1) + lv2 = bb.emit_te(topi.add, x, p2) + lv3 = bb.emit_te(topi.concatenate, (lv0, lv1, lv2), axis=1) + lv4 = bb.emit_te(topi.add, lv3, p3) + lv5 = bb.emit_te(topi.add, x, p4) + lv6 = bb.emit_te(topi.add, x, p5) + lv7 = bb.emit_te(topi.add, x, p6) + lv8 = bb.emit_te(topi.concatenate, (lv5, lv6, lv7), axis=1) + lv9 = bb.emit_te(topi.add, lv8, p7) + lv10 = bb.emit_te(topi.add, x, p8) + lv11 = bb.emit_te(topi.add, x, p9) + lv12 = bb.emit_te(topi.add, x, p10) + lv13 = bb.emit_te(topi.concatenate, (lv10, lv11, lv12), axis=1) + lv14 = bb.emit_te(topi.add, lv13, p11) + gv = bb.emit_output(bb.call_te(topi.concatenate, (lv4, lv9, lv14), axis=1)) + bb.emit_func_output(gv) + + # Grouped function 2 + concat = relax.Var("concat", R.Tensor((1, 144, 64, 64), "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_pool2d_add2", [concat, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + concat, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + gv = bb.emit_output(bb.call_te(topi.add, lv0, p0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_func1 = mod.get_global_var( + "fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1" + ) + fused_func2 = mod.get_global_var("fused_pool2d_add2") + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit( + relax.Call( + fused_func1, + ( + x, + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + ), + ) + ) + lv1 = bb.emit(relax.Call(fused_func2, (lv0, relax.const(1, "float32")))) + lv2 = bb.emit_te(topi.add, lv1, relax.const(1, "float32")) + gv = bb.emit_output((lv1, lv2)) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_inception_like(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + w0 = relax.Var("w0", R.Tensor((16, 16, 3, 3), "float32")) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), "float32")) + w2 = relax.Var("w2", R.Tensor((16, 32, 3, 3), "float32")) + w3 = relax.Var("w3", R.Tensor((16, 32, 3, 3), "float32")) + with bb.function("main", [x, w0, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.conv2d, x, w0, strides=1, padding=1, dilation=1) + lv1 = bb.emit_te(topi.nn.relu, lv0) + lv2 = bb.emit_te(topi.nn.conv2d, x, w1, strides=1, padding=1, dilation=1) + lv3 = bb.emit_te(topi.nn.relu, lv2) + lv4 = bb.emit_te(topi.concatenate, (lv1, lv3), axis=1) + lv5 = bb.emit_te(topi.nn.conv2d, lv4, w2, strides=1, padding=1, dilation=1) + lv6 = bb.emit_te(topi.nn.relu, lv5) + lv7 = bb.emit_te(topi.nn.conv2d, lv4, w3, strides=1, padding=1, dilation=1) + lv8 = bb.emit_te(topi.nn.relu, lv7) + gv = bb.emit_output(bb.call_te(topi.concatenate, (lv6, lv8), axis=1)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function 1 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + w = relax.Var("w", R.Tensor((16, 16, 3, 3), "float32")) + with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d", + ) + gv = bb.emit_output(bb.call_te(topi.nn.relu, lv0)) + bb.emit_func_output(gv) + + # Grouped function 2 + x = relax.Var("x", R.Tensor((1, 32, 64, 64), "float32")) + w = relax.Var("w", R.Tensor((16, 32, 3, 3), "float32")) + with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d1", + ) + gv = bb.emit_output(bb.call_te(topi.nn.relu, lv0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_conv2d_relu1 = mod.get_global_var("fused_conv2d_relu") + fused_conv2d_relu2 = mod.get_global_var("fused_conv2d1_relu") + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + w0 = relax.Var("w0", R.Tensor((16, 16, 3, 3), "float32")) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), "float32")) + w2 = relax.Var("w2", R.Tensor((16, 32, 3, 3), "float32")) + w3 = relax.Var("w3", R.Tensor((16, 32, 3, 3), "float32")) + with bb.function("main", [x, w0, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit(relax.Call(fused_conv2d_relu1, (x, w0))) + lv1 = bb.emit(relax.Call(fused_conv2d_relu1, (x, w1))) + lv2 = bb.emit_te(topi.concatenate, (lv0, lv1), axis=1) + lv3 = bb.emit(relax.Call(fused_conv2d_relu2, (lv2, w2))) + lv4 = bb.emit(relax.Call(fused_conv2d_relu2, (lv2, w3))) + gv = bb.emit_output(bb.call_te(topi.concatenate, (lv3, lv4), axis=1)) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_fuse_parallel_injective(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((10, 20), "int32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "int32")) + lv1 = bb.emit_te(topi.squeeze, lv0) + lv2 = bb.emit_te(topi.transpose, lv0, axes=[1, 0]) + lv3 = bb.emit_te(topi.transpose, lv2, axes=[1, 0]) + gv = bb.emit_output(bb.call_te(topi.left_shift, lv1, lv3)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((10, 20), "int32")) + p0 = relax.Var("p0", R.Tensor((), "int32")) + with bb.function( + "fused_add_squeeze_transpose_transpose1_left_shift", [x, p0], attrs={"Primitive": 1} + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.squeeze, lv0) + lv2 = bb.emit_te(topi.transpose, lv0, axes=[1, 0]) + lv3 = bb.emit_te(topi.transpose, lv2, axes=[1, 0], primfunc_name_hint="transpose1") + gv = bb.emit_output(bb.call_te(topi.left_shift, lv1, lv3)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_func = bb.get().get_global_var("fused_add_squeeze_transpose_transpose1_left_shift") + + # Main function + x = relax.Var("x", R.Tensor((10, 20), "int32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_func, (x, relax.const(1, "int32")))) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_softmax(): + """Test if softmax can be fused with following ops.""" + + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((16, 16), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.softmax, x) + gv = bb.emit_output(bb.call_te(topi.cast, lv0, dtype="float16")) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((16, 16), "float32")) + with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.softmax, x) + gv = bb.emit_output(bb.call_te(topi.cast, lv0, dtype="float16")) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_func = bb.get().get_global_var("fused_softmax_cast") + + # Main function + x = relax.Var("x", R.Tensor((16, 16), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_func, (x,))) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_multiple_relax_functions(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("func1", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + + x = relax.Var("x", R.Tensor([20, 10], "float32")) + with bb.function("func2", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") + + x = relax.Var("x", R.Tensor([20, 10], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add1_exp1_squeeze1 = bb.get().get_global_var("fused_add1_exp1_squeeze1") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("func1", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call(fused_add_exp_squeeze, [x, relax.const(1, "float32")]) + ) + bb.emit_func_output(gv) + + x = relax.Var("x", R.Tensor([20, 10], "float32")) + with bb.function("func2", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call(fused_add1_exp1_squeeze1, [x, relax.const(1, "float32")]) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_skip_call_dps_packed(): + @I.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + y = R.call_dps_packed("func_packed_dps", x, R.Tensor((2, 3), "float32")) + R.output(y) + return y + + # FuseOps should does no change to it. + _check(Module, Module) + + +def test_edge_with_call_dps_packed(): + @I.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + cls = Module + with R.dataflow(): + a = R.call_tir(cls.exp, (x,), out_sinfo=R.Tensor((2, 3), "float32")) + b = R.call_tir(cls.exp, (a,), out_sinfo=R.Tensor((2, 3), "float32")) + c = R.call_dps_packed("packed_dps", (a,), out_sinfo=R.Tensor((2, 3), "float32")) + R.output(b, c) + return R.tuple(b, c) + + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + # FuseOps should does no change to it. + _check(Module, Module) + + +def test_layer_norm_silu(): + # fmt: off + @I.ir_module + class Module: + @R.function + def main(x: R.Tensor((1, 512, 64, 64), "float32"), mean: R.Tensor((64, 64), "float32"), var: R.Tensor((64, 64), "float32")): + cls = Module + with R.dataflow(): + gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64))) + gv1 = R.call_tir(cls.relu, gv0, out_sinfo=R.Tensor((1, 512, 64, 64), "float32")) + R.output(gv1) + return gv1 + + @T.prim_func + def layer_norm(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), gamma: T.Buffer((T.int64(64), T.int64(64)), "float32"), beta: T.Buffer((T.int64(64), T.int64(64)), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")): + rxplaceholder_red_temp_v0 = T.alloc_buffer([T.int64(64), T.int64(64)], dtype="float32") + rxplaceholder_red_temp_v1 = T.alloc_buffer([T.int64(64), T.int64(64)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): + with T.block("rxplaceholder_red_temp"): + ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, k2, k3]) + T.writes(rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1]) + with T.init(): + rxplaceholder_red_temp_v0[ax0, ax1] = T.float32(0) + rxplaceholder_red_temp_v1[ax0, ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[ax0, ax1] + A[ax0, ax1, k2, k3] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[ax0, ax1] + A[ax0, ax1, k2, k3] * A[ax0, ax1, k2, k3] + rxplaceholder_red_temp_v0[ax0, ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[ax0, ax1] = v_rxplaceholder_red_temp_v1 + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): + with T.block("T_layer_norm"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, ax2, ax3], rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], gamma[ax2, ax3], beta[ax2, ax3]) + T.writes(T_layer_norm[ax0, ax1, ax2, ax3]) + T_layer_norm[ax0, ax1, ax2, ax3] = (A[ax0, ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) * T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] * T.float32(0.05) - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05) * (rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) + T.float32(1e-05), dtype="float32") * gamma[ax2, ax3] + beta[ax2, ax3] + + @T.prim_func + def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), B: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")): + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): + with T.block("relu"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[v_i0, v_i1, v_i2, v_i3]) + T.writes(B[v_i0, v_i1, v_i2, v_i3]) + B[v_i0, v_i1, v_i2, v_i3] = T.max(A[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + @I.ir_module + class Expected: + @T.prim_func + def layer_norm(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), gamma: T.Buffer((T.int64(64), T.int64(64)), "float32"), beta: T.Buffer((T.int64(64), T.int64(64)), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")): + T.func_attr({"op_pattern": 4}) + # with T.block("root"): + rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(64), T.int64(64))) + rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(64), T.int64(64))) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): + with T.block("rxplaceholder_red_temp"): + ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, k2, k3]) + T.writes(rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1]) + with T.init(): + rxplaceholder_red_temp_v0[ax0, ax1] = T.float32(0) + rxplaceholder_red_temp_v1[ax0, ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[ax0, ax1] + A[ax0, ax1, k2, k3] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[ax0, ax1] + A[ax0, ax1, k2, k3] * A[ax0, ax1, k2, k3] + rxplaceholder_red_temp_v0[ax0, ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[ax0, ax1] = v_rxplaceholder_red_temp_v1 + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): + with T.block("T_layer_norm"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, ax2, ax3], rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], gamma[ax2, ax3], beta[ax2, ax3]) + T.writes(T_layer_norm[ax0, ax1, ax2, ax3]) + T_layer_norm[ax0, ax1, ax2, ax3] = (A[ax0, ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.050000000000000003)) * T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] * T.float32(0.050000000000000003) - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.050000000000000003) * (rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.050000000000000003)) + T.float32(1.0000000000000001e-05)) * gamma[ax2, ax3] + beta[ax2, ax3] + + @T.prim_func + def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), B: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")): + T.func_attr({"op_pattern": 0}) + # with T.block("root"): + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): + with T.block("relu"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[v_i0, v_i1, v_i2, v_i3]) + T.writes(B[v_i0, v_i1, v_i2, v_i3]) + B[v_i0, v_i1, v_i2, v_i3] = T.max(A[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + @R.function + def fused_layer_norm_relu(x: R.Tensor((1, 512, 64, 64), dtype="float32"), mean: R.Tensor((64, 64), dtype="float32"), var: R.Tensor((64, 64), dtype="float32")) -> R.Tensor((1, 512, 64, 64), dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Expected + with R.dataflow(): + gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64))) + gv = R.call_tir(cls.relu, (gv0,), out_sinfo=R.Tensor((1, 512, 64, 64), dtype="float32")) + R.output(gv) + return gv + + @R.function + def main(x: R.Tensor((1, 512, 64, 64), dtype="float32"), mean: R.Tensor((64, 64), dtype="float32"), var: R.Tensor((64, 64), dtype="float32")) -> R.Tensor((1, 512, 64, 64), dtype="float32"): + cls = Expected + with R.dataflow(): + gv: R.Tensor((1, 512, 64, 64), dtype="float32") = cls.fused_layer_norm_relu(x, mean, var) + R.output(gv) + return gv + # fmt: on + + _check(Module, Expected) + + +def test_multiple_paths(): + # fmt: off + @I.ir_module + class Module: + @R.function + def main( + inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), + inp_1: R.Tensor((2, 1280), dtype="float32"), + w1: R.Tensor((320, 320, 3, 3), dtype="float32"), + b1: R.Tensor((320,), "float32"), + w2: R.Tensor((320, 1280), "float32"), + b2: R.Tensor((320,), "float32"), + ): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv27: R.Tensor((2, 320, 64, 64), dtype="float32") = R.nn.conv2d(inp_0, w1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32") + lv28: R.Tensor((1, 320, 1, 1), dtype="float32") = R.reshape(b1, R.shape([1, 320, 1, 1])) ## + lv29: R.Tensor((2, 320, 64, 64), dtype="float32") = R.add(lv27, lv28) + lv31: R.Tensor((1280, 320), dtype="float32") = R.permute_dims(w2, axes=None) ## + lv32: R.Tensor((2, 320), dtype="float32") = R.matmul(inp_1, lv31, out_dtype="float32") + lv33: R.Tensor((2, 320), dtype="float32") = R.add(lv32, b2) + lv35: R.Tensor((2, 320, 1, 1), dtype="float32") = R.reshape(lv33, R.shape([2, 320, 1, 1])) + lv36: R.Tensor((2, 320, 64, 64), dtype="float32") = R.add(lv29, lv35) + gv = lv36 + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1), T.int64(320), T.int64(1), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32")): + T.func_attr({"op_pattern": 0, "tir.noalias": True}) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320), T.int64(64), T.int64(64)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_1[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] + rxplaceholder_1[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + + @T.prim_func + def add1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)), "float32"), rxplaceholder_1: T.Buffer((T.int64(320),), "float32"), T_add: T.Buffer((T.int64(2), T.int64(320)), "float32")): + T.func_attr({"op_pattern": 0, "tir.noalias": True}) + for ax0, ax1 in T.grid(T.int64(2), T.int64(320)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1] + + @T.prim_func + def add2(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(320), T.int64(1), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32")): + T.func_attr({"op_pattern": 0, "tir.noalias": True}) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320), T.int64(64), T.int64(64)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] + rxplaceholder_1[v_ax0, v_ax1, T.int64(0), T.int64(0)] + + @T.prim_func + def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(320), T.int64(320), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32")): + T.func_attr({"op_pattern": 4, "tir.noalias": True}) + pad_temp = T.alloc_buffer((T.int64(2), T.int64(320), T.int64(66), T.int64(66))) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(320), T.int64(66), T.int64(66)): + with T.block("pad_temp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)]) + T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3]) + pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(65) and T.int64(1) <= v_i3 and v_i3 < T.int64(65), rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float32(0)) + for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(2), T.int64(320), T.int64(64), T.int64(64), T.int64(320), T.int64(3), T.int64(3)): + with T.block("conv2d_nchw"): + v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx]) + T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], rxplaceholder_1[v_ff, v_rc, v_ry, v_rx]) + T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx]) + with T.init(): + conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0) + conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * rxplaceholder_1[v_ff, v_rc, v_ry, v_rx] + + @T.prim_func + def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(1280)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1280), T.int64(320)), "float32"), matmul: T.Buffer((T.int64(2), T.int64(320)), "float32")): + T.func_attr({"op_pattern": 4, "tir.noalias": True}) + for i0, i1, k in T.grid(T.int64(2), T.int64(320), T.int64(1280)): + with T.block("matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, v_i1]) + T.writes(matmul[v_i0, v_i1]) + with T.init(): + matmul[v_i0, v_i1] = T.float32(0) + matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1] + + @T.prim_func + def reshape(rxplaceholder: T.Buffer((T.int64(320),), "float32"), T_reshape: T.Buffer((T.int64(1), T.int64(320), T.int64(1), T.int64(1)), "float32")): + T.func_attr({"op_pattern": 2, "tir.noalias": True}) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(320), T.int64(1), T.int64(1)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[(v_ax1 + v_ax2 + v_ax3) % T.int64(320)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[(v_ax1 + v_ax2 + v_ax3) % T.int64(320)] + + @T.prim_func + def reshape1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(320), T.int64(1), T.int64(1)), "float32")): + T.func_attr({"op_pattern": 2, "tir.noalias": True}) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320), T.int64(1), T.int64(1)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[((v_ax1 + v_ax2 + v_ax3) // T.int64(320) + v_ax0) % T.int64(2), (v_ax1 + v_ax2 + v_ax3) % T.int64(320)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[((v_ax1 + v_ax2 + v_ax3) // T.int64(320) + v_ax0) % T.int64(2), (v_ax1 + v_ax2 + v_ax3) % T.int64(320)] + + @T.prim_func + def transpose(rxplaceholder: T.Buffer((T.int64(320), T.int64(1280)), "float32"), T_transpose: T.Buffer((T.int64(1280), T.int64(320)), "float32")): + T.func_attr({"op_pattern": 2, "tir.noalias": True}) + for ax0, ax1 in T.grid(T.int64(1280), T.int64(320)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0] + + @R.function + def fused_conv2d_add_add2(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), w1: R.Tensor((320, 320, 3, 3), dtype="float32"), lv28: R.Tensor((1, 320, 1, 1), dtype="float32"), lv35: R.Tensor((2, 320, 1, 1), dtype="float32")) -> R.Tensor((2, 320, 64, 64), dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Expected + with R.dataflow(): + lv27 = R.call_tir(cls.conv2d, (inp_0, w1), out_sinfo=R.Tensor((2, 320, 64, 64), dtype="float32")) + lv29 = R.call_tir(cls.add, (lv27, lv28), out_sinfo=R.Tensor((2, 320, 64, 64), dtype="float32")) + gv = R.call_tir(cls.add2, (lv29, lv35), out_sinfo=R.Tensor((2, 320, 64, 64), dtype="float32")) + R.output(gv) + return gv + + @R.function + def fused_matmul_add1(inp_1: R.Tensor((2, 1280), dtype="float32"), lv31: R.Tensor((1280, 320), dtype="float32"), b2: R.Tensor((320,), dtype="float32")) -> R.Tensor((2, 320), dtype="float32"): + cls = Expected + R.func_attr({"Primitive": 1}) + with R.dataflow(): + lv32 = R.call_tir(cls.matmul, (inp_1, lv31), out_sinfo=R.Tensor((2, 320), dtype="float32")) + gv = R.call_tir(cls.add1, (lv32, b2), out_sinfo=R.Tensor((2, 320), dtype="float32")) + R.output(gv) + return gv + + @R.function + def main(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), inp_1: R.Tensor((2, 1280), dtype="float32"), w1: R.Tensor((320, 320, 3, 3), dtype="float32"), b1: R.Tensor((320,), dtype="float32"), w2: R.Tensor((320, 1280), dtype="float32"), b2: R.Tensor((320,), dtype="float32")) -> R.Tensor((2, 320, 64, 64), dtype="float32"): + R.func_attr({"num_input": 2}) + cls = Expected + with R.dataflow(): + lv28 = R.call_tir(cls.reshape, (b1,), out_sinfo=R.Tensor((1, 320, 1, 1), dtype="float32")) + lv31 = R.call_tir(cls.transpose, (w2,), out_sinfo=R.Tensor((1280, 320), dtype="float32")) + lv: R.Tensor((2, 320), dtype="float32") = cls.fused_matmul_add1(inp_1, lv31, b2) + lv35 = R.call_tir(cls.reshape1, (lv,), out_sinfo=R.Tensor((2, 320, 1, 1), dtype="float32")) + lv1: R.Tensor((2, 320, 64, 64), dtype="float32") = cls.fused_conv2d_add_add2(inp_0, w1, lv28, lv35) + gv: R.Tensor((2, 320, 64, 64), dtype="float32") = lv1 + R.output(gv) + return gv + # fmt: on + + mod = relax.transform.LegalizeOps()(Module) + mod = relax.transform.AnnotateTIROpPattern()(mod) + mod = relax.transform.FuseOps()(mod) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_dead_group(): + + # fmt: off + + @I.ir_module + class Module: + @R.function + def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: R.Tensor((1, 128), dtype="float32"), linear1_bias: R.Tensor((128,), dtype="float32"), linear1_weight: R.Tensor((128, 784), dtype="float32"), linear2_bias: R.Tensor((10,), dtype="float32"), linear2_weight: R.Tensor((10, 128), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(linear1_weight, axes=None) + lv1: R.Tensor((1, 128), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32") + lv2: R.Tensor((1, 128), dtype="float32") = R.add(lv1, linear1_bias) + lv3: R.Tensor((1, 128), dtype="float32") = R.nn.relu(lv2) + lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(linear2_weight, axes=None) + lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(inp_1, lv4, out_dtype="float32") + lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, linear2_bias) + gv: R.Tensor((1, 10), dtype="float32") = lv6 + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")): + T.func_attr({"op_pattern": 0, "tir.noalias": True}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(128)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1] + + @T.prim_func + def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceholder_1: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")): + T.func_attr({"op_pattern": 0, "tir.noalias": True}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(10)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1] + + @T.prim_func + def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxplaceholder_1: T.Buffer((T.int64(784), T.int64(128)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(128)), "float32")): + T.func_attr({"op_pattern": 4, "tir.noalias": True}) + # with T.block("root"): + for i0, i1, k in T.grid(T.int64(1), T.int64(128), T.int64(784)): + with T.block("matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, v_i1]) + T.writes(matmul_1[v_i0, v_i1]) + with T.init(): + matmul_1[v_i0, v_i1] = T.float32(0) + matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1] + + @T.prim_func + def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")): + T.func_attr({"op_pattern": 4, "tir.noalias": True}) + # with T.block("root"): + for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(128)): + with T.block("matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, v_i1]) + T.writes(matmul[v_i0, v_i1]) + with T.init(): + matmul[v_i0, v_i1] = T.float32(0) + matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1] + + @T.prim_func + def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")): + T.func_attr({"op_pattern": 0, "tir.noalias": True}) + # with T.block("root"): + for i0, i1 in T.grid(T.int64(1), T.int64(128)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0)) + + @T.prim_func + def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")): + T.func_attr({"op_pattern": 2, "tir.noalias": True}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(784), T.int64(128)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0] + + @T.prim_func + def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")): + T.func_attr({"op_pattern": 2, "tir.noalias": True}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(128), T.int64(10)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0] + + @R.function + def fused_matmul1_add1(inp_1: R.Tensor((1, 128), dtype="float32"), lv4: R.Tensor((128, 10), dtype="float32"), linear2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Expected + with R.dataflow(): + lv5 = R.call_tir(cls.matmul1, (inp_1, lv4), out_sinfo=R.Tensor((1, 10), dtype="float32")) + gv = R.call_tir(cls.add1, (lv5, linear2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32")) + R.output(gv) + return gv + + @R.function + def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: R.Tensor((1, 128), dtype="float32"), linear1_bias: R.Tensor((128,), dtype="float32"), linear1_weight: R.Tensor((128, 784), dtype="float32"), linear2_bias: R.Tensor((10,), dtype="float32"), linear2_weight: R.Tensor((10, 128), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + lv = R.call_tir(cls.transpose, (linear1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32")) + lv4 = R.call_tir(cls.transpose1, (linear2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32")) + lv_1: R.Tensor((1, 10), dtype="float32") = cls.fused_matmul1_add1(inp_1, lv4, linear2_bias) + gv: R.Tensor((1, 10), dtype="float32") = lv_1 + R.output(gv) + return gv + + # fmt: on + + mod = relax.transform.LegalizeOps()(Module) + _check(mod, Expected) + + +def test_symbolic_shape_aware_fuse(): + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor(["n", "m"], "float32")): + with R.dataflow(): + lv0 = R.emit_te(topi.add, x, R.const(1, "float32")) + lv1 = R.emit_te(topi.exp, lv0) + gv = R.emit_te(topi.squeeze, lv1) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def fused_add_exp_squeeze( + x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32") + ) -> R.Tensor(["n", "m"], dtype="float32"): + R.func_attr({"Primitive": 1}) + with R.dataflow(): + lv0 = R.emit_te(topi.add, x, p0) + lv1 = R.emit_te(topi.exp, lv0) + gv = R.emit_te(topi.squeeze, lv1) + R.output(gv) + return gv + + @R.function + def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], dtype="float32"): + cls = Expected + with R.dataflow(): + gv = cls.fused_add_exp_squeeze(x, R.const(1, "float32")) + R.output(gv) + return gv + + _check(Before, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py new file mode 100644 index 000000000000..2f3e2d479ff4 --- /dev/null +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -0,0 +1,675 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest + +import tvm +from tvm import relax +from tvm.relax.dpl.pattern import is_op, make_fused_bias_activation_pattern, wildcard +from tvm.relax.transform import PatternCheckContext +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +@tvm.script.ir_module +class Conv2dReLU: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1))) + R.output(conv1) + + return conv1 + + +@tvm.script.ir_module +class Conv2dReLU_composite_annotated: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + cls = Conv2dReLU_composite_annotated + with R.dataflow(): + gv: R.Tensor( + (1, 64, 56, 56), dtype="float32" + ) = cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu_dnnl( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr( + {"Codegen": "dnnl", "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu_dnnl"} + ) + + @R.function + def gv1( + data2: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight12: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data2, + weight12, + padding=[1, 1, 1, 1], + ) + gv2: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv) + R.output(gv2) + return gv2 + + gv11: R.Tensor((1, 64, 56, 56), dtype="float32") = gv1(data1, weight11) + return gv11 + + +@tvm.script.ir_module +class Conv2dReLUx2: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + weight2: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1))) + conv2 = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0))) + R.output(conv2) + + return conv2 + + +@tvm.script.ir_module +class Conv2dReLUx2Partitioned: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + cls = Conv2dReLUx2Partitioned + with R.dataflow(): + lv: R.Tensor( + (1, 64, 56, 56), dtype="float32" + ) = cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1) + gv: R.Tensor( + (1, 64, 54, 54), dtype="float32" + ) = cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data1, weight11, padding=[1, 1, 1, 1] + ) + gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu1( + conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + conv1, weight21, padding=[0, 0, 0, 0] + ) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv2) + R.output(gv2) + return gv2 + + +@tvm.script.ir_module +class Conv2dReLUx2Partitioned_only_conv2d: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + cls = Conv2dReLUx2Partitioned_only_conv2d + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d( + data, weight1 + ) + conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv) + lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_nn_conv2d1( + conv1, weight2 + ) + conv2d: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1) + R.output(conv2d) + return conv2d + + @R.function + def fused_relax_nn_conv2d( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d"}) + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data1, weight11, padding=[1, 1, 1, 1] + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d1( + conv11: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d"}) + with R.dataflow(): + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + conv11, weight21, padding=[0, 0, 0, 0] + ) + R.output(gv1) + return gv1 + + +@tvm.script.ir_module +class Conv2dConv2dReLU: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + weight2: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d(data, weight1, padding=(1, 1)) + conv2d = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0))) + R.output(conv2d) + + return conv2d + + +@tvm.script.ir_module +class Conv2dConv2dReLUPartitioned: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + cls = Conv2dConv2dReLUPartitioned + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d( + data, weight1 + ) + gv: R.Tensor( + (1, 64, 54, 54), dtype="float32" + ) = cls.fused_relax_nn_conv2d_relax_nn_relu(lv, weight2) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu( + conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + conv1, weight21, padding=[0, 0, 0, 0] + ) + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_conv2d( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d"}) + with R.dataflow(): + gv2: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data1, weight11, padding=[1, 1, 1, 1] + ) + R.output(gv2) + return gv2 + + +@tvm.script.ir_module +class BranchTupleOutput: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d(data, weight) + relu1 = R.nn.relu(conv1) + gelu1 = R.nn.gelu(relu1) + gelu2 = R.nn.gelu(conv1) + out = relax.op.add(gelu1, gelu2) + R.output(out) + + return out + + +@tvm.script.ir_module +class BranchTupleOutputPartitioned: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + with R.dataflow(): + cls = BranchTupleOutputPartitioned + lv: R.Tuple( + R.Tensor((1, 64, 54, 54), dtype="float32"), + R.Tensor((1, 64, 54, 54), dtype="float32"), + ) = cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight) + lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = lv[1] # conv1 + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv[0] # relu(conv1) + gelu1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv2) + gelu2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv1) + out: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(gelu1, gelu2) + R.output(out) + return out + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tuple( + R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54), dtype="float32") + ): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(data1, weight1) + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(gv) + R.output(gv, gv1) + return (gv1, gv) + + +@tvm.script.ir_module +class Branch: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d(data, weight) + relu1 = R.nn.relu(conv1) + gelu1 = R.nn.gelu(conv1) + + out = relax.op.add(relu1, gelu1) + R.output(out) + + return out + + +@tvm.script.ir_module +class Conv2dx2: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), "float16"), + weight1: R.Tensor((16, 3, 3, 16), "float16"), + weight2: R.Tensor((16, 3, 3, 16), "float16"), + ): + with R.dataflow(): + conv1 = relax.op.nn.conv2d( + data, weight1, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + conv2 = relax.op.nn.conv2d( + conv1, weight2, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + R.output(conv2) + + return conv2 + + +@tvm.script.ir_module +class Conv2dx2_partitioned: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), + weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + cls = Conv2dx2_partitioned + with R.dataflow(): + lv: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_cutlass( + data, weight1 + ) + gv: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_cutlass( + lv, weight2 + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_cutlass( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + R.func_attr({"Codegen": "cutlass", "global_symbol": "fused_relax_nn_conv2d_cutlass"}) + + @R.function + def gv( + data_1: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1}) + with R.dataflow(): + gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = R.nn.conv2d( + data_1, + weight1_1, + padding=[1, 1, 1, 1], + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(gv_1) + return gv_1 + + gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv(data, weight1) + return gv1 + + +conv2d_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation=None) +conv2d_relu_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation="relax.nn.relu") + + +def check(mod, patterns, expected, bind_constants=True, annotate_codegen=False): + partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(mod) + tvm.ir.assert_structural_equal(partitioned, expected) + + +def test_partition_conv2d_relu(): + check(Conv2dReLUx2, [("dnnl.conv2d_relu", conv2d_relu_pat)], Conv2dReLUx2Partitioned) + + +def test_partition_multiple_patterns(): + check( + Conv2dConv2dReLU, + [("dnnl.conv2d_relu", conv2d_relu_pat), ("dnnl.conv2d", conv2d_pat)], + Conv2dConv2dReLUPartitioned, + ) + + +def test_partition_order(): + check( + Conv2dReLUx2, + [("dnnl.conv2d", conv2d_pat), ("dnnl.conv2d_relu", conv2d_relu_pat)], + Conv2dReLUx2Partitioned_only_conv2d, + ) + + +def test_branch_tuple_output(): + check(BranchTupleOutput, [("dnnl.conv2d_relu", conv2d_relu_pat)], BranchTupleOutputPartitioned) + + +def test_cyclic_dependency(): + conv_pat = make_fused_bias_activation_pattern("relax.nn.conv2d") + relu_pat = is_op("relax.nn.relu")(conv_pat) + add_pat = is_op("relax.add")(relu_pat, wildcard()) + + with pytest.raises(tvm.error.TVMError) as err: + relax.transform.FuseOpsByPattern( + [("compiler_A.conv2d_relu_add", add_pat)], bind_constants=True + )(Branch) + + assert "A cyclic dependency detected" in str(err.value) + + +def test_bind_params(): + weight_np = np.random.randn(64, 64, 3, 3).astype("float32") + mod = tvm.transform.Sequential( + [ + relax.transform.BindParams("main", {"weight1": weight_np}), + relax.transform.FuseOpsByPattern( + [("dnnl.conv2d_relu", conv2d_relu_pat)], bind_constants=True + ), + ] + )(Conv2dReLU) + + assert "fused_relax_nn_conv2d_relax_nn_relu" in [var.name_hint for var in mod.functions.keys()] + + for gvar, f in mod.functions.items(): + if gvar.name_hint == "fused_relax_nn_conv2d_relax_nn_relu": + conv2d = f.body.blocks[0].bindings[0].value + assert isinstance(conv2d.args[1], relax.Constant) + + +def test_annotate_codegen(): + check( + Conv2dReLU, + [("dnnl.conv2d_relu", conv2d_relu_pat)], + Conv2dReLU_composite_annotated, + annotate_codegen=True, + ) + + +def test_multiple_calls_same_extern(): + pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None) + check(Conv2dx2, [("cutlass.conv2d", pat)], Conv2dx2_partitioned, annotate_codegen=True) + + +def test_ignore_call_tir(): + @I.ir_module + class Conv2dReLUCallTIR: + @T.prim_func + def relu( + data: T.Buffer((64, 64, 56, 56), "float32"), out: T.Buffer((64, 64, 56, 56), "float32") + ): + for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56): + with T.block("root"): + i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + out[i, j, k, l] = T.max(data[i, j, k, l], 0.0) + + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d(data, weight1, padding=(1, 1)) + relu1 = R.call_tir( + Conv2dReLUCallTIR.relu, (conv1,), R.Tensor((64, 64, 56, 56), "float32") + ) + R.output(relu1) + + return relu1 + + @I.ir_module + class Conv2dReLUCallTIR_partitioned: + @T.prim_func + def relu( + data: T.Buffer((64, 64, 56, 56), "float32"), out: T.Buffer((64, 64, 56, 56), "float32") + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56): + with T.block("root"): + i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(data[i, j, k, l]) + T.writes(out[i, j, k, l]) + out[i, j, k, l] = T.max(data[i, j, k, l], T.float32(0)) + + @R.function + def fused_relax_nn_conv2d( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1}) + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data, + weight1, + padding=(1, 1), + ) + R.output(gv) + return gv + + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((64, 64, 56, 56), dtype="float32"): + cls = Conv2dReLUCallTIR_partitioned + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d( + data, weight1 + ) + relu1 = R.call_tir( + cls.relu, (lv,), out_sinfo=R.Tensor((64, 64, 56, 56), dtype="float32") + ) + R.output(relu1) + return relu1 + + pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None) + check(Conv2dReLUCallTIR, [("cutlass.conv2d", pat)], Conv2dReLUCallTIR_partitioned) + + +def test_unused(): + @I.ir_module + class Conv2dReLU: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d(data, weight1, padding=(1, 1)) + relu = R.nn.relu(data) + R.output(conv1) + + return conv1 + + @I.ir_module + class Conv2dReLU_partitioned: + @R.function + def fused_relax_nn_conv2d( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1}) + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data, weight1, padding=(1, 1) + ) + R.output(gv) + return gv + + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + cls = Conv2dReLU_partitioned + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d( + data, weight1 + ) + relu: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(data) + R.output(gv) + return gv + + pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None) + check(Conv2dReLU, [("cutlass.conv2d", pat)], Conv2dReLU_partitioned) + + +def test_check_pattern(): + lhs = wildcard() + rhs = wildcard() + out = is_op("relax.nn.conv2d")(lhs, rhs) + annotation_patterns = {"root": out, "lhs": lhs, "rhs": rhs} + + def pred(context: PatternCheckContext): + lhs = context.annotated_expr["lhs"] + rhs = context.annotated_expr["rhs"] + expr = context.annotated_expr["root"] + assert isinstance(lhs, relax.expr.Var) and lhs.name_hint == "data" + assert isinstance(rhs, relax.expr.Var) and rhs.name_hint == "weight1" + assert isinstance(expr, relax.expr.Call) and expr.op.name == "relax.nn.conv2d" + return False + + check( + Conv2dReLU, [("cutlass.conv2d", out, annotation_patterns, pred)], Conv2dReLU + ) # expect no partitioning + + +def test_bind_constants(): + weight = np.random.randn(64, 64, 3, 3).astype("float32") + + @I.ir_module + class Conv2dWithConstantWeight: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d(data, R.const(weight), padding=(1, 1)) + R.output(conv1) + return conv1 + + @I.ir_module + class Conv2dWithConstantWeight_partitioned: + @R.function + def fused_relax_nn_conv2d( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + param_0: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1}) + with R.dataflow(): + gv = R.nn.conv2d(data, param_0, padding=(1, 1)) + R.output(gv) + return gv + + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + cls = Conv2dWithConstantWeight_partitioned + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d( + data, R.const(weight) + ) + R.output(gv) + return gv + + pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None) + check( + Conv2dWithConstantWeight, + [("cutlass.conv2d", pat)], + Conv2dWithConstantWeight_partitioned, + bind_constants=False, + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py new file mode 100644 index 000000000000..356e28d6e910 --- /dev/null +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -0,0 +1,702 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax, topi +from tvm.script import ir as I, relax as R, tir as T + + +def _check(mod_before, mod_expected): + mod = relax.transform.FuseTIR()(mod_before) + tvm.ir.assert_structural_equal(mod, mod_expected) + + +def test_simple(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0])) + bb.emit_func_output(gv) + + return bb.get().with_attrs({"foo": "bar"}) + + def expected(): + def fused_add_exp_squeeze(x, p0): + add = topi.add(x, p0) + exp = topi.exp(add) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(fused_add_exp_squeeze, x, p0)) + bb.emit_func_output(gv) + return bb.get().with_attrs({"foo": "bar"}) + + _check(before(), expected()) + + +def test_conv2d_fuse(): + def before(dtype): + bb = relax.BlockBuilder() + + # Grouped function 1 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype)) + p0 = relax.Var("p0", R.Tensor((), dtype)) + with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d", + ) + lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1") + gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Grouped function 2 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype)) + y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype)) + with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=0, + dilation=1, + primfunc_name_hint="conv2d1", + ) + gv = bb.emit_output(bb.call_te(topi.add, lv0, y, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2") + fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2") + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) + w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) + w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, relax.const(1, dtype)])) + lv2 = bb.emit_te( + topi.nn.conv2d, + lv1, + w3, + strides=1, + padding=1, + dilation=1, + ) + gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2, lv2])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dtype): + def fused_conv2d_add1_add2(x, w, p): + conv = topi.nn.conv2d(x, w, strides=1, padding=1, dilation=1) + add = topi.add(p, conv) + return topi.add(conv, add) + + def fused_conv2d1_add2(x, w, p): + conv = topi.nn.conv2d(x, w, strides=1, padding=0, dilation=1) + return topi.add(conv, p) + + bb = relax.BlockBuilder() + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) + w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) + w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit_te(fused_conv2d_add1_add2, lv0, w1, relax.const(1, dtype)) + lv2 = bb.emit_te( + topi.nn.conv2d, + lv1, + w3, + strides=1, + padding=1, + dilation=1, + ) + gv = bb.emit_output(bb.call_te(fused_conv2d1_add2, lv1, w2, lv2)) + bb.emit_func_output(gv) + + return bb.get() + + _check(before("float32"), expected("float32")) + + +def test_two_subfunction(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) + with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}): + with bb.dataflow(): + lv1 = bb.emit_te(topi.exp, x1) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_squeeze") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit(relax.Call(func_gv, [x])) + lv2 = bb.emit(relax.Call(func_gv, [lv])) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_squeeze(x): + exp = topi.exp(x) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_exp_squeeze, x) + lv2 = bb.emit_te(fused_exp_squeeze, lv) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_same_primfunc(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) + with bb.function("fused_exp_exp_squeeze", [x1], attrs={"Primitive": True}): + with bb.dataflow(): + lv1 = bb.emit_te(topi.exp, x1) + lv2 = bb.emit_te(topi.exp, lv1) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv2)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_exp_squeeze") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit(relax.Call(func_gv, [x])) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_exp_squeeze(x): + exp = topi.exp(x) + exp = topi.exp(exp) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_exp_exp_squeeze, x) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_tuple_as_param(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])) + with bb.function("fused_exp_add", [x], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + lv2 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.add, lv2, lv1)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_add") + x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(func_gv, [x])) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_add(x1, x2): + exp = topi.exp(x1) + return topi.add(exp, x2) + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + gv = bb.emit_output(bb.call_te(fused_exp_add, lv0, lv1)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_nested_tuple_as_param(): + tuple_struct_info = R.Tuple( + [R.Tensor([10], "float32"), R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])] + ) + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", tuple_struct_info) + with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv0_exp = bb.emit_te(topi.exp, lv0) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + lv1_0 = bb.emit(relax.TupleGetItem(lv1, 0)) + lv1_1 = bb.emit(relax.TupleGetItem(lv1, 1)) + lv2 = bb.emit_te(topi.add, lv1_0, lv1_1) + gv = bb.emit_output(bb.call_te(topi.add, lv0_exp, lv2)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_add_add") + x = relax.Var("x", tuple_struct_info) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(func_gv, [x])) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_add_add(x1, x2, x3): + exp = topi.exp(x1) + add = topi.add(x2, x3) + return topi.add(exp, add) + + bb = relax.BlockBuilder() + x = relax.Var("x", tuple_struct_info) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + lv2 = bb.emit(relax.TupleGetItem(lv1, 0)) + lv3 = bb.emit(relax.TupleGetItem(lv1, 1)) + gv = bb.emit_output(bb.call_te(fused_exp_add_add, lv0, lv2, lv3)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_call_tir_in_main(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) + with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}): + with bb.dataflow(): + lv = bb.emit_te(topi.exp, x1) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_squeeze") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.Call(func_gv, [x])) + lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32")) + gv = bb.emit_output(lv1) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_squeeze(x): + exp = topi.exp(x) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_exp_squeeze, x) + lv2 = bb.emit_te(topi.add, lv, relax.const(1, "float32")) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_const_in_argument(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) + x2 = relax.Var("x2", R.Tensor([], "float32")) + with bb.function("fused_add_exp_squeeze", [x1, x2], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x1, x2) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_add_exp_squeeze") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit(relax.Call(func_gv, [x, relax.const(1, "float32")])) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_add_exp_squeeze(x, y): + add = topi.add(x, y) + exp = topi.exp(add) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_add_exp_squeeze, x, relax.const(1, "float32")) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_tuple_output(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + + with bb.function("fused_add_exp", [x, p0], attrs={"Primitive": True}): + with bb.dataflow(): + gv0 = bb.emit_output(bb.call_te(topi.add, x, p0)) + gv1 = bb.emit_output(bb.call_te(topi.exp, gv0)) + bb.emit_func_output(relax.Tuple([gv0, gv1])) + fused_add_exp = bb.get().get_global_var("fused_add_exp") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_add_exp, [x, p0])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + def fused_add_exp(x, p0): + add = topi.add(x, p0) + exp = topi.exp(add) + return add, exp + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(fused_add_exp, x, p0)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_immediate_tuple(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + y = relax.Var("y", R.Tensor([10, 20], "float32")) + + with bb.function("fused_add", [x, y], attrs={"Primitive": True}): + with bb.dataflow(): + lv_tuple = bb.emit(relax.Tuple([x, relax.Tuple([x, y])])) + lv_x = bb.emit(relax.TupleGetItem(lv_tuple, 0)) + lv0 = bb.emit(relax.TupleGetItem(lv_tuple, 1)) + lv_y = bb.emit(relax.TupleGetItem(lv0, 1)) + gv = bb.emit_output(bb.call_te(topi.add, lv_x, lv_y)) + bb.emit_func_output(gv) + fused_add = bb.get().get_global_var("fused_add") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + y = relax.Var("y", R.Tensor([10, 20], "float32")) + with bb.function("main", [x, y]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_add, [x, y])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + y = relax.Var("y", R.Tensor([10, 20], "float32")) + with bb.function("main", [x, y]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(topi.add, x, y, primfunc_name_hint="fused_add")) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_return_partial_result(): + def te_argmax_idx_val(val): + from tvm import te + + def f_combine(x, y): + lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) + return lhs, rhs + + def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): + return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1) + + argmax = te.comm_reducer(f_combine, f_identity, name="argmax") + m, n = val.shape + k = te.reduce_axis((0, n), "k") + max_idx, max_val = te.compute( + (m,), lambda i: argmax((k.var, val[i, k]), axis=k), name="argmax" + ) + return max_idx, max_val + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + offset = relax.Var("offset", R.Tensor([10], "int32")) + with bb.function("fused_argmax_add", [x, offset], attrs={"Primitive": True}): + with bb.dataflow(): + lv = bb.emit_te(te_argmax_idx_val, x) + idx = bb.emit(relax.TupleGetItem(lv, 0)) + gv = bb.emit_output(bb.call_te(topi.add, idx, offset)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_argmax_add") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + offset = relax.Var("x", R.Tensor([10], "int32")) + with bb.function("main", [x, offset]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(func_gv, [x, offset])) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_argmax_add(x, offset): + idx, value = te_argmax_idx_val(x) + idx = topi.add(idx, offset) + return idx + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + offset = relax.Var("offset", R.Tensor([10], "int32")) + with bb.function("main", [x, offset]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(fused_argmax_add, x, offset)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_multiple_relax_functions(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") + + x = relax.Var("x", R.Tensor([20, 10], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add1_exp1_squeeze1 = bb.get().get_global_var("fused_add1_exp1_squeeze1") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("func1", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call(fused_add_exp_squeeze, [x, relax.const(1, "float32")]) + ) + bb.emit_func_output(gv) + + x = relax.Var("x", R.Tensor([20, 10], "float32")) + with bb.function("func2", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call(fused_add1_exp1_squeeze1, [x, relax.const(1, "float32")]) + ) + bb.emit_func_output(gv) + + return bb.get() + + @I.ir_module + class Expected: + @R.function + def func1(x: R.Tensor((10, 20), dtype="float32")) -> R.Tensor((10, 20), dtype="float32"): + with R.dataflow(): + gv2 = R.call_tir( + Expected.fused_add_exp_squeeze, + (x, R.const(1, "float32")), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv2) + return gv2 + + @R.function + def func2(x: R.Tensor((20, 10), dtype="float32")) -> R.Tensor((20, 10), dtype="float32"): + with R.dataflow(): + gv3 = R.call_tir( + Expected.fused_add1_exp1_squeeze1, + (x, R.const(1, "float32")), + out_sinfo=R.Tensor((20, 10), dtype="float32"), + ) + R.output(gv3) + return gv3 + + @T.prim_func + def fused_add1_exp1_squeeze1( + x: T.Buffer((T.int64(20), T.int64(10)), "float32"), + p0: T.Buffer((), "float32"), + T_squeeze: T.Buffer((T.int64(20), T.int64(10)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + T_add = T.alloc_buffer((T.int64(20), T.int64(10))) + compute = T.alloc_buffer((T.int64(20), T.int64(10))) + for ax0, ax1 in T.grid(T.int64(20), T.int64(10)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x[v_ax0, v_ax1], p0[()]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + for i0, i1 in T.grid(T.int64(20), T.int64(10)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_add[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.exp(T_add[v_i0, v_i1]) + for ax0, ax1 in T.grid(T.int64(20), T.int64(10)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(compute[v_ax0, v_ax1]) + T.writes(T_squeeze[v_ax0, v_ax1]) + T_squeeze[v_ax0, v_ax1] = compute[v_ax0, v_ax1] + + @T.prim_func + def fused_add_exp_squeeze( + x: T.Buffer((T.int64(10), T.int64(20)), "float32"), + p0: T.Buffer((), "float32"), + T_squeeze: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + T_add = T.alloc_buffer((T.int64(10), T.int64(20))) + compute = T.alloc_buffer((T.int64(10), T.int64(20))) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x[v_ax0, v_ax1], p0[()]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_add[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.exp(T_add[v_i0, v_i1]) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(compute[v_ax0, v_ax1]) + T.writes(T_squeeze[v_ax0, v_ax1]) + T_squeeze[v_ax0, v_ax1] = compute[v_ax0, v_ax1] + + _check(before(), Expected) + + +def test_skip_call_dps_packed(): + @I.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + y = R.call_dps_packed("func_packed_dps", x, R.Tensor((2, 3), "float32")) + R.output(y) + return y + + # FuseTIR should does no change to it. + _check(Module, Module) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py new file mode 100644 index 000000000000..017a673e8fcf --- /dev/null +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -0,0 +1,307 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax +import tvm.script +from tvm.script import relax as R, tir as T +from tvm.relax import transform +from tvm.ir.base import assert_structural_equal + + +def _check_equal(x, y): + tvm.ir.assert_structural_equal(x, y) + tvm.ir.assert_structural_equal(y, x) + + xhash = tvm.ir.structural_hash(x, map_free_vars=True) + yhash = tvm.ir.structural_hash(y, map_free_vars=True) + assert xhash == yhash + + +def _check_save_roundtrip(x): + y = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, y) + + +def test_basic(): + # the target IRModule + @tvm.script.ir_module + class Expected: + @R.function + def lifted_func_0( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + @R.function + def main( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + inner = Expected.lifted_func_0 + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @tvm.script.ir_module + class Before: + @R.function + def main( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + @R.function + def inner( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + before = Before + expected = Expected + # Perform Lambda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 2 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_closure(): + # the expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + outer_func = Expected.lifted_func_0 + in_call = outer_func(x) + res = R.invoke_closure(in_call, (y,), sinfo_args=(R.Tensor((2, 3), dtype="float32"))) + return res + + @R.function + def lifted_func_1(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2, 3), "float32")): + r_1: R.Tensor((2, 3), "float32") = R.add(x1, c1) + return r_1 + + @R.function + def lifted_func_0(y: R.Tensor((2, 3), "float32")) -> R.Object: + inner_func = R.make_closure(Expected.lifted_func_1, (y,)) + return inner_func + + # IRModule to perform Lambda Lifting + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + @R.function + def outer_func( + c1: R.Tensor((2, 3), "float32") + ) -> R.Callable((R.Tensor((2, 3), "float32"),), R.Tensor((2, 3), "float32")): + @R.function + def inner_func(x1: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + s: R.Tensor((2, 3), "float32") = R.add(x1, c1) + return s + + return inner_func + + in_call = outer_func(x) + res = in_call(y) + return res + + before = Before + after = transform.LambdaLift()(before) + expected = Expected + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_recursive(): + # the expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def lifted_func_0( + i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + cond: R.Tensor((), "bool") = R.call_packed( + "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool")) + ) + c: R.Tensor((), "int32") = R.const(1, dtype="int32") + if cond: + new_i: R.Tensor((), "int32") = R.add(i, c) + new_s: R.Tensor((2, 3), "float32") = R.add(s, x) + new_r = Expected.lifted_func_0(new_i, new_s, x) + r = new_r + else: + r = s + return r + + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), dtype="float32"): + while_loop = R.make_closure(Expected.lifted_func_0, (x,)) + gv: R.Tensor((2, 3), dtype="float32") = R.invoke_closure( + while_loop, + (R.const(0), x), + sinfo_args=(R.Tensor((2, 3), dtype="float32")), + ) + return gv + + # the IRModule to apply lambda lifting + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: + @R.function + def while_loop( + i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + cond: R.Tensor((), "bool") = R.call_packed( + "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool")) + ) + c: R.Tensor((), "int32") = R.const(1, dtype="int32") + if cond: + new_i: R.Tensor((), "int32") = R.add(i, c) + new_s: R.Tensor((2, 3), "float32") = R.add(s, x) + r: R.Tensor((2, 3), "float32") = while_loop(new_i, new_s) + else: + r: R.Tensor((2, 3), "float32") = s + return r + + gv: R.Tensor((2, 3), "float32") = while_loop(R.const(0), x) + return gv + + before = Before + expected = Expected + # check well-formness of recursive call + assert relax.analysis.well_formed(before) + + # Perform Lambda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 2 + + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_multi_func(): + # expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def glob_func_1( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + inner = Expected.lifted_func_0 + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @R.function + def glob_func_2( + x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + inner = Expected.lifted_func_1 + gv11: R.Tensor((10, 5), "float32") = inner(x11, y11) + return gv11 + + @R.function + def lifted_func_0( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + @R.function + def lifted_func_1( + x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s1: R.Tensor((10, 5), "float32") = R.add(x21, y21) + return s1 + + # the IRModule to apply lambda lifting + @tvm.script.ir_module + class Before: + @R.function + def glob_func_1( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + @R.function + def inner( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @R.function + def glob_func_2( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + @R.function + def inner( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + before = Before + expected = Expected + # Perform Lambda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 4 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_no_local_func(): + @tvm.script.ir_module + class Before: + @T.prim_func + def sub( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + ) -> None: + for i, j in T.grid(16, 16): + with T.block("sub"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] - B[vi, vj] + + @R.function + def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor(dtype="float32", ndim=2)): + s = R.call_tir(Before.sub, (c0, x), R.Tensor((16, 16), dtype="float32")) + return s + + before = Before + # Perform lambda lifting + after = transform.LambdaLift()(before) + # No local functions are lifted + assert_structural_equal(after, before, map_free_vars=True) + _check_save_roundtrip(after) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops.py b/tests/python/relax/test_transform_legalize_ops.py new file mode 100644 index 000000000000..73c5770c5dbd --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops.py @@ -0,0 +1,264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +from tvm.relax.transform import LegalizeOps +from tvm.relax.transform.legalize_ops.common import register_legalize +from tvm.script import relax as R, tir as T +import tvm.testing + + +def test_customize_legalize(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.add(x, y) + return gv + + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + cls = Expected + gv = R.call_tir(cls.add, (y, x), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def add(rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + rxplaceholder[T.int64(0), ax2, ax3] + # fmt: on + + def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call): + from tvm import topi # pylint: disable=import-outside-toplevel + + return bb.call_te(topi.add, call.args[1], call.args[0]) + + mod = LegalizeOps({"relax.add": customize_legalize_add})(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_legalize_multiple_types_of_call(): + # fmt: off + @tvm.script.ir_module + class Before: + @R.function + def mul2(x: R.Tensor((3, 3), "float32")): + gv = R.multiply(x, R.const(2.0, "float32")) + return gv + + @T.prim_func + def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: T.Buffer((T.int64(3), T.int64(3)), "float32")): + for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_id[v_ax0, v_ax1]) + T_id[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + + @R.function + def main(x: R.Tensor((3, 3), "float32")): + cls = Before + gv: R.Tensor((3, 3), "float32") = cls.mul2(x) + gv1 = R.call_tir(cls.identity, gv, R.Tensor((3, 3), dtype="float32")) + gv2 = R.multiply(gv1, R.const(2.0, "float32")) + return gv2 + + @tvm.script.ir_module + class Expected: + @R.function + def mul2(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"): + cls = Expected + gv = R.call_tir(cls.multiply, (x,), R.Tensor((3, 3), dtype="float32")) + return gv + + @T.prim_func + def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: T.Buffer((T.int64(3), T.int64(3)), "float32")): + for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_id[v_ax0, v_ax1]) + T_id[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + + @T.prim_func + def multiply(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_multiply[v_ax0, v_ax1]) + T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * T.float32(2) + + @R.function + def main(x1: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"): + cls = Expected + gv1: R.Tensor((3, 3), dtype="float32") = cls.mul2(x1) + gv11 = R.call_tir(cls.identity, gv1, R.Tensor((3, 3), dtype="float32")) + gv2 = R.call_tir(cls.multiply, (gv11,), R.Tensor((3, 3), dtype="float32")) + return gv2 + # fmt: on + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + +def test_can_not_legalize(): + # case 1: does't have legalization + add_legalize = tvm.ir.Op.get("relax.add").get_attr("FLegalize") + # reset it for test + tvm.ir.Op.get("relax.add").reset_attr("FLegalize") + + # fmt: off + @tvm.script.ir_module + class Before0: + @R.function + def main(x: R.Tensor((3, 3), "float32")): + gv: R.Tensor((3, 3), "float32") = R.add(x, x) + return gv + # fmt: on + After0 = LegalizeOps()(Before0) + tvm.ir.assert_structural_equal(After0, Before0) + + register_legalize("relax.add", add_legalize) + + # case 2: don't know all shape + s = relax.Var("s", relax.ShapeStructInfo((3, 3))) + x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + y = relax.Var("y", relax.TensorStructInfo(s, "float32")) + bb = relax.BlockBuilder() + with bb.function("main", [x, y]): + with bb.dataflow(): + gv = bb.emit_output(R.add(x, y)) + bb.emit_func_output(gv) + Before1 = bb.get() + After1 = LegalizeOps()(Before1) + tvm.ir.assert_structural_equal(After1, Before1) + + +def test_legalize_scalar_data_type_preserve(): + # fmt: off + @tvm.script.ir_module + class Before0: + @R.function + def main(x: R.Tensor((3, 3), "float16")): + gv: R.Tensor((3, 3), "float16") = R.multiply(x, R.const(1.14514, "float16")) + return gv + + @tvm.script.ir_module + class Before1: + @R.function + def main(x: R.Tensor((3, 3), "uint8")): + gv: R.Tensor((3, 3), "uint8") = R.multiply(x, R.const(2, "uint8")) + return gv + + @tvm.script.ir_module + class Before2: + @R.function + def main(x: R.Tensor((3, 3), "bool")): + gv: R.Tensor((3, 3), "bool") = R.equal(x, R.const(True, "bool")) + return gv + + @tvm.script.ir_module + class Expected0: + @T.prim_func + def multiply( + rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float16"), + T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float16"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_multiply[v_ax0, v_ax1]) + T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * T.float16( + 1.1455078125 + ) + + @R.function + def main(x: R.Tensor((3, 3), dtype="float16")) -> R.Tensor((3, 3), dtype="float16"): + cls = Expected0 + gv = R.call_tir(cls.multiply, (x,), out_sinfo=R.Tensor((3, 3), dtype="float16")) + return gv + + @tvm.script.ir_module + class Expected1: + @T.prim_func + def multiply( + rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "uint8"), + T_multiply: T.Buffer((T.int64(3), T.int64(3)), "uint8"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_multiply[v_ax0, v_ax1]) + T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * T.uint8(2) + + @R.function + def main(x: R.Tensor((3, 3), dtype="uint8")) -> R.Tensor((3, 3), dtype="uint8"): + cls = Expected1 + gv = R.call_tir(cls.multiply, (x,), out_sinfo=R.Tensor((3, 3), dtype="uint8")) + return gv + + @tvm.script.ir_module + class Expected2: + @T.prim_func + def equal( + rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "bool"), + T_equal: T.Buffer((T.int64(3), T.int64(3)), "bool"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): + with T.block("T_equal"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_equal[v_ax0, v_ax1]) + T_equal[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] == tvm.tir.const(True, "bool") + + @R.function + def main(x: R.Tensor((3, 3), dtype="bool")) -> R.Tensor((3, 3), dtype="bool"): + cls = Expected2 + gv = R.call_tir(cls.equal, (x,), out_sinfo=R.Tensor((3, 3), dtype="bool")) + return gv + # fmt: on + + After0 = LegalizeOps()(Before0) + tvm.ir.assert_structural_equal(After0, Expected0) + After1 = LegalizeOps()(Before1) + tvm.ir.assert_structural_equal(After1, Expected1) + After2 = LegalizeOps()(Before2) + tvm.ir.assert_structural_equal(After2, Expected2) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py b/tests/python/relax/test_transform_legalize_ops_binary.py new file mode 100644 index 000000000000..dc14a0c3fd40 --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_binary.py @@ -0,0 +1,1611 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + + +##################### Binary arithmetic ##################### + + +def test_add(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.add(x, y) + return gv + + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(Expected.add, (x, y), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_add_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.add(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.add, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + T.float32(1) + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_add_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.add(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.add, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = T.float32(1) + rxplaceholder[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_add_symbolic(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "float32") = R.add(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.add, (x, y), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def add(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_add: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_add = T.match_buffer(var_T_add, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_divide(): + # fmt: off + @tvm.script.ir_module + class Divide: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.divide(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(Expected.divide, (x, y), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def divide(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_divide: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_divide[ax0, ax1, ax2, ax3]) + T_divide[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] / rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Divide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_divide_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Divide: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.divide(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.divide, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_divide[ax0, ax1]) + T_divide[ax0, ax1] = rxplaceholder[ax0, ax1] / T.float32(1) + # fmt: on + + mod = LegalizeOps()(Divide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_divide_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Divide: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.divide(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.divide, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_divide[ax0, ax1]) + T_divide[ax0, ax1] = T.float32(1) / rxplaceholder[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Divide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_divide_symbolic(): + # fmt: off + @tvm.script.ir_module + class Divide: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "float32") = R.divide(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.divide, (x, y), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_divide: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_divide = T.match_buffer(var_T_divide, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_divide[ax0, ax1, ax2, ax3]) + T_divide[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] / rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Divide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_floor_divide(): + # fmt: off + @tvm.script.ir_module + class FloorDivide: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.floor_divide(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(Expected.floor_divide, (x, y), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def floor_divide(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_floor_divide: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_floor_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_floor_divide[ax0, ax1, ax2, ax3]) + T_floor_divide[ax0, ax1, ax2, ax3] = T.floor(rxplaceholder[T.int64(0), ax2, ax3] / rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + # fmt: on + + mod = LegalizeOps()(FloorDivide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_floor_divide_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class FloorDivide: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.floor_divide(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.floor_divide, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def floor_divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_floor_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_floor_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_floor_divide[ax0, ax1]) + T_floor_divide[ax0, ax1] = T.floor(rxplaceholder[ax0, ax1] / T.float32(1)) + # fmt: on + + mod = LegalizeOps()(FloorDivide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_floor_divide_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class FloorDivide: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.floor_divide(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.floor_divide, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def floor_divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_floor_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_floor_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_floor_divide[ax0, ax1]) + T_floor_divide[ax0, ax1] = T.floor(T.float32(1) / rxplaceholder[ax0, ax1]) + # fmt: on + + mod = LegalizeOps()(FloorDivide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_floor_divide_symbolic(): + # fmt: off + @tvm.script.ir_module + class FloorDivide: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "float32") = R.floor_divide(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.floor_divide, (x, y), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def floor_divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_floor_divide: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_floor_divide = T.match_buffer(var_T_floor_divide, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_floor_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_floor_divide[ax0, ax1, ax2, ax3]) + T_floor_divide[ax0, ax1, ax2, ax3] = T.floor(rxplaceholder[T.int64(0), ax2, ax3] / rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + # fmt: on + + mod = LegalizeOps()(FloorDivide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_multiply(): + # fmt: off + @tvm.script.ir_module + class Multiply: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.multiply(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(Expected.multiply, (x, y), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def multiply(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_multiply: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] * rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Multiply) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_multiply_symbolic(): + # fmt: off + @tvm.script.ir_module + class Multiply: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "float32") = R.multiply(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.multiply, (x, y), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def multiply(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_multiply: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_multiply = T.match_buffer(var_T_multiply, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] * rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Multiply) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_power(): + # fmt: off + @tvm.script.ir_module + class Power: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.power(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def power(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_power: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_power"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)]) + T.writes(T_power[v_ax0, v_ax1, v_ax2, v_ax3]) + T_power[v_ax0, v_ax1, v_ax2, v_ax3] = T.pow(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)]) + + @R.function + def main(x: R.Tensor((1, 2, 3), dtype="float32"), y: R.Tensor((4, 3, 2, 1), dtype="float32")) -> R.Tensor((4, 3, 2, 3), dtype="float32"): + gv = R.call_tir(Expected.power, (x, y), out_sinfo=R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + # fmt: on + + mod = LegalizeOps()(Power) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_power_symbolic(): + # fmt: off + @tvm.script.ir_module + class Power: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "float32") = R.power(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def power(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_power: T.handle): + T.func_attr({"tir.noalias": True}) + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), c, d)) + a = T.int64() + b = T.int64() + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (a, b, c, T.int64(1))) + T_power = T.match_buffer(var_T_power, (a, b, c, d)) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(a, b, c, d): + with T.block("T_power"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)]) + T.writes(T_power[v_ax0, v_ax1, v_ax2, v_ax3]) + T_power[v_ax0, v_ax1, v_ax2, v_ax3] = T.pow(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)]) + + @R.function + def main(x: R.Tensor((1, "c", "d"), dtype="float32"), y: R.Tensor(("a", "b", "c", 1), dtype="float32")) -> R.Tensor(("a", "b", "c", "d"), dtype="float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.power, (x, y), out_sinfo=R.Tensor((a, b, c, d), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Expected) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_subtract(): + # fmt: off + @tvm.script.ir_module + class Subtract: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.subtract(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(Expected.subtract, (x, y), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def subtract(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_subtract: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] - rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Subtract) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_subtract_symbolic(): + # fmt: off + @tvm.script.ir_module + class Subtract: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "float32") = R.subtract(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.subtract, (x, y), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def subtract(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_subtract: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_subtract = T.match_buffer(var_T_subtract, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] - rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Subtract) + tvm.ir.assert_structural_equal(mod, Expected) + + +##################### Binary comparison ##################### + + +def test_equal(): + # fmt: off + @tvm.script.ir_module + class Equal: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv: R.Tensor((4, 3, 2, 3), "bool") = R.equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv = R.call_tir(Expected.equal, (x, y), R.Tensor((4, 3, 2, 3), dtype="bool")) + return gv + + @T.prim_func + def equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_equal[ax0, ax1, ax2, ax3]) + T_equal[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] == rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Equal) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_equal_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), dtype="bool") = R.equal(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv = R.call_tir(Expected.equal, (x,), R.Tensor((2, 3), dtype="bool")) + return gv + + @T.prim_func + def equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_equal"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_equal[ax0, ax1]) + T_equal[ax0, ax1] = rxplaceholder[ax0, ax1] == T.float32(1) + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_equal_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), dtype="bool") = R.equal(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv = R.call_tir(Expected.equal, (x,), R.Tensor((2, 3), dtype="bool")) + return gv + + @T.prim_func + def equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_equal"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_equal[ax0, ax1]) + T_equal[ax0, ax1] = T.float32(1) == rxplaceholder[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_equal_symbolic(): + # fmt: off + @tvm.script.ir_module + class Equal: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "bool") = R.equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.equal, (x, y), R.Tensor((a, b, c, d), dtype="bool")) + return gv + + @T.prim_func + def equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_equal: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_equal = T.match_buffer(var_T_equal, [a, b, c, d], dtype="bool") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_equal[ax0, ax1, ax2, ax3]) + T_equal[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] == rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Equal) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_greater(): + # fmt: off + @tvm.script.ir_module + class Greater: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv: R.Tensor((4, 3, 2, 3), "bool") = R.greater(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv = R.call_tir(Expected.greater, (x, y), R.Tensor((4, 3, 2, 3), dtype="bool")) + return gv + + @T.prim_func + def greater(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_greater: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_greater"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) + T.writes(T_greater[ax0, ax1, ax2, ax3]) + T_greater[ax0, ax1, ax2, ax3] = rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] < rxplaceholder[T.int64(0), ax2, ax3] + # fmt: on + + mod = LegalizeOps()(Greater) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_greater_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), dtype="bool") = R.greater(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv = R.call_tir(Expected.greater, (x,), R.Tensor((2, 3), dtype="bool")) + return gv + + @T.prim_func + def greater(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_greater: T.Buffer((T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_greater"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_greater[ax0, ax1]) + T_greater[ax0, ax1] = T.float32(1) < rxplaceholder[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_greater_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), dtype="bool") = R.greater(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv = R.call_tir(Expected.greater, (x,), R.Tensor((2, 3), dtype="bool")) + return gv + + @T.prim_func + def greater(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_greater: T.Buffer((T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_greater"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_greater[ax0, ax1]) + T_greater[ax0, ax1] = rxplaceholder[ax0, ax1] < T.float32(1) + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_greater_symbolic(): + # fmt: off + @tvm.script.ir_module + class Greater: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "bool") = R.greater(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.greater, (x, y), R.Tensor((a, b, c, d), dtype="bool")) + return gv + + @T.prim_func + def greater(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_greater: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_greater = T.match_buffer(var_T_greater, [a, b, c, d], dtype="bool") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_greater"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) + T.writes(T_greater[ax0, ax1, ax2, ax3]) + T_greater[ax0, ax1, ax2, ax3] = rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] < rxplaceholder[T.int64(0), ax2, ax3] + # fmt: on + + mod = LegalizeOps()(Greater) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_greater_equal(): + # fmt: off + @tvm.script.ir_module + class GreaterEqual: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv: R.Tensor((4, 3, 2, 3), "bool") = R.greater_equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv = R.call_tir(Expected.greater_equal, (x, y), R.Tensor((4, 3, 2, 3), dtype="bool")) + return gv + + @T.prim_func + def greater_equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_greater_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_greater_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) + T.writes(T_greater_equal[ax0, ax1, ax2, ax3]) + T_greater_equal[ax0, ax1, ax2, ax3] = rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] <= rxplaceholder[T.int64(0), ax2, ax3] + # fmt: on + + mod = LegalizeOps()(GreaterEqual) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_greater_equal_symbolic(): + # fmt: off + @tvm.script.ir_module + class GreaterEqual: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "bool") = R.greater_equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.greater_equal, (x, y), R.Tensor((a, b, c, d), dtype="bool")) + return gv + + @T.prim_func + def greater_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_greater_equal: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_greater_equal = T.match_buffer(var_T_greater_equal, [a, b, c, d], dtype="bool") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_greater_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) + T.writes(T_greater_equal[ax0, ax1, ax2, ax3]) + T_greater_equal[ax0, ax1, ax2, ax3] = rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] <= rxplaceholder[T.int64(0), ax2, ax3] + # fmt: on + + mod = LegalizeOps()(GreaterEqual) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_less(): + # fmt: off + @tvm.script.ir_module + class Less: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv: R.Tensor((4, 3, 2, 3), "bool") = R.less(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv = R.call_tir(Expected.less, (x, y), R.Tensor((4, 3, 2, 3), dtype="bool")) + return gv + + @T.prim_func + def less(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_less: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_less"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_less[ax0, ax1, ax2, ax3]) + T_less[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] < rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Less) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_less_symbolic(): + # fmt: off + @tvm.script.ir_module + class Less: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "bool") = R.less(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.less, (x, y), R.Tensor((a, b, c, d), dtype="bool")) + return gv + + @T.prim_func + def less(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_less: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_less = T.match_buffer(var_T_less, [a, b, c, d], dtype="bool") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_less"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_less[ax0, ax1, ax2, ax3]) + T_less[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] < rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Less) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_less_equal(): + # fmt: off + @tvm.script.ir_module + class LessEqual: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv: R.Tensor((4, 3, 2, 3), "bool") = R.less_equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv = R.call_tir(Expected.less_equal, (x, y), R.Tensor((4, 3, 2, 3), dtype="bool")) + return gv + + @T.prim_func + def less_equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_less_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_less_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_less_equal[ax0, ax1, ax2, ax3]) + T_less_equal[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] <= rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(LessEqual) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_less_equal_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), dtype="bool") = R.less_equal(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv = R.call_tir(Expected.less_equal, (x,), R.Tensor((2, 3), dtype="bool")) + return gv + + @T.prim_func + def less_equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_less_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_less_equal"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_less_equal[ax0, ax1]) + T_less_equal[ax0, ax1] = rxplaceholder[ax0, ax1] <= T.float32(1) + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_less_equal_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), dtype="bool") = R.less_equal(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv = R.call_tir(Expected.less_equal, (x,), R.Tensor((2, 3), dtype="bool")) + return gv + + @T.prim_func + def less_equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_less_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_less_equal"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_less_equal[ax0, ax1]) + T_less_equal[ax0, ax1] = T.float32(1) <= rxplaceholder[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_less_equal_symbolic(): + # fmt: off + @tvm.script.ir_module + class LessEqual: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "bool") = R.less_equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.less_equal, (x, y), R.Tensor((a, b, c, d), dtype="bool")) + return gv + + @T.prim_func + def less_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_less_equal: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_less_equal = T.match_buffer(var_T_less_equal, [a, b, c, d], dtype="bool") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_less_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_less_equal[ax0, ax1, ax2, ax3]) + T_less_equal[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] <= rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(LessEqual) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_not_equal(): + # fmt: off + @tvm.script.ir_module + class NotEqual: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv: R.Tensor((4, 3, 2, 3), "bool") = R.not_equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv = R.call_tir(Expected.not_equal, (x, y), R.Tensor((4, 3, 2, 3), dtype="bool")) + return gv + + @T.prim_func + def not_equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_not_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_not_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_not_equal[ax0, ax1, ax2, ax3]) + T_not_equal[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] != rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(NotEqual) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_not_equal_symbolic(): + # fmt: off + @tvm.script.ir_module + class NotEqual: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "bool") = R.not_equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.not_equal, (x, y), R.Tensor((a, b, c, d), dtype="bool")) + return gv + + @T.prim_func + def not_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_not_equal: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_not_equal = T.match_buffer(var_T_not_equal, [a, b, c, d], dtype="bool") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_not_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_not_equal[ax0, ax1, ax2, ax3]) + T_not_equal[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] != rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(NotEqual) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_maximum(): + # fmt: off + @tvm.script.ir_module + class Maximum: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.maximum(x, y) + return gv + + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(Expected.maximum, (x, y), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def maximum(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_maximum: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_maximum"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_maximum[ax0, ax1, ax2, ax3]) + T_maximum[ax0, ax1, ax2, ax3] = T.max(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + # fmt: on + + mod = LegalizeOps()(Maximum) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_maximum_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Maximum: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.maximum(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.maximum, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def maximum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_maximum: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_maximum"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_maximum[ax0, ax1]) + T_maximum[ax0, ax1] = T.max(rxplaceholder[ax0, ax1], T.float32(1)) + # fmt: on + + mod = LegalizeOps()(Maximum) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_maximum_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Maximum: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.maximum(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.maximum, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def maximum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_maximum: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_maximum"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_maximum[ax0, ax1]) + T_maximum[ax0, ax1] = T.max(T.float32(1), rxplaceholder[ax0, ax1]) + # fmt: on + + mod = LegalizeOps()(Maximum) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_maximum_symbolic(): + # fmt: off + @tvm.script.ir_module + class Maximum: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "float32") = R.maximum(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.maximum, (x, y), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def maximum(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_maximum: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_maximum = T.match_buffer(var_T_maximum, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_maximum"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_maximum[ax0, ax1, ax2, ax3]) + T_maximum[ax0, ax1, ax2, ax3] = T.max(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + # fmt: on + + mod = LegalizeOps()(Maximum) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_minimum(): + # fmt: off + @tvm.script.ir_module + class Minimum: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.minimum(x, y) + return gv + + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(Expected.minimum, (x, y), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def minimum(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_minimum: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_minimum"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_minimum[ax0, ax1, ax2, ax3]) + T_minimum[ax0, ax1, ax2, ax3] = T.min(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + # fmt: on + + mod = LegalizeOps()(Minimum) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_minimum_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Minimum: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.minimum(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.minimum, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def minimum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_minimum: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_minimum"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_minimum[ax0, ax1]) + T_minimum[ax0, ax1] = T.min(rxplaceholder[ax0, ax1], T.float32(1)) + # fmt: on + + mod = LegalizeOps()(Minimum) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_minimum_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Minimum: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.minimum(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.minimum, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def minimum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_minimum: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_minimum"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_minimum[ax0, ax1]) + T_minimum[ax0, ax1] = T.min(T.float32(1), rxplaceholder[ax0, ax1]) + # fmt: on + + mod = LegalizeOps()(Minimum) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_minimum_symbolic(): + # fmt: off + @tvm.script.ir_module + class Minimum: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "float32") = R.minimum(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.minimum, (x, y), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def minimum(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_minimum: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_minimum = T.match_buffer(var_T_minimum, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_minimum"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_minimum[ax0, ax1, ax2, ax3]) + T_minimum[ax0, ax1, ax2, ax3] = T.min(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + # fmt: on + + mod = LegalizeOps()(Minimum) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py b/tests/python/relax/test_transform_legalize_ops_create_datatype.py new file mode 100644 index 000000000000..b0584ce75954 --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py @@ -0,0 +1,806 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + + +##################### Creation ##################### + + +def test_full(): + # fmt: off + @tvm.script.ir_module + class Full: + @R.function + def main(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.full((2, 3), v, dtype="int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="int32")) + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + # fmt: on + + mod = LegalizeOps()(Full) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_constant_scalar_fill_value(): + # fmt: off + @tvm.script.ir_module + class Full: + @R.function + def main() -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.full((2, 3), R.const(3.5, "float32"), dtype="int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main() -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), dtype="int32")) + return gv + + @T.prim_func + def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = 3 + # fmt: on + + mod = LegalizeOps()(Full) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_different_dtype(): + # fmt: off + @tvm.script.ir_module + class Full: + @R.function + def main(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.full((2, 3), v, dtype="float32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.Cast("float32", rxplaceholder[()]) + # fmt: on + + mod = LegalizeOps()(Full) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_symbolic(): + # fmt: off + @tvm.script.ir_module + class Full: + @R.function + def main(dumb_param: R.Tensor(("m", "n")), v: R.Tensor((), "int32")) -> R.Tensor(("m", "n"), "int32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "int32") = R.full((m, n), v, dtype="int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("m", "n")), v: R.Tensor((), "int32")) -> R.Tensor(("m", "n"), "int32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), dtype="int32")) + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "int32"), var_T_full: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + T_full = T.match_buffer(var_T_full, [m, n], dtype="int32") + for i0, i1 in T.grid(m, n): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + # fmt: on + + mod = LegalizeOps()(Full) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_like(): + # fmt: off + @tvm.script.ir_module + class FullLike: + @R.function + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.full_like(x, v) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + # fmt: on + + mod = LegalizeOps()(FullLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_like_constant_scalar_fill_value(): + # fmt: off + @tvm.script.ir_module + class FullLike: + @R.function + def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.full_like(x, R.const(-5, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(-5) + # fmt: on + + mod = LegalizeOps()(FullLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_like_different_dtype(): + # fmt: off + @tvm.script.ir_module + class FullLike: + @R.function + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float64"): + gv: R.Tensor((2, 3), "float64") = R.full_like(x, v, dtype="float64") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float64"): + gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="float64")) + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float64")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.Cast("float64", rxplaceholder[()]) + # fmt: on + + mod = LegalizeOps()(FullLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_like_symbolic(): + # fmt: off + @tvm.script.ir_module + class FullLike: + @R.function + def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.full_like(x, v) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "float32"), var_T_full: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + # fmt: on + + mod = LegalizeOps()(FullLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_ones(): + # fmt: off + @tvm.script.ir_module + class Ones: + @R.function + def main() -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.ones((2, 3), "float32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main() -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.ones, R.tuple(), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def ones(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(1) + # fmt: on + + mod = LegalizeOps()(Ones) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_ones_symbolic(): + # fmt: off + @tvm.script.ir_module + class Ones: + @R.function + def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.ones((m, n), "float32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.ones, R.tuple(), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def ones(var_T_full: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(1) + # fmt: on + + mod = LegalizeOps()(Ones) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_ones_like(): + # fmt: off + @tvm.script.ir_module + class OnesLike: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.ones_like(x, "int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(Expected.ones, R.tuple(), R.Tensor((2, 3), dtype="int32")) + return gv + + @T.prim_func + def ones(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = 1 + # fmt: on + + mod = LegalizeOps()(OnesLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_ones_like_symbolic(): + # fmt: off + @tvm.script.ir_module + class OnesLike: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.ones_like(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.ones, R.tuple(), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def ones(var_T_full: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(1) + # fmt: on + + mod = LegalizeOps()(OnesLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_zeros(): + # fmt: off + @tvm.script.ir_module + class Zeros: + @R.function + def main() -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.zeros((2, 3), "float32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main() -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.zeros, R.tuple(), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def zeros(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(0) + # fmt: on + + mod = LegalizeOps()(Zeros) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_zeros_symbolic(): + # fmt: off + @tvm.script.ir_module + class Zeros: + @R.function + def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.zeros((m, n), "float32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.zeros, R.tuple(), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def zeros(var_T_full: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(0) + # fmt: on + + mod = LegalizeOps()(Zeros) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_zeros_like(): + # fmt: off + @tvm.script.ir_module + class ZerosLike: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.zeros_like(x, "int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(Expected.zeros, R.tuple(), R.Tensor((2, 3), dtype="int32")) + return gv + + @T.prim_func + def zeros(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = 0 + # fmt: on + + mod = LegalizeOps()(ZerosLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_zeros_like_symbolic(): + # fmt: off + @tvm.script.ir_module + class ZerosLike: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.zeros_like(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.zeros, R.tuple(), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def zeros(var_T_full: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(0) + # fmt: on + + mod = LegalizeOps()(ZerosLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_tril(): + # fmt: off + @tvm.script.ir_module + class Tril: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.tril(x, k=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv = R.call_tir(Expected.tril, (x,), R.Tensor((2, 3, 4), dtype="float32")) + return gv + + @T.prim_func + def tril(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), trilu: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): + with T.block("trilu"): + i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1]) + T.writes(trilu[i0_1, i1_1, i2_1]) + trilu[i0_1, i1_1, i2_1] = T.Select(i2_1 - T.int64(1) <= i1_1, rxplaceholder[i0_1, i1_1, i2_1], T.float32(0)) + # fmt: on + + mod = LegalizeOps()(Tril) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_tril_symbolic(): + # fmt: off + @tvm.script.ir_module + class Tril: + @R.function + def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int8"): + m = T.int64() + n = T.int64() + k = T.int64() + gv: R.Tensor((m, n, k), "int8") = R.tril(x, k=-2) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int8"): + m = T.int64() + n = T.int64() + k = T.int64() + gv = R.call_tir(Expected.tril, (x,), R.Tensor((m, n, k), dtype="int8")) + return gv + + @T.prim_func + def tril(var_rxplaceholder: T.handle, var_trilu: T.handle): + T.func_attr({"tir.noalias": True}) + k = T.int64() + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n, k], dtype="int8") + trilu = T.match_buffer(var_trilu, [m, n, k], dtype="int8") + for i0, i1, i2 in T.grid(m, n, k): + with T.block("trilu"): + i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1]) + T.writes(trilu[i0_1, i1_1, i2_1]) + trilu[i0_1, i1_1, i2_1] = T.Select(i2_1 + T.int64(2) <= i1_1, rxplaceholder[i0_1, i1_1, i2_1], T.int8(0)) + # fmt: on + + mod = LegalizeOps()(Tril) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_triu(): + # fmt: off + @tvm.script.ir_module + class Triu: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.triu(x, k=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv = R.call_tir(Expected.triu, (x,), R.Tensor((2, 3, 4), dtype="float32")) + return gv + + @T.prim_func + def triu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), trilu: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): + with T.block("trilu"): + i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1]) + T.writes(trilu[i0_1, i1_1, i2_1]) + trilu[i0_1, i1_1, i2_1] = T.Select(i1_1 <= i2_1 - T.int64(1), rxplaceholder[i0_1, i1_1, i2_1], T.float32(0)) + # fmt: on + + mod = LegalizeOps()(Triu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_triu_symbolic(): + # fmt: off + @tvm.script.ir_module + class Triu: + @R.function + def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int8"): + m = T.int64() + n = T.int64() + k = T.int64() + gv: R.Tensor((m, n, k), "int8") = R.triu(x, k=-2) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int8"): + m = T.int64() + n = T.int64() + k = T.int64() + gv = R.call_tir(Expected.triu, (x,), R.Tensor((m, n, k), dtype="int8")) + return gv + + @T.prim_func + def triu(var_rxplaceholder: T.handle, var_trilu: T.handle): + T.func_attr({"tir.noalias": True}) + k = T.int64() + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n, k], dtype="int8") + trilu = T.match_buffer(var_trilu, [m, n, k], dtype="int8") + for i0, i1, i2 in T.grid(m, n, k): + with T.block("trilu"): + i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1]) + T.writes(trilu[i0_1, i1_1, i2_1]) + trilu[i0_1, i1_1, i2_1] = T.Select(i1_1 <= i2_1 + T.int64(2), rxplaceholder[i0_1, i1_1, i2_1], T.int8(0)) + # fmt: on + + mod = LegalizeOps()(Triu) + tvm.ir.assert_structural_equal(mod, Expected) + + +##################### Datatype ##################### + + +def test_astype(): + # fmt: off + @tvm.script.ir_module + class Astype: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "int32"): + gv: R.Tensor((2, 3, 4), "int32") = R.astype(x, "int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "int32"): + gv = R.call_tir(Expected.cast, (x,), R.Tensor((2, 3, 4), dtype="int32")) + return gv + + @T.prim_func + def cast(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): + with T.block("compute"): + i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1]) + T.writes(compute[i0_1, i1_1, i2_1]) + compute[i0_1, i1_1, i2_1] = T.Cast("int32", rxplaceholder[i0_1, i1_1, i2_1]) + # fmt: on + + mod = LegalizeOps()(Astype) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_astype_input_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Astype: + @R.function + def main() -> R.Tensor((), "int32"): + gv: R.Tensor((), "int32") = R.astype(R.const(1.5, "float32"), "int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main() -> R.Tensor((), "int32"): + gv: R.Tensor((), "int32") = R.const(1, "int32") + return gv + # fmt: on + + mod = LegalizeOps()(Astype) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_astype_symbolic(): + # fmt: off + @tvm.script.ir_module + class Astype: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "int32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "int32") = R.astype(x, "int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "int32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.cast, (x,), R.Tensor((m, n), dtype="int32")) + return gv + + @T.prim_func + def cast(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="int32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.Cast("int32", rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Astype) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_image.py b/tests/python/relax/test_transform_legalize_ops_image.py new file mode 100644 index 000000000000..18acb282c294 --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_image.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + + +def test_image_resize2d(): + # fmt: off + @tvm.script.ir_module + class Resize2D: + @R.function + def main(x: R.Tensor((2, 8, 8, 3), "float32")) -> R.Tensor((2, 16, 16, 3), "float32"): + gv: R.Tensor((2, 16, 16, 3), "float32") = R.image.resize2d(x, size=(16, 16), layout="NHWC", method="nearest_neighbor", coordinate_transformation_mode="asymmetric") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 8, 8, 3), "float32")) -> R.Tensor((2, 16, 16, 3), "float32"): + gv = R.call_tir(Expected.resize2d, (x,), R.Tensor((2, 16, 16, 3), dtype="float32")) + return gv + + @T.prim_func + def resize2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(8), T.int64(8), T.int64(3)), "float32"), resize: T.Buffer((T.int64(2), T.int64(16), T.int64(16), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(16), T.int64(16), T.int64(3)): + with T.block("resize"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, T.max(T.min(T.Div(i1_1, T.int64(2)), T.int64(7)), T.int64(0)), T.max(T.min(T.Div(i2_1, T.int64(2)), T.int64(7)), T.int64(0)), i3_1]) + T.writes(resize[i0_1, i1_1, i2_1, i3_1]) + resize[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, T.max(T.min(T.Div(i1_1, T.int64(2)), T.int64(7)), T.int64(0)), T.max(T.min(T.Div(i2_1, T.int64(2)), T.int64(7)), T.int64(0)), i3_1] + # fmt: on + + mod = LegalizeOps()(Resize2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_image_resize2d_symbolic(): + # fmt: off + @tvm.script.ir_module + class Resize2D: + @R.function + def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", "h", "w", 16), "float32")) -> R.Tensor(("n", "c", "oh", "ow", 16), "float32"): + n = T.int64() + c = T.int64() + oh = T.int64() + ow = T.int64() + gv: R.Tensor((n, c, oh, ow, 16), "float32") = R.image.resize2d(x, size=(oh, ow), layout="NCHW16c", method="nearest_neighbor", coordinate_transformation_mode="asymmetric") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", "h", "w", 16), "float32")) -> R.Tensor(("n", "c", "oh", "ow", 16), "float32"): + n = T.int64() + c = T.int64() + oh = T.int64() + ow = T.int64() + gv = R.call_tir(Expected.resize2d, (x,), R.Tensor((n, c, oh, ow, 16), dtype="float32")) + return gv + + @T.prim_func + def resize2d(var_rxplaceholder: T.handle, var_resize: T.handle): + T.func_attr({"tir.noalias": True}) + c = T.int64() + h = T.int64() + n = T.int64() + oh = T.int64() + ow = T.int64() + w = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w, T.int64(16)], dtype="float32") + resize = T.match_buffer(var_resize, [n, c, oh, ow, T.int64(16)], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(n, c, oh, ow, T.int64(16)): + with T.block("resize"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(rxplaceholder[i0_1, i1_1, T.int64(0) : T.max(h, T.int64(1)), T.int64(0) : T.max(w, T.int64(1)), i4_1]) + T.writes(resize[i0_1, i1_1, i2_1, i3_1, i4_1]) + resize[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i1_1, T.max(T.min(T.Cast("int64", T.round(T.Cast("float32", h) / T.Cast("float32", oh) * T.Cast("float32", i2_1), dtype="float32")), h - T.int64(1)), T.int64(0)), T.max(T.min(T.Cast("int64", T.round(T.Cast("float32", w) / T.Cast("float32", ow) * T.Cast("float32", i3_1), dtype="float32")), w - T.int64(1)), T.int64(0)), i4_1] + # fmt: on + + mod = LegalizeOps()(Resize2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py new file mode 100644 index 000000000000..e19dc7b5ed3c --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -0,0 +1,401 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + + +##################### Indexing ##################### + + +def test_take(): + # fmt: off + @tvm.script.ir_module + class Take: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((4,), "int64")) -> R.Tensor((2, 4, 4), "float32"): + gv: R.Tensor((2, 4, 4), "float32") = R.take(x, indices, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((4,), "int64")) -> R.Tensor((2, 4, 4), "float32"): + gv = R.call_tir(Expected.take, (x, indices), R.Tensor((2, 4, 4), dtype="float32")) + return gv + + @T.prim_func + def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), rxplaceholder_1: T.Buffer(T.int64(4), "int64"), T_take: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): + with T.block("T_take"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, rxplaceholder_1[ax1], ax2], rxplaceholder_1[ax1]) + T.writes(T_take[ax0, ax1, ax2]) + T_take[ax0, ax1, ax2] = rxplaceholder[ax0, rxplaceholder_1[ax1], ax2] + # fmt: on + + mod = LegalizeOps()(Take) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_take_symbolic(): + # fmt: off + @tvm.script.ir_module + class Take: + @R.function + def main(x: R.Tensor(("m", "n"), "float32"), indices: R.Tensor(("i",), "int64")) -> R.Tensor(("m", "i"), "float32"): + m = T.int64() + i = T.int64() + gv: R.Tensor((m, i), "float32") = R.take(x, indices, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32"), indices: R.Tensor(("i",), "int64")) -> R.Tensor(("m", "i"), "float32"): + m = T.int64() + i = T.int64() + gv = R.call_tir(Expected.take, (x, indices), R.Tensor((m, i), dtype="float32")) + return gv + + @T.prim_func + def take(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_take: T.handle): + T.func_attr({"tir.noalias": True}) + i = T.int64() + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [i], dtype="int64") + T_take = T.match_buffer(var_T_take, [m, i], dtype="float32") + for i0, i1 in T.grid(m, i): + with T.block("T_take"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, rxplaceholder_1[ax1]], rxplaceholder_1[ax1]) + T.writes(T_take[ax0, ax1]) + T_take[ax0, ax1] = rxplaceholder[ax0, rxplaceholder_1[ax1]] + # fmt: on + + mod = LegalizeOps()(Take) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_strided_slice(): + # fmt: off + @tvm.script.ir_module + class StridedSlice: + @R.function + def main(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor((4, 9, 10, 3), "float32"): + gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice(x, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((8, 9, 10, 10), dtype="float32")) -> R.Tensor((4, 9, 10, 3), dtype="float32"): + gv = R.call_tir(Expected.strided_slice, (x,), R.Tensor((4, 9, 10, 3), dtype="float32")) + return gv + + @T.prim_func + def strided_slice(rxplaceholder: T.Buffer((T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(4), T.int64(9), T.int64(10), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(9), T.int64(10), T.int64(3)): + with T.block("T_strided_slice_with_axes"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0 * T.int64(2) + T.int64(1), ax1, ax2, T.int64(8) - ax3 * T.int64(3)]) + T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3]) + T_strided_slice_with_axes[ax0, ax1, ax2, ax3] = rxplaceholder[ax0 * T.int64(2) + T.int64(1), ax1, ax2, T.int64(8) - ax3 * T.int64(3)] + # fmt: on + + mod = LegalizeOps()(StridedSlice) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_strided_slice_no_strides(): + # fmt: off + @tvm.script.ir_module + class StridedSlice: + @R.function + def main(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor((4, 9, 10, 3), "float32"): + gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice(x, axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((8, 9, 10, 10), dtype="float32")) -> R.Tensor((4, 9, 10, 3), dtype="float32"): + gv = R.call_tir(Expected.strided_slice, (x,), out_sinfo=R.Tensor((7, 9, 10, 2), dtype="float32")) + return gv + + @T.prim_func + def strided_slice(rxplaceholder: T.Buffer((T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(7), T.int64(9), T.int64(10), T.int64(2)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(7), T.int64(9), T.int64(10), T.int64(2)): + with T.block("T_strided_slice_with_axes"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0 + T.int64(1), v_ax1, v_ax2, v_ax3 + T.int64(2)]) + T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3]) + T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0 + T.int64(1), v_ax1, v_ax2, v_ax3 + T.int64(2)] + # fmt: on + + mod = LegalizeOps()(StridedSlice) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_strided_slice_symbolic_sliced_axis(): + # fmt: off + @tvm.script.ir_module + class StridedSlice: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((2, "n"), "float32"): + n = T.int64() + gv: R.Tensor((2, n), "float32") = R.strided_slice(x, axes=[0], begin=[1], end=[8], strides=[3]) + return gv + # fmt: on + + mod = LegalizeOps()(StridedSlice) + tvm.ir.assert_structural_equal(mod, StridedSlice) + + +def test_strided_slice_symbolic(): + # fmt: off + @tvm.script.ir_module + class StridedSlice: + @R.function + def main(x: R.Tensor((10, "n"), "float32")) -> R.Tensor((3, "n"), "float32"): + n = T.int64() + gv: R.Tensor((3, n), "float32") = R.strided_slice(x, axes=[0], begin=[1], end=[8], strides=[3]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((10, "n"), dtype="float32")) -> R.Tensor((3, "n"), dtype="float32"): + n = T.int64() + gv = R.call_tir(Expected.strided_slice, (x,), R.Tensor((3, n), dtype="float32")) + return gv + + @T.prim_func + def strided_slice(var_rxplaceholder: T.handle, var_T_strided_slice_with_axes: T.handle): + T.func_attr({"tir.noalias": True}) + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(10), n], dtype="float32") + T_strided_slice_with_axes = T.match_buffer(var_T_strided_slice_with_axes, [T.int64(3), n], dtype="float32") + for i0, i1 in T.grid(T.int64(3), n): + with T.block("T_strided_slice_with_axes"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0 * T.int64(3) + T.int64(1), ax1]) + T.writes(T_strided_slice_with_axes[ax0, ax1]) + T_strided_slice_with_axes[ax0, ax1] = rxplaceholder[ax0 * T.int64(3) + T.int64(1), ax1] + # fmt: on + + mod = LegalizeOps()(StridedSlice) + tvm.ir.assert_structural_equal(mod, Expected) + + +##################### Linear algebra ##################### + + +def test_matmul_1_4(): + # fmt: off + @tvm.script.ir_module + class Matmul: + @R.function + def main(x: R.Tensor((4,), "float32"), y: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 3, 5), "float32"): + gv: R.Tensor((2, 3, 5), "float32") = R.matmul(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((4,), "float32"), y: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 3, 5), "float32"): + gv = R.call_tir(Expected.matmul, (x, y), R.Tensor((2, 3, 5), dtype="float32")) + return gv + + @T.prim_func + def matmul(rxplaceholder: T.Buffer(T.int64(4), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), matmul: T.Buffer((T.int64(2), T.int64(3), T.int64(5)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(5), T.int64(4)): + with T.block("matmul"): + i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k], rxplaceholder_1[i0_1, i1_1, k, i2_1]) + T.writes(matmul[i0_1, i1_1, i2_1]) + with T.init(): + matmul[i0_1, i1_1, i2_1] = T.float32(0) + matmul[i0_1, i1_1, i2_1] = matmul[i0_1, i1_1, i2_1] + rxplaceholder[k] * rxplaceholder_1[i0_1, i1_1, k, i2_1] + # fmt: on + + mod = LegalizeOps()(Matmul) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_matmul_4_1(): + # fmt: off + @tvm.script.ir_module + class Matmul: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.matmul(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv = R.call_tir(Expected.matmul, (x, y), R.Tensor((2, 3, 4), dtype="float32")) + return gv + + @T.prim_func + def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer(T.int64(5), "float32"), matmul: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("matmul"): + i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1, k], rxplaceholder_1[k]) + T.writes(matmul[i0_1, i1_1, i2_1]) + with T.init(): + matmul[i0_1, i1_1, i2_1] = T.float32(0) + matmul[i0_1, i1_1, i2_1] = matmul[i0_1, i1_1, i2_1] + rxplaceholder[i0_1, i1_1, i2_1, k] * rxplaceholder_1[k] + # fmt: on + + mod = LegalizeOps()(Matmul) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_matmul_1_1(): + # fmt: off + @tvm.script.ir_module + class Matmul: + @R.function + def main(x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.matmul(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32")) -> R.Tensor((), "float32"): + gv = R.call_tir(Expected.matmul, (x, y), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def matmul(rxplaceholder: T.Buffer(T.int64(4), "float32"), rxplaceholder_1: T.Buffer(T.int64(4), "float32"), matmul: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + for i0 in T.serial(T.int64(4)): + with T.block("matmul"): + k = T.axis.reduce(T.int64(4), i0) + T.reads(rxplaceholder[k], rxplaceholder_1[k]) + T.writes(matmul[()]) + with T.init(): + matmul[()] = T.float32(0) + matmul[()] = matmul[()] + rxplaceholder[k] * rxplaceholder_1[k] + # fmt: on + + mod = LegalizeOps()(Matmul) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_matmul_4_5(): + # fmt: off + @tvm.script.ir_module + class Matmul: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float16"), y: R.Tensor((6, 2, 3, 5, 7), "float16")) -> R.Tensor((6, 2, 3, 4, 7), "float32"): + gv: R.Tensor((6, 2, 3, 4, 7), "float32") = R.matmul(x, y, out_dtype="float32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float16"), y: R.Tensor((6, 2, 3, 5, 7), "float16")) -> R.Tensor((6, 2, 3, 4, 7), "float32"): + gv = R.call_tir(Expected.matmul, (x, y), R.Tensor((6, 2, 3, 4, 7), dtype="float32")) + return gv + + @T.prim_func + def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"), rxplaceholder_1: T.Buffer((T.int64(6), T.int64(2), T.int64(3), T.int64(5), T.int64(7)), "float16"), matmul: T.Buffer((T.int64(6), T.int64(2), T.int64(3), T.int64(4), T.int64(7)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(6), T.int64(2), T.int64(3), T.int64(4), T.int64(7), T.int64(5)): + with T.block("matmul"): + i0_1, i1_1, i2_1, i3_1, i4_1, k = T.axis.remap("SSSSSR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[i1_1, i2_1, i3_1, k], rxplaceholder_1[i0_1, i1_1, i2_1, k, i4_1]) + T.writes(matmul[i0_1, i1_1, i2_1, i3_1, i4_1]) + with T.init(): + matmul[i0_1, i1_1, i2_1, i3_1, i4_1] = T.float32(0) + matmul[i0_1, i1_1, i2_1, i3_1, i4_1] = matmul[i0_1, i1_1, i2_1, i3_1, i4_1] + T.Cast("float32", rxplaceholder[i1_1, i2_1, i3_1, k]) * T.Cast("float32", rxplaceholder_1[i0_1, i1_1, i2_1, k, i4_1]) + # fmt: on + + mod = LegalizeOps()(Matmul) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_matmul_4_5_symbolic(): + # fmt: off + @tvm.script.ir_module + class Matmul: + @R.function + def main(x: R.Tensor(("b", 1, "m", "k"), "float32"), y: R.Tensor(("a", 1, "c", "k", "n"), "float32")) -> R.Tensor(("a", "b", "c", "m", "n"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + m = T.int64() + n = T.int64() + gv: R.Tensor((a, b, c, m, n), "float32") = R.matmul(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("b", 1, "m", "k"), "float32"), y: R.Tensor(("a", 1, "c", "k", "n"), "float32")) -> R.Tensor(("a", "b", "c", "m", "n"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.matmul, (x, y), R.Tensor((a, b, c, m, n), dtype="float32")) + return gv + + @T.prim_func + def matmul(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + k = T.int64() + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [b, T.int64(1), m, k], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, T.int64(1), c, k, n], dtype="float32") + matmul = T.match_buffer(var_matmul, [a, b, c, m, n], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(a, b, c, m, n, k): + with T.block("matmul"): + i0_1, i1_1, i2_1, i3_1, i4_1, k_1 = T.axis.remap("SSSSSR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[i1_1, T.int64(0), i3_1, k_1], rxplaceholder_1[i0_1, T.int64(0), i2_1, k_1, i4_1]) + T.writes(matmul[i0_1, i1_1, i2_1, i3_1, i4_1]) + with T.init(): + matmul[i0_1, i1_1, i2_1, i3_1, i4_1] = T.float32(0) + matmul[i0_1, i1_1, i2_1, i3_1, i4_1] = matmul[i0_1, i1_1, i2_1, i3_1, i4_1] + rxplaceholder[i1_1, T.int64(0), i3_1, k_1] * rxplaceholder_1[i0_1, T.int64(0), i2_1, k_1, i4_1] + # fmt: on + + mod = LegalizeOps()(Matmul) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py new file mode 100644 index 000000000000..cce35a90263c --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -0,0 +1,1357 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm import relax +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T, ir as I +import tvm.testing + + +##################### Manipulation ##################### + + +def test_broadcast_to(): + # fmt: off + @tvm.script.ir_module + class BroadcastTo: + @R.function + def main(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor((4, 2, 5, 3), "float32"): + gv: R.Tensor((4, 2, 5, 3), "float32") = R.broadcast_to(x, (4, 2, 5, 3)) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor((4, 2, 5, 3), "float32"): + gv = R.call_tir(Expected.broadcast_to, (x,), R.Tensor((4, 2, 5, 3), dtype="float32")) + return gv + + @T.prim_func + def broadcast_to(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3)), "float32"), T_broadcast_to: T.Buffer((T.int64(4), T.int64(2), T.int64(5), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(2), T.int64(5), T.int64(3)): + with T.block("T_broadcast_to"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax1, T.int64(0), ax3]) + T.writes(T_broadcast_to[ax0, ax1, ax2, ax3]) + T_broadcast_to[ax0, ax1, ax2, ax3] = rxplaceholder[ax1, T.int64(0), ax3] + # fmt: on + + mod = LegalizeOps()(BroadcastTo) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_broadcast_to_symbolic(): + # fmt: off + @tvm.script.ir_module + class BroadcastTo: + @R.function + def main(dumb_param: R.Tensor(("a", "c")), x: R.Tensor(("b", 1, "d"), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "float32") = R.broadcast_to(x, (a, b, c, d)) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("a", "c")), x: R.Tensor(("b", 1, "d"), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.broadcast_to, (x,), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def broadcast_to(var_rxplaceholder: T.handle, var_T_broadcast_to: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [b, T.int64(1), d], dtype="float32") + T_broadcast_to = T.match_buffer(var_T_broadcast_to, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_broadcast_to"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax1, T.int64(0), ax3]) + T.writes(T_broadcast_to[ax0, ax1, ax2, ax3]) + T_broadcast_to[ax0, ax1, ax2, ax3] = rxplaceholder[ax1, T.int64(0), ax3] + # fmt: on + + mod = LegalizeOps()(BroadcastTo) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_concat(): + # fmt: off + @tvm.script.ir_module + class Concat: + @R.function + def main(x1: R.Tensor((1, 2, 3), "float32"), x2: R.Tensor((1, 3, 3), "float32"), x3: R.Tensor((1, 4, 3), "float32")) -> R.Tensor((1, 9, 3), "float32"): + gv: R.Tensor((1, 9, 3), "float32") = R.concat((x1, x2, x3), axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x1: R.Tensor((1, 2, 3), "float32"), x2: R.Tensor((1, 3, 3), "float32"), x3: R.Tensor((1, 4, 3), "float32")) -> R.Tensor((1, 9, 3), "float32"): + gv = R.call_tir(Expected.concatenate, (x1, x2, x3), R.Tensor((1, 9, 3), dtype="float32")) + return gv + + @T.prim_func + def concatenate(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1), T.int64(3), T.int64(3)), "float32"), rxplaceholder_2: T.Buffer((T.int64(1), T.int64(4), T.int64(3)), "float32"), T_concat: T.Buffer((T.int64(1), T.int64(9), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(9), T.int64(3)): + with T.block("T_concat"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder_2[ax0, ax1 - T.int64(5), ax2], rxplaceholder_1[ax0, ax1 - T.int64(2), ax2], rxplaceholder[ax0, ax1, ax2]) + T.writes(T_concat[ax0, ax1, ax2]) + T_concat[ax0, ax1, ax2] = T.if_then_else(T.int64(5) <= ax1, rxplaceholder_2[ax0, ax1 - T.int64(5), ax2], T.if_then_else(T.int64(2) <= ax1, rxplaceholder_1[ax0, ax1 - T.int64(2), ax2], rxplaceholder[ax0, ax1, ax2])) + # fmt: on + + mod = LegalizeOps()(Concat) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_concat_input_tuple_var(): + # fmt: off + @tvm.script.ir_module + class Concat: + @R.function + def main(t: R.Tuple(R.Tensor((3, 4), "float32"), R.Tensor((3, 5), "float32"))) -> R.Tensor((3, 9), "float32"): + gv: R.Tensor((3, 9), "float32") = R.concat(t, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(t: R.Tuple(R.Tensor((3, 4), "float32"), R.Tensor((3, 5), "float32"))) -> R.Tensor((3, 9), "float32"): + gv: R.Tensor((3, 4), dtype="float32") = t[0] + gv1: R.Tensor((3, 5), dtype="float32") = t[1] + gv2 = R.call_tir(Expected.concatenate, (gv, gv1), R.Tensor((3, 9), dtype="float32")) + return gv2 + + @T.prim_func + def concatenate(rxplaceholder: T.Buffer((T.int64(3), T.int64(4)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), T.int64(5)), "float32"), T_concat: T.Buffer((T.int64(3), T.int64(9)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(3), T.int64(9)): + with T.block("T_concat"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder_1[ax0, ax1 - T.int64(4)], rxplaceholder[ax0, ax1]) + T.writes(T_concat[ax0, ax1]) + T_concat[ax0, ax1] = T.if_then_else(T.int64(4) <= ax1, rxplaceholder_1[ax0, ax1 - T.int64(4)], rxplaceholder[ax0, ax1]) + # fmt: on + + mod = LegalizeOps()(Concat) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_concat_input_tuple_var_symbolic(): + # fmt: off + @tvm.script.ir_module + class Concat: + @R.function + def main(t: R.Tuple(R.Tensor(("a", "b0"), "float32"), R.Tensor(("a", "b1"), "float32"), R.Tensor(("a", "b2"), "float32"))) -> R.Tensor(("a", "b0 + b1 + b2"), "float32"): + a = T.int64() + b0 = T.int64() + b1 = T.int64() + b2 = T.int64() + gv: R.Tensor((a, b0 + b1 + b2), "float32") = R.concat(t, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(t: R.Tuple(R.Tensor(("a", "b0"), "float32"), R.Tensor(("a", "b1"), "float32"), R.Tensor(("a", "b2"), "float32"))) -> R.Tensor(("a", "b0 + b1 + b2"), "float32"): + a = T.int64() + b0 = T.int64() + b1 = T.int64() + b2 = T.int64() + gv: R.Tensor((a, b0), dtype="float32") = t[0] + gv1: R.Tensor((a, b1), dtype="float32") = t[1] + gv2: R.Tensor((a, b2), dtype="float32") = t[2] + gv3 = R.call_tir(Expected.concatenate, (gv, gv1, gv2), R.Tensor((a, ((b0 + b1) + b2)), dtype="float32")) + return gv3 + + @T.prim_func + def concatenate(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_concat: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b0 = T.int64() + b1 = T.int64() + b2 = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b0], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b1], dtype="float32") + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [a, b2], dtype="float32") + T_concat = T.match_buffer(var_T_concat, [a, b0 + b1 + b2], dtype="float32") + for i0, i1 in T.grid(a, b0 + b1 + b2): + with T.block("T_concat"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder_2[ax0, ax1 - b0 - b1], rxplaceholder_1[ax0, ax1 - b0], rxplaceholder[ax0, ax1]) + T.writes(T_concat[ax0, ax1]) + T_concat[ax0, ax1] = T.if_then_else(T.int64(0) <= ax1 - b0 - b1, rxplaceholder_2[ax0, ax1 - b0 - b1], T.if_then_else(T.int64(0) <= ax1 - b0, rxplaceholder_1[ax0, ax1 - b0], rxplaceholder[ax0, ax1])) + # fmt: on + + mod = LegalizeOps()(Concat) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_expand_dims(): + # fmt: off + @tvm.script.ir_module + class ExpandDims: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32"): + gv: R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32") = R.expand_dims(x, axis=[-1, 1, -6, 3, 5]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32"): + gv = R.call_tir(Expected.expand_dims, (x,), R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), dtype="float32")) + return gv + + @T.prim_func + def expand_dims(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), expand_dims: T.Buffer((T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(3), T.int64(1), T.int64(4), T.int64(1)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(3), T.int64(1), T.int64(4), T.int64(1)): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap("SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7]) + T.reads(rxplaceholder[i0_1, i4_1, i6_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1] = rxplaceholder[i0_1, i4_1, i6_1] + # fmt: on + + mod = LegalizeOps()(ExpandDims) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_expand_dims_symbolic(): + # fmt: off + @tvm.script.ir_module + class ExpandDims: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", 1, "b", 1, "c", 1), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + gv: R.Tensor((a, 1, b, 1, c, 1), "float32") = R.expand_dims(x, axis=[1, 3, 5]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", 1, "b", 1, "c", 1), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + gv = R.call_tir(Expected.expand_dims, (x,), R.Tensor((a, 1, b, 1, c, 1), dtype="float32")) + return gv + + @T.prim_func + def expand_dims(var_rxplaceholder: T.handle, var_expand_dims: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") + expand_dims = T.match_buffer(var_expand_dims, [a, T.int64(1), b, T.int64(1), c, T.int64(1)], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(a, T.int64(1), b, T.int64(1), c, T.int64(1)): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[i0_1, i2_1, i4_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] = rxplaceholder[i0_1, i2_1, i4_1] + # fmt: on + + mod = LegalizeOps()(ExpandDims) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_flatten(): + # fmt: off + @tvm.script.ir_module + class Flatten: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((24,), "float32"): + gv: R.Tensor((24,), "float32") = R.flatten(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((24,), "float32"): + gv = R.call_tir(Expected.reshape, (x,), R.Tensor((24,), dtype="float32")) + return gv + + @T.prim_func + def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), T_reshape: T.Buffer(T.int64(24), "float32")): + T.func_attr({"tir.noalias": True}) + for i0 in T.serial(T.int64(24)): + with T.block("T_reshape"): + ax0 = T.axis.spatial(T.int64(24), i0) + T.reads(rxplaceholder[ax0 % T.int64(24) // T.int64(12), ax0 % T.int64(12) // T.int64(4), ax0 % T.int64(4)]) + T.writes(T_reshape[ax0]) + T_reshape[ax0] = rxplaceholder[ax0 % T.int64(24) // T.int64(12), ax0 % T.int64(12) // T.int64(4), ax0 % T.int64(4)] + # fmt: on + + mod = LegalizeOps()(Flatten) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_flatten_zero_rank(): + # fmt: off + @tvm.script.ir_module + class Flatten: + @R.function + def main(x: R.Tensor((), "float32")) -> R.Tensor((1,), "float32"): + gv: R.Tensor((1,), "float32") = R.flatten(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((), "float32")) -> R.Tensor((1,), "float32"): + gv = R.call_tir(Expected.reshape, (x,), R.Tensor((1,), dtype="float32")) + return gv + + @T.prim_func + def reshape(rxplaceholder: T.Buffer((), "float32"), T_reshape: T.Buffer(T.int64(1), "float32")): + T.func_attr({"tir.noalias": True}) + for i0 in T.serial(T.int64(1)): + with T.block("T_reshape"): + ax0 = T.axis.spatial(T.int64(1), i0) + T.reads(rxplaceholder[()]) + T.writes(T_reshape[ax0]) + T_reshape[ax0] = rxplaceholder[()] + # fmt: on + + mod = LegalizeOps()(Flatten) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_flatten_symbolic(): + # fmt: off + @tvm.script.ir_module + class Flatten: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a * b * c",), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + gv: R.Tensor((a * b * c,), "float32") = R.flatten(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a * b * c",), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + gv = R.call_tir(Expected.reshape, (x,), R.Tensor((((a * b) * c),), dtype="float32")) + return gv + + @T.prim_func + def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") + T_reshape = T.match_buffer(var_T_reshape, [a * b * c], dtype="float32") + for i0 in T.serial(a * b * c): + with T.block("T_reshape"): + ax0 = T.axis.spatial(a * b * c, i0) + T.reads(rxplaceholder[ax0 // c // b % a, ax0 // c % b, ax0 % c]) + T.writes(T_reshape[ax0]) + T_reshape[ax0] = rxplaceholder[ax0 // c // b % a, ax0 // c % b, ax0 % c] + # fmt: on + + mod = LegalizeOps()(Flatten) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_permute_dims(): + # fmt: off + @tvm.script.ir_module + class PermuteDims: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((2, 4, 3, 1), "float32"): + gv: R.Tensor((2, 4, 3, 1), "float32") = R.permute_dims(x, axes=[1, -1, 2, -4]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((2, 4, 3, 1), "float32"): + gv = R.call_tir(Expected.transpose, (x,), R.Tensor((2, 4, 3, 1), dtype="float32")) + return gv + + @T.prim_func + def transpose(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"), T_transpose: T.Buffer((T.int64(2), T.int64(4), T.int64(3), T.int64(1)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(4), T.int64(3), T.int64(1)): + with T.block("T_transpose"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax3, ax0, ax2, ax1]) + T.writes(T_transpose[ax0, ax1, ax2, ax3]) + T_transpose[ax0, ax1, ax2, ax3] = rxplaceholder[ax3, ax0, ax2, ax1] + # fmt: on + + mod = LegalizeOps()(PermuteDims) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_permute_dims_symbolic(): + # fmt: off + @tvm.script.ir_module + class PermuteDims: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("b", "d", "c", "a"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((b, d, c, a), "float32") = R.permute_dims(x, axes=[1, -1, 2, -4]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor(("b", "d", "c", "a"), dtype="float32"): + b = T.int64() + d = T.int64() + c = T.int64() + a = T.int64() + gv = R.call_tir(Expected.transpose, (x,), R.Tensor((b, d, c, a), dtype="float32")) + return gv + + @T.prim_func + def transpose(var_rxplaceholder: T.handle, var_T_transpose: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + T_transpose = T.match_buffer(var_T_transpose, [b, d, c, a], dtype="float32") + for i0, i1, i2, i3 in T.grid(b, d, c, a): + with T.block("T_transpose"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax3, ax0, ax2, ax1]) + T.writes(T_transpose[ax0, ax1, ax2, ax3]) + T_transpose[ax0, ax1, ax2, ax3] = rxplaceholder[ax3, ax0, ax2, ax1] + # fmt: on + + mod = LegalizeOps()(PermuteDims) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_reshape(): + # fmt: off + @tvm.script.ir_module + class Reshape: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3), "float32"): + gv: R.Tensor((8, 3), "float32") = R.reshape(x, (8, 3)) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3), "float32"): + gv = R.call_tir(Expected.reshape, (x,), R.Tensor((8, 3), dtype="float32")) + return gv + + @T.prim_func + def reshape(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"), T_reshape: T.Buffer((T.int64(8), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(8), T.int64(3)): + with T.block("T_reshape"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[T.int64(0), (ax0 * T.int64(3) + ax1) % T.int64(24) // T.int64(12), (ax0 * T.int64(3) + ax1) % T.int64(12) // T.int64(4), (ax0 * T.int64(3) + ax1) % T.int64(4)]) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = rxplaceholder[T.int64(0), (ax0 * T.int64(3) + ax1) % T.int64(24) // T.int64(12), (ax0 * T.int64(3) + ax1) % T.int64(12) // T.int64(4), (ax0 * T.int64(3) + ax1) % T.int64(4)] + # fmt: on + + mod = LegalizeOps()(Reshape) + tvm.ir.assert_structural_equal(mod, Expected) + + # fmt: off + # ShapeExpr might be produced by shape computation + @tvm.script.ir_module + class Reshape2: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3), "float32"): + lv: R.Shape((8, 3)) = R.shape((8, 3)) + gv: R.Tensor((8, 3), "float32") = R.reshape(x, lv) + return gv + + # After lowering, redundant var might be removed by later dead code elimination + @tvm.script.ir_module + class Expected2: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"), + T_reshape: T.Buffer((T.int64(8), T.int64(3)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(8), T.int64(3)): + with T.block("T_reshape"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads( + rxplaceholder[ + T.int64(0), + (v_ax0 * T.int64(3) + v_ax1) % T.int64(24) // T.int64(12), + (v_ax0 * T.int64(3) + v_ax1) % T.int64(12) // T.int64(4), + (v_ax0 * T.int64(3) + v_ax1) % T.int64(4), + ] + ) + T.writes(T_reshape[v_ax0, v_ax1]) + T_reshape[v_ax0, v_ax1] = rxplaceholder[ + T.int64(0), + (v_ax0 * T.int64(3) + v_ax1) % T.int64(24) // T.int64(12), + (v_ax0 * T.int64(3) + v_ax1) % T.int64(12) // T.int64(4), + (v_ax0 * T.int64(3) + v_ax1) % T.int64(4), + ] + + @R.function + def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((8, 3), dtype="float32"): + lv: R.Shape((8, 3)) = R.shape((8, 3)) + gv = R.call_tir(Expected2.reshape, (x,), out_sinfo=R.Tensor((8, 3), dtype="float32")) + return gv + # fmt: on + + mod2 = LegalizeOps()(Reshape2) + tvm.ir.assert_structural_equal(mod2, Expected2) + + +def test_reshape_symbolic(): + # fmt: off + @tvm.script.ir_module + class Reshape: + @R.function + def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "float32"): + a = T.int64() + b = T.int64() + gv: R.Tensor((a // 2, b * 2), "float32") = R.reshape(x, (a // 2, b * 2)) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "float32"): + a = T.int64() + b = T.int64() + gv = R.call_tir(Expected.reshape, (x,), R.Tensor(((a // 2), (b * 2)), dtype="float32")) + return gv + + @T.prim_func + def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b], dtype="float32") + T_reshape = T.match_buffer(var_T_reshape, [a // T.int64(2), b * T.int64(2)], dtype="float32") + for i0, i1 in T.grid(a // T.int64(2), b * T.int64(2)): + with T.block("T_reshape"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[(ax0 * (b * T.int64(2)) + ax1) // b % a, (ax0 * (b * T.int64(2)) + ax1) % b]) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = rxplaceholder[(ax0 * (b * T.int64(2)) + ax1) // b % a, (ax0 * (b * T.int64(2)) + ax1) % b] + # fmt: on + + mod = LegalizeOps()(Reshape) + tvm.ir.assert_structural_equal(mod, Expected) + + # ShapeExpr might be produced by shape computation + @tvm.script.ir_module + class Reshape2: + @R.function + def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "float32"): + a = T.int64() + b = T.int64() + lv: R.Shape((a // 2, b * 2)) = R.shape((a // 2, b * 2)) + gv: R.Tensor((a // 2, b * 2), "float32") = R.reshape(x, lv) + return gv + + # After lowering, redundant var might be removed by later dead code elimination + @tvm.script.ir_module + class Expected2: + @R.function + def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "float32"): + a = T.int64() + b = T.int64() + lv: R.Shape((a // 2, b * 2)) = R.shape((a // 2, b * 2)) + gv = R.call_tir(Expected2.reshape, (x,), R.Tensor(((a // 2), (b * 2)), dtype="float32")) + return gv + + @T.prim_func + def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b], dtype="float32") + T_reshape = T.match_buffer( + var_T_reshape, [a // T.int64(2), b * T.int64(2)], dtype="float32" + ) + for i0, i1 in T.grid(a // T.int64(2), b * T.int64(2)): + with T.block("T_reshape"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads( + rxplaceholder[ + (ax0 * (b * T.int64(2)) + ax1) // b % a, + (ax0 * (b * T.int64(2)) + ax1) % b, + ] + ) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = rxplaceholder[ + (ax0 * (b * T.int64(2)) + ax1) // b % a, (ax0 * (b * T.int64(2)) + ax1) % b + ] + + mod2 = LegalizeOps()(Reshape2) + tvm.ir.assert_structural_equal(mod2, Expected2) + + # ShapeExpr might be produced by shape computation + @I.ir_module + class Reshape3: + @R.function + def main(x: R.Tensor((10, "b"), "float32")) -> R.Tensor((5, "b * 2"), "float32"): + a = T.int64() + b = T.int64() + lv: R.Shape((5, b * 2)) = R.shape((5, b * 2)) + gv: R.Tensor((5, b * 2), "float32") = R.reshape(x, lv) + return gv + + # After lowering, redundant var might be removed by later dead code elimination + @I.ir_module + class Expected3: + @T.prim_func + def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": True}) + b = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(10), b)) + T_reshape = T.match_buffer(var_T_reshape, (T.int64(5), b * T.int64(2))) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(5), b * T.int64(2)): + with T.block("T_reshape"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads( + rxplaceholder[ + (v_ax0 * (b * T.int64(2)) + v_ax1) // b % T.int64(10), + (v_ax0 * (b * T.int64(2)) + v_ax1) % b, + ] + ) + T.writes(T_reshape[v_ax0, v_ax1]) + T_reshape[v_ax0, v_ax1] = rxplaceholder[ + (v_ax0 * (b * T.int64(2)) + v_ax1) // b % T.int64(10), + (v_ax0 * (b * T.int64(2)) + v_ax1) % b, + ] + + @R.function + def main( + x: R.Tensor((10, "b"), dtype="float32") + ) -> R.Tensor((5, "b * 2"), dtype="float32"): + b = T.int64() + lv: R.Shape([5, b * 2]) = R.shape([5, b * 2]) + gv = R.call_tir( + Expected3.reshape, (x,), out_sinfo=R.Tensor((5, b * 2), dtype="float32") + ) + return gv + + mod3 = LegalizeOps()(Reshape3) + tvm.ir.assert_structural_equal(mod3, Expected3) + + +def test_data_dependent_reshape(): + # fmt: off + @tvm.script.ir_module + class DDReshape: + @R.function + def main(x: R.Tensor((3, ), dtype="int64")): + lv: R.Shape([3,]) = R.tensor_to_shape(x) + gv = R.reshape(x, lv) + return gv + # fmt: on + + assert relax.analysis.well_formed(DDReshape) + mod = relax.transform.DecomposeCompositeOps()(DDReshape) + out_mod = relax.transform.LegalizeOps()(mod) + + # fmt: off + @I.ir_module + class Expected: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(3),), "int64"), var_T_reshape: T.handle + ): + T.func_attr({"tir.noalias": True}) + x = T.int64() + T_reshape = T.match_buffer(var_T_reshape, (x,), "int64") + # with T.block("root"): + for ax0 in range(x): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(x, ax0) + T.reads(rxplaceholder[v_ax0 % T.int64(3)]) + T.writes(T_reshape[v_ax0]) + T_reshape[v_ax0] = rxplaceholder[v_ax0 % T.int64(3)] + + @R.function + def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor((3,), dtype="int64"): + x_1 = T.int64() + gv: R.Shape([3]) = R.call_packed( + "vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),) + ) + y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1])) + lv: R.Shape([x_1]) = R.shape([x_1]) + gv_1 = R.call_tir(Expected.reshape, (x,), out_sinfo=R.Tensor((x_1,), dtype="int64")) + return gv_1 + # fmt: on + tvm.ir.assert_structural_equal(out_mod, Expected) + + +def test_split_by_indices(): + # fmt: off + @tvm.script.ir_module + class Split: + @R.function + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 3, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 3, 4), "float32")]): + gv: R.Tuple([R.Tensor((2, 3, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 3, 4), "float32")]) = R.split(x, [3, 7], axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 3, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 3, 4), "float32")]): + gv = R.call_tir(Expected.split, (x,), [R.Tensor((2, 3, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 3, 4), "float32")]) + return gv + + @T.prim_func + def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), T_split_1: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_2: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): + with T.block("T_split"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2]) + T.writes(T_split[ax0, ax1, ax2]) + T_split[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): + with T.block("T_split_1"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(3), ax2]) + T.writes(T_split_1[ax0, ax1, ax2]) + T_split_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(3), ax2] + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): + with T.block("T_split_2"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(7), ax2]) + T.writes(T_split_2[ax0, ax1, ax2]) + T_split_2[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(7), ax2] + # fmt: on + + mod = LegalizeOps()(Split) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_split_by_indices_n_section_indivisible(): + # fmt: off + @tvm.script.ir_module + class Split: + @R.function + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]): + gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, 3, axis=1) + return gv + # fmt: on + + mod = LegalizeOps()(Split) + tvm.ir.assert_structural_equal(mod, Split) + + +def test_split_by_indices_n_section_divisible(): + # fmt: off + @tvm.script.ir_module + class Split: + @R.function + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]): + gv: R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]) = R.split(x, 2, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]): + gv = R.call_tir(Expected.split, (x,), [R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]) + return gv + + @T.prim_func + def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): + with T.block("T_split_sections"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2]) + T.writes(T_split_sections[ax0, ax1, ax2]) + T_split_sections[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): + with T.block("T_split_sections_1"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(5), ax2]) + T.writes(T_split_sections_1[ax0, ax1, ax2]) + T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(5), ax2] + # fmt: on + + mod = LegalizeOps()(Split) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_split_by_indices_n_section_divisible_symbolic(): + # fmt: off + @tvm.script.ir_module + class Split: + @R.function + def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), "float32")) -> R.Tuple([R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32")]): + m = T.int64() + n = T.int64() + gv: R.Tuple([R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32")]) = R.split(x, 3, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3) // 3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], tir_vars=(n,)) + return gv + + @T.prim_func + def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64): + T.func_attr({"tir.noalias": True}) + m = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") + T_split_sections = T.match_buffer(var_T_split_sections, [m, n * T.int64(3) // T.int64(3)], dtype="float32") + T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n * T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], dtype="float32") + T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - n * T.int64(3) // T.int64(3) * T.int64(2)], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_split_sections"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_split_sections[ax0, ax1]) + T_split_sections[ax0, ax1] = rxplaceholder[ax0, ax1] + for i0, i1 in T.grid(m, n): + with T.block("T_split_sections_1"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, n + ax1]) + T.writes(T_split_sections_1[ax0, ax1]) + T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, n + ax1] + for i0, i1 in T.grid(m, n): + with T.block("T_split_sections_2"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, n * T.int64(2) + ax1]) + T.writes(T_split_sections_2[ax0, ax1]) + T_split_sections_2[ax0, ax1] = rxplaceholder[ax0, n * T.int64(2) + ax1] + # fmt: on + + mod = LegalizeOps()(Split) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_squeeze(): + # fmt: off + @tvm.script.ir_module + class Squeeze: + @R.function + def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "float32"): + gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x, [1, 4]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "float32"): + gv = R.call_tir(Expected.squeeze, (x,), R.Tensor((2, 3, 1, 4), dtype="float32")) + return gv + + @T.prim_func + def squeeze(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3), T.int64(1), T.int64(1), T.int64(4)), "float32"), T_squeeze: T.Buffer((T.int64(2), T.int64(3), T.int64(1), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(1), T.int64(4)): + with T.block("T_squeeze"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, T.int64(0), ax1, ax2, T.int64(0), ax3]) + T.writes(T_squeeze[ax0, ax1, ax2, ax3]) + T_squeeze[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, T.int64(0), ax1, ax2, T.int64(0), ax3] + # fmt: on + + mod = LegalizeOps()(Squeeze) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_squeeze_no_axis(): + # fmt: off + @tvm.script.ir_module + class Squeeze: + @R.function + def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "float32"): + gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "float32"): + gv = R.call_tir(Expected.squeeze, (x,), R.Tensor((2, 3, 4), dtype="float32")) + return gv + + @T.prim_func + def squeeze(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3), T.int64(1), T.int64(1), T.int64(4)), "float32"), T_squeeze: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): + with T.block("T_squeeze"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, T.int64(0), ax1, T.int64(0), T.int64(0), ax2]) + T.writes(T_squeeze[ax0, ax1, ax2]) + T_squeeze[ax0, ax1, ax2] = rxplaceholder[ax0, T.int64(0), ax1, T.int64(0), T.int64(0), ax2] + # fmt: on + + mod = LegalizeOps()(Squeeze) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_squeeze_symbolic(): + # fmt: off + @tvm.script.ir_module + class Squeeze: + @R.function + def main(x: R.Tensor(("a", 1, "b", 1), "float32")) -> R.Tensor(("a", "b", 1), "float32"): + a = T.int64() + b = T.int64() + gv: R.Tensor((a, b, 1), "float32") = R.squeeze(x, [1]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", 1, "b", 1), "float32")) -> R.Tensor(("a", "b", 1), "float32"): + a = T.int64() + b = T.int64() + gv = R.call_tir(Expected.squeeze, (x,), R.Tensor((a, b, 1), dtype="float32")) + return gv + + @T.prim_func + def squeeze(var_rxplaceholder: T.handle, var_T_squeeze: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, T.int64(1), b, T.int64(1)], dtype="float32") + T_squeeze = T.match_buffer(var_T_squeeze, [a, b, T.int64(1)], dtype="float32") + for i0, i1, i2 in T.grid(a, b, T.int64(1)): + with T.block("T_squeeze"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, T.int64(0), ax1, ax2]) + T.writes(T_squeeze[ax0, ax1, ax2]) + T_squeeze[ax0, ax1, ax2] = rxplaceholder[ax0, T.int64(0), ax1, ax2] + # fmt: on + + mod = LegalizeOps()(Squeeze) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_collapse_sum_like(): + # fmt: off + @tvm.script.ir_module + class CollapseSumLike: + @R.function + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1, 3), "float32")) -> R.Tensor((1, 3), "float32"): + gv: R.Tensor((1, 3), "float32") = R.collapse_sum_like(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1, 3), "float32")) -> R.Tensor((1, 3), "float32"): + gv = R.call_tir(Expected.collapse_sum, (x,), R.Tensor((1, 3), dtype="float32")) + return gv + + @T.prim_func + def collapse_sum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_red: T.Buffer((T.int64(1), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(3), T.int64(2)): + with T.block("rxplaceholder_red"): + ax0, ax1, k0 = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(rxplaceholder[k0, ax1]) + T.writes(rxplaceholder_red[ax0, ax1]) + with T.init(): + rxplaceholder_red[ax0, ax1] = T.float32(0) + rxplaceholder_red[ax0, ax1] = rxplaceholder_red[ax0, ax1] + rxplaceholder[k0, ax1] + # fmt: on + + mod = LegalizeOps()(CollapseSumLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +@pytest.mark.skip("TOPI collapse_sum not support symbolic now") +def test_collapse_sum_like_symbolic(): + # fmt: off + @tvm.script.ir_module + class CollapseSumLike: + @R.function + def main(x: R.Tensor(("a", "b", "a"), "float32"), y: R.Tensor(("b", 1), "float32")) -> R.Tensor(("b", 1), "float32"): + b = T.int64() + gv: R.Tensor((b, 1), "float32") = R.collapse_sum_like(x, y) + return gv + + # fmt: on + + mod = LegalizeOps()(CollapseSumLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_collapse_sum_to(): + # fmt: off + @tvm.script.ir_module + class CollapseSumTo: + @R.function + def main(x: R.Tensor((3, 2, 3), "float32")) -> R.Tensor((2, 1), "float32"): + gv: R.Tensor((2, 1), "float32") = R.collapse_sum_to(x, (2, 1)) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((3, 2, 3), dtype="float32") + ) -> R.Tensor((2, 1), dtype="float32"): + # block 0 + gv = R.call_tir(Expected.collapse_sum, (x,), R.Tensor((2, 1), dtype="float32")) + return gv + + @T.prim_func + def collapse_sum(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(1)), "float32")): + T.func_attr({"tir.noalias": True}) + for ax0, ax1, k0, k2 in T.grid(T.int64(2), T.int64(1), T.int64(3), T.int64(3)): + with T.block("rxplaceholder_red"): + v_ax0, v_ax1, v_k0, v_k2 = T.axis.remap("SSRR", [ax0, ax1, k0, k2]) + T.reads(rxplaceholder[v_k0, v_ax0, v_k2]) + T.writes(rxplaceholder_red[v_ax0, v_ax1]) + with T.init(): + rxplaceholder_red[v_ax0, v_ax1] = T.float32(0) + rxplaceholder_red[v_ax0, v_ax1] = (rxplaceholder_red[v_ax0, v_ax1] + rxplaceholder[v_k0, v_ax0, v_k2]) + # fmt: on + + mod = LegalizeOps()(CollapseSumTo) + tvm.ir.assert_structural_equal(mod, Expected) + + +@pytest.mark.skip("TOPI collapse_sum not support symbolic now") +def test_collapse_sum_to_symbolic(): + # fmt: off + @tvm.script.ir_module + class CollapseSumTo: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("b", 1), "float32"): + b = T.int64() + gv: R.Tensor((b, 1), "float32") = R.collapse_sum_to(x, (b, 1)) + return gv + + # fmt: on + + mod = LegalizeOps()(CollapseSumTo) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_repeat(): + # fmt: off + @I.ir_module + class Repeat: + @R.function + def main(x: R.Tensor((3, 2, 3), "float32")): + gv = R.repeat(x, 2, 0) + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((6, 2, 3), dtype="float32"): + gv = R.call_tir(Expected.repeat, (x,), out_sinfo=R.Tensor((6, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def repeat(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32"), T_repeat: T.Buffer((T.int64(6), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(T.int64(6), T.int64(2), T.int64(3)): + with T.block("T_repeat"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(rxplaceholder[v_ax0 // T.int64(2), v_ax1, v_ax2]) + T.writes(T_repeat[v_ax0, v_ax1, v_ax2]) + T_repeat[v_ax0, v_ax1, v_ax2] = rxplaceholder[v_ax0 // T.int64(2), v_ax1, v_ax2] + # fmt: on + + mod = LegalizeOps()(Repeat) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_repeat_no_axis(): + # fmt: off + @I.ir_module + class Repeat: + @R.function + def main(x: R.Tensor((3, 2, 3), "float32")): + gv = R.repeat(x, 2) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((3, 2, 3), dtype="float32") + ) -> R.Tensor((36,), dtype="float32"): + gv = R.call_tir(Expected.repeat, (x,), out_sinfo=R.Tensor((36,), dtype="float32")) + return gv + + @T.prim_func + def repeat( + rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32"), + T_repeat: T.Buffer((T.int64(36),), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + T_reshape = T.alloc_buffer((T.int64(18),)) + for ax0 in range(T.int64(18)): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(T.int64(18), ax0) + T.reads( + rxplaceholder[ + v_ax0 % T.int64(18) // T.int64(6), + v_ax0 % T.int64(6) // T.int64(3), + v_ax0 % T.int64(3), + ] + ) + T.writes(T_reshape[v_ax0]) + T_reshape[v_ax0] = rxplaceholder[ + v_ax0 % T.int64(18) // T.int64(6), + v_ax0 % T.int64(6) // T.int64(3), + v_ax0 % T.int64(3), + ] + for ax0 in range(T.int64(36)): + with T.block("T_repeat"): + v_ax0 = T.axis.spatial(T.int64(36), ax0) + T.reads(T_reshape[v_ax0 // T.int64(2)]) + T.writes(T_repeat[v_ax0]) + T_repeat[v_ax0] = T_reshape[v_ax0 // T.int64(2)] + # fmt: on + + mod = LegalizeOps()(Repeat) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_repeat_symbolic(): + # fmt: off + @I.ir_module + class Repeat: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")): + gv = R.repeat(x, 2, 0) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def repeat(var_rxplaceholder: T.handle, var_T_repeat: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c)) + T_repeat = T.match_buffer(var_T_repeat, (T.int64(2) * a, b, c)) + # with T.block("root"): + for ax0, ax1, ax2 in T.grid(a * T.int64(2), b, c): + with T.block("T_repeat"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(rxplaceholder[v_ax0 // T.int64(2), v_ax1, v_ax2]) + T.writes(T_repeat[v_ax0, v_ax1, v_ax2]) + T_repeat[v_ax0, v_ax1, v_ax2] = rxplaceholder[v_ax0 // T.int64(2), v_ax1, v_ax2] + + @R.function + def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("2 * a", "b", "c"), dtype="float32"): + a = T.Var("a", "int64") + b = T.Var("b", "int64") + c = T.Var("c", "int64") + gv = R.call_tir(Expected.repeat, (x,), out_sinfo=R.Tensor((2 * a, b, c), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Repeat) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_tile(): + # fmt: off + @I.ir_module + class Tile: + @R.function + def main(x: R.Tensor((3, 2, 3), "float32")): + gv = R.tile(x, (2, 1, 2, 3)) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def tile(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32"), T_tile: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(9)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(9)): + with T.block("T_tile"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax1 % T.int64(3), v_ax2 % T.int64(2), v_ax3 % T.int64(3)]) + T.writes(T_tile[v_ax0, v_ax1, v_ax2, v_ax3]) + T_tile[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax1 % T.int64(3), v_ax2 % T.int64(2), v_ax3 % T.int64(3)] + + @R.function + def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((2, 3, 4, 9), dtype="float32"): + gv = R.call_tir(Expected.tile, (x,), out_sinfo=R.Tensor((2, 3, 4, 9), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Tile) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_tile_symbolic(): + # fmt: off + @I.ir_module + class Tile: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")): + gv = R.tile(x, (2, 1, 2, 3)) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def tile(var_rxplaceholder: T.handle, var_T_tile: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c)) + T_tile = T.match_buffer(var_T_tile, (T.int64(2), a, b * T.int64(2), c * T.int64(3))) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), a, b * T.int64(2), c * T.int64(3)): + with T.block("T_tile"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax1 % a, v_ax2 % b, v_ax3 % c]) + T.writes(T_tile[v_ax0, v_ax1, v_ax2, v_ax3]) + T_tile[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax1 % a, v_ax2 % b, v_ax3 % c] + + @R.function + def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor((2, "a", "b * 2", "c * 3"), dtype="float32"): + a = T.Var("a", "int64") + b = T.Var("b", "int64") + c = T.Var("c", "int64") + gv = R.call_tir(Expected.tile, (x,), out_sinfo=R.Tensor((2, a, b * 2, c * 3), dtype="float32")) + return gv + # fmt: on + mod = LegalizeOps()(Tile) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cumsum(): + # fmt: off + @I.ir_module + class Cumsum: + @R.function + def main(x: R.Tensor((3, 2, 3), "float32")): + gv = R.cumsum(x, axis=1, dtype="int32") + return gv + + @I.ir_module + class Expected: + @T.prim_func + def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3), T.int64(2), T.int64(3)), offset_factor=1) + with T.block("cumsum_generic"): + T.reads(rxplaceholder[T.int64(0):T.int64(3), T.int64(0):T.int64(2), T.int64(0):T.int64(3)]) + T.writes(out_buf[T.int64(0):T.int64(3), T.int64(0):T.int64(2), T.int64(0):T.int64(3)]) + for fused in T.parallel(T.int64(9)): + out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) % T.int64(3)] = T.Cast("int32", rxplaceholder[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) % T.int64(3)]) + for _k in range(T.int64(1)): + out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) % T.int64(3)] = out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1) - T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1) - T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1) - T.int64(1)) * T.int64(3)) % T.int64(3)] + T.Cast("int32", rxplaceholder[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) % T.int64(3)]) + + @R.function + def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((3, 2, 3), dtype="int32"): + cls = Expected + gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((3, 2, 3), dtype="int32")) + return gv + # fmt: on + + mod = LegalizeOps()(Cumsum) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cumsum_symbolic(): + # fmt: off + @I.ir_module + class Cumsum: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")): + gv = R.cumsum(x, axis=1, dtype="int32") + return gv + + @I.ir_module + class Expected: + @T.prim_func + def cumsum(var_rxplaceholder: T.handle, var_cumsum_generic: T.handle): + T.func_attr({"tir.noalias": True}) + a, b, c = T.int64(), T.int64(), T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c), offset_factor=1) + out_buf = T.match_buffer(var_cumsum_generic, (a, b, c), "int32") + with T.block("cumsum_generic"): + T.reads(rxplaceholder[T.int64(0):a, T.int64(0):b, T.int64(0):c]) + T.writes(out_buf[T.int64(0):a, T.int64(0):b, T.int64(0):c]) + for fused in T.parallel(a * c): + out_buf[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c] = T.Cast("int32", rxplaceholder[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c]) + for _k in range(b - T.int64(1)): + out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) % c] = out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) % c] + T.Cast("int32", rxplaceholder[(fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) % c]) + + @R.function + def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", "c"), dtype="int32"): + a = T.int64() + b = T.int64() + c = T.int64() + cls = Expected + gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((a, b, c), dtype="int32")) + return gv + # fmt: on + + mod = LegalizeOps()(Cumsum) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py new file mode 100644 index 000000000000..e807082e3526 --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -0,0 +1,2447 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T, ir as I +import tvm.testing + + +##################### Neural network ##################### + + +def test_conv1d(): + # fmt: off + @tvm.script.ir_module + class Conv1d: + @R.function + def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((64, 16, 3), "float32")) -> R.Tensor((2, 64, 13), "float32"): + gv: R.Tensor((2, 4, 13), "float32") = R.nn.conv1d(x, w, strides=(2,), padding=(1,), dilation=(2,), groups=8) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16, 3), dtype="float32")) -> R.Tensor((2, 64, 13), dtype="float32"): + gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 64, 13), dtype="float32")) + return gv + + @T.prim_func + def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), T.int64(13)), "float32")): + T.func_attr({"tir.noalias": True}) + pad_temp = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(30))) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(30)): + with T.block("pad_temp"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)]) + T.writes(pad_temp[v_i0, v_i1, v_i2]) + pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(29), rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0)) + for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(128), T.int64(3)): + with T.block("conv1d_ncw"): + v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry]) + T.reads(pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)], rxplaceholder_1[v_ff, v_rc, v_ry]) + T.writes(conv1d_ncw[v_nn, v_ff, v_yy]) + with T.init(): + conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0) + conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)] * rxplaceholder_1[v_ff, v_rc, v_ry] + # fmt: on + + mod = LegalizeOps()(Conv1d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv1d_with_out_dtype(): + # fmt: off + @tvm.script.ir_module + class Conv1d: + @R.function + def main(x: R.Tensor((2, 3, 28), "float32"), w: R.Tensor((4, 3, 3), "float32")) -> R.Tensor((2, 4, 26), "float16"): + gv: R.Tensor((2, 4, 26), "float16") = R.nn.conv1d(x, w, out_dtype="float16") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 28), dtype="float32"), w: R.Tensor((4, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 26), dtype="float16"): + gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 4, 26), dtype="float16")) + return gv + + @T.prim_func + def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(4), T.int64(26)), "float16")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + pad_temp = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28))) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(28)): + with T.block("pad_temp"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2]) + T.writes(pad_temp[v_i0, v_i1, v_i2]) + pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2] + for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(3), T.int64(3)): + with T.block("conv1d_ncw"): + v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry]) + T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry], rxplaceholder_1[v_ff, v_rc, v_ry]) + T.writes(conv1d_ncw[v_nn, v_ff, v_yy]) + with T.init(): + conv1d_ncw[v_nn, v_ff, v_yy] = T.float16(0) + conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + T.Cast("float16", pad_temp[v_nn, v_rc, v_yy + v_ry]) * T.Cast("float16", rxplaceholder_1[v_ff, v_rc, v_ry]) + # fmt: on + + mod = LegalizeOps()(Conv1d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv1d_nwc(): + # fmt: off + @tvm.script.ir_module + class Conv1d: + @R.function + def main(x: R.Tensor((2, 28, 128), "float32"), w: R.Tensor((64, 128, 3), "float32")) -> R.Tensor((2, 26, 64), "float32"): + gv: R.Tensor((2, 26, 64), "float32") = R.nn.conv1d(x, w, data_layout="NWC") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 28, 128), dtype="float32"), w: R.Tensor((64, 128, 3), dtype="float32")) -> R.Tensor((2, 26, 64), dtype="float32"): + gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 26, 64), dtype="float32")) + return gv + + @T.prim_func + def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(28), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(128), T.int64(3)), "float32"), conv1d_nwc: T.Buffer((T.int64(2), T.int64(26), T.int64(64)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + pad_temp = T.alloc_buffer((T.int64(2), T.int64(28), T.int64(128))) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(28), T.int64(128)): + with T.block("pad_temp"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2]) + T.writes(pad_temp[v_i0, v_i1, v_i2]) + pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2] + for nn, yy, ff, ry, rc in T.grid(T.int64(2), T.int64(26), T.int64(64), T.int64(3), T.int64(128)): + with T.block("conv1d_nwc"): + v_nn, v_yy, v_ff, v_ry, v_rc = T.axis.remap("SSSRR", [nn, yy, ff, ry, rc]) + T.reads(pad_temp[v_nn, v_yy + v_ry, v_rc], rxplaceholder_1[v_ff, v_rc, v_ry]) + T.writes(conv1d_nwc[v_nn, v_yy, v_ff]) + with T.init(): + conv1d_nwc[v_nn, v_yy, v_ff] = T.float32(0) + conv1d_nwc[v_nn, v_yy, v_ff] = conv1d_nwc[v_nn, v_yy, v_ff] + pad_temp[v_nn, v_yy + v_ry, v_rc] * rxplaceholder_1[v_ff, v_rc, v_ry] + # fmt: on + + mod = LegalizeOps()(Conv1d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv1d_symbolic(): + # fmt: off + @tvm.script.ir_module + class Conv1d: + @R.function + def main(x: R.Tensor(("n", "c", "w"), "float32"), kernel: R.Tensor(("f", "c", "kw"), "float32")) -> R.Tensor(("n", "f", "w - kw + 1"), "float32"): + n = T.int64() + w = T.int64() + f = T.int64() + kw = T.int64() + gv: R.Tensor((n, f, w - kw + 1), "float32") = R.nn.conv1d(x, kernel) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("n", "c", "w"), dtype="float32"), kernel: R.Tensor(("f", "c", "kw"), dtype="float32")) -> R.Tensor(("n", "f", "w - kw + 1"), dtype="float32"): + n = T.int64() + f = T.int64() + w = T.int64() + kw = T.int64() + c = T.int64() + gv = R.call_tir(Expected.conv1d, (x, kernel), out_sinfo=R.Tensor((n, f, w - kw + 1), dtype="float32")) + return gv + + @T.prim_func + def conv1d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv1d_ncw: T.handle): + T.func_attr({"tir.noalias": True}) + n, c, w = T.int64(), T.int64(), T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, w)) + f, kw = T.int64(), T.int64() + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kw)) + conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w - kw + T.int64(1))) + # with T.block("root"): + pad_temp = T.alloc_buffer((n, c, w)) + for i0, i1, i2 in T.grid(n, c, w): + with T.block("pad_temp"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2]) + T.writes(pad_temp[v_i0, v_i1, v_i2]) + pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2] + for nn, ff, yy, rc, ry in T.grid(n, f, w + T.int64(1) - kw, c, kw): + with T.block("conv1d_ncw"): + v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry]) + T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry], rxplaceholder_1[v_ff, v_rc, v_ry]) + T.writes(conv1d_ncw[v_nn, v_ff, v_yy]) + with T.init(): + conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0) + conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_rc, v_yy + v_ry] * rxplaceholder_1[v_ff, v_rc, v_ry] + # fmt: on + + mod = LegalizeOps()(Conv1d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv2d(): + # fmt: off + @tvm.script.ir_module + class Conv2d: + @R.function + def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((64, 16, 3, 3), "float32")) -> R.Tensor((2, 64, 13, 13), "float32"): + gv: R.Tensor((2, 4, 13, 13), "float32") = R.nn.conv2d(x, w, strides=(2, 2), padding=(1, 1), dilation=(2, 2), groups=8) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((64, 16, 3, 3), "float32")) -> R.Tensor((2, 64, 13, 13), "float32"): + gv = R.call_tir(Expected.conv2d, (x, w), R.Tensor((2, 64, 13, 13), dtype="float32")) + return gv + + @T.prim_func + def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(16), T.int64(3), T.int64(3)), "float32"), group_conv2d_nchw: T.Buffer((T.int64(2), T.int64(64), T.int64(13), T.int64(13)), "float32")): + T.func_attr({"tir.noalias": True}) + pad_temp = T.alloc_buffer([T.int64(2), T.int64(128), T.int64(30), T.int64(30)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), T.int64(30), T.int64(30)): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1)]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(T.int64(1) <= i2_1 and i2_1 < T.int64(29) and T.int64(1) <= i3_1 and i3_1 < T.int64(29), rxplaceholder[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1)], T.float32(0), dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(13), T.int64(16), T.int64(3), T.int64(3)): + with T.block("group_conv2d_nchw"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, ff // T.int64(8) * T.int64(16) + rc, yy * T.int64(2) + ry * T.int64(2), xx * T.int64(2) + rx * T.int64(2)], rxplaceholder_1[ff, rc, ry, rx]) + T.writes(group_conv2d_nchw[nn, ff, yy, xx]) + with T.init(): + group_conv2d_nchw[nn, ff, yy, xx] = T.float32(0) + group_conv2d_nchw[nn, ff, yy, xx] = group_conv2d_nchw[nn, ff, yy, xx] + pad_temp[nn, ff // T.int64(8) * T.int64(16) + rc, yy * T.int64(2) + ry * T.int64(2), xx * T.int64(2) + rx * T.int64(2)] * rxplaceholder_1[ff, rc, ry, rx] + # fmt: on + + mod = LegalizeOps()(Conv2d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv2d_with_out_dtype(): + # fmt: off + @tvm.script.ir_module + class Conv2d: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32")) -> R.Tensor((2, 4, 26, 26), "float16"): + gv: R.Tensor((2, 4, 26, 26), "float16") = R.nn.conv2d(x, w, out_dtype="float16") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32")) -> R.Tensor((2, 4, 26, 26), "float16"): + gv = R.call_tir(Expected.conv2d, (x, w), R.Tensor((2, 4, 26, 26), dtype="float16")) + return gv + + @T.prim_func + def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float16")): + T.func_attr({"tir.noalias": True}) + pad_temp = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26), T.int64(3), T.int64(3), T.int64(3)): + with T.block("conv2d_nchw"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, rc, yy + ry, xx + rx], rxplaceholder_1[ff, rc, ry, rx]) + T.writes(conv2d_nchw[nn, ff, yy, xx]) + with T.init(): + conv2d_nchw[nn, ff, yy, xx] = T.float16(0) + conv2d_nchw[nn, ff, yy, xx] = conv2d_nchw[nn, ff, yy, xx] + T.Cast("float16", pad_temp[nn, rc, yy + ry, xx + rx]) * T.Cast("float16", rxplaceholder_1[ff, rc, ry, rx]) + # fmt: on + + mod = LegalizeOps()(Conv2d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv2d_nhwc(): + # fmt: off + @tvm.script.ir_module + class Conv2d: + @R.function + def main(x: R.Tensor((2, 28, 28, 128), "float32"), w: R.Tensor((64, 128, 3, 3), "float32")) -> R.Tensor((2, 26, 26, 64), "float32"): + gv: R.Tensor((2, 26, 26, 64), "float32") = R.nn.conv2d(x, w, data_layout="NHWC") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 28, 28, 128), "float32"), w: R.Tensor((64, 128, 3, 3), "float32")) -> R.Tensor((2, 26, 26, 64), "float32"): + gv = R.call_tir(Expected.conv2d, (x, w), R.Tensor((2, 26, 26, 64), dtype="float32")) + return gv + + @T.prim_func + def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(28), T.int64(28), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nhwc: T.Buffer((T.int64(2), T.int64(26), T.int64(26), T.int64(64)), "float32")): + T.func_attr({"tir.noalias": True}) + pad_temp = T.alloc_buffer([T.int64(2), T.int64(28), T.int64(28), T.int64(128)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(28), T.int64(28), T.int64(128)): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(26), T.int64(26), T.int64(64), T.int64(3), T.int64(3), T.int64(128)): + with T.block("conv2d_nhwc"): + nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, yy + ry, xx + rx, rc], rxplaceholder_1[ff, rc, ry, rx]) + T.writes(conv2d_nhwc[nn, yy, xx, ff]) + with T.init(): + conv2d_nhwc[nn, yy, xx, ff] = T.float32(0) + conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + pad_temp[nn, yy + ry, xx + rx, rc] * rxplaceholder_1[ff, rc, ry, rx] + # fmt: on + + mod = LegalizeOps()(Conv2d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv2d_symbolic(): + # fmt: off + @tvm.script.ir_module + class Conv2d: + @R.function + def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c", "kh", "kw"), "float32")) -> R.Tensor(("n", "f", "h - kh + 1", "w - kw + 1"), "float32"): + n = T.int64() + h = T.int64() + w = T.int64() + f = T.int64() + kh = T.int64() + kw = T.int64() + gv: R.Tensor((n, f, h - kh + 1, w - kw + 1), "float32") = R.nn.conv2d(x, kernel) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c", "kh", "kw"), "float32")) -> R.Tensor(("n", "f", "h - kh + 1", "w - kw + 1"), "float32"): + n = T.int64() + f = T.int64() + h = T.int64() + kh = T.int64() + w = T.int64() + kw = T.int64() + gv = R.call_tir(Expected.conv2d, (x, kernel), R.Tensor((n, f, ((h - kh) + 1), ((w - kw) + 1)), dtype="float32")) + return gv + + @T.prim_func + def conv2d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv2d_nchw: T.handle): + T.func_attr({"tir.noalias": True}) + c = T.int64() + f = T.int64() + h = T.int64() + kh = T.int64() + kw = T.int64() + n = T.int64() + w = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [f, c, kh, kw], dtype="float32") + conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h - kh + T.int64(1), w - kw + T.int64(1)], dtype="float32") + pad_temp = T.alloc_buffer([n, c, h, w], dtype="float32") + for i0, i1, i2, i3 in T.grid(n, c, h, w): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(n, f, h + T.int64(1) - kh, w + T.int64(1) - kw, c, kh, kw): + with T.block("conv2d_nchw"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, rc, yy + ry, xx + rx], rxplaceholder_1[ff, rc, ry, rx]) + T.writes(conv2d_nchw[nn, ff, yy, xx]) + with T.init(): + conv2d_nchw[nn, ff, yy, xx] = T.float32(0) + conv2d_nchw[nn, ff, yy, xx] = conv2d_nchw[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * rxplaceholder_1[ff, rc, ry, rx] + # fmt: on + + mod = LegalizeOps()(Conv2d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv2d_transpose(): + # fmt: off + @I.ir_module + class Conv2dTranspose: + @R.function + def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((128, 16, 3, 3), "float32")): + gv = R.nn.conv2d_transpose(x, w, strides=(2, 3), padding=(1, 1), dilation=(1, 1), output_padding=(1, 2), groups=8) + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 128, 28, 28), dtype="float32"), w: R.Tensor((128, 16, 3, 3), dtype="float32")) -> R.Tensor((2, 128, 56, 84), dtype="float32"): + gv = R.call_tir(Expected.conv2d_transpose, (x, w), out_sinfo=R.Tensor((2, 128, 56, 84), dtype="float32")) + return gv + + @T.prim_func + def conv2d_transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(16), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(128), T.int64(56), T.int64(84)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + data_dilate = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(55), T.int64(82))) + data_pad = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(58), T.int64(86))) + kernel_transform = T.alloc_buffer((T.int64(16), T.int64(128), T.int64(3), T.int64(3))) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), T.int64(55), T.int64(82)): + with T.block("data_dilate"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(3)]) + T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3]) + data_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2 % T.int64(2) == T.int64(0) and v_i3 % T.int64(3) == T.int64(0), rxplaceholder[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(3)], T.float32(0)) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), T.int64(58), T.int64(86)): + with T.block("data_pad"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)]) + T.writes(data_pad[v_i0, v_i1, v_i2, v_i3]) + data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(56) and T.int64(1) <= v_i3 and v_i3 < T.int64(83), data_dilate[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float32(0)) + for i, o, h, w in T.grid(T.int64(16), T.int64(128), T.int64(3), T.int64(3)): + with T.block("kernel_transform"): + v_i, v_o, v_h, v_w = T.axis.remap("SSSS", [i, o, h, w]) + T.reads(rxplaceholder_1[v_o, v_i, T.int64(2) - v_h, T.int64(2) - v_w]) + T.writes(kernel_transform[v_i, v_o, v_h, v_w]) + kernel_transform[v_i, v_o, v_h, v_w] = rxplaceholder_1[v_o, v_i, T.int64(2) - v_h, T.int64(2) - v_w] + for b, c, h, w, dc, dh, dw in T.grid(T.int64(2), T.int64(128), T.int64(56), T.int64(84), T.int64(16), T.int64(3), T.int64(3)): + with T.block("compute"): + v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = T.axis.remap("SSSSRRR", [b, c, h, w, dc, dh, dw]) + T.reads(data_pad[v_b, v_c // T.int64(16) * T.int64(16) + v_dc, v_h + v_dh, v_w + v_dw], kernel_transform[v_c % T.int64(16), v_c // T.int64(16) * T.int64(16) + v_dc, v_dh, v_dw]) + T.writes(compute[v_b, v_c, v_h, v_w]) + with T.init(): + compute[v_b, v_c, v_h, v_w] = T.float32(0) + compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] + data_pad[v_b, v_c // T.int64(16) * T.int64(16) + v_dc, v_h + v_dh, v_w + v_dw] * kernel_transform[v_c % T.int64(16), v_c // T.int64(16) * T.int64(16) + v_dc, v_dh, v_dw] + # fmt: on + + mod = LegalizeOps()(Conv2dTranspose) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv2d_transpose_with_out_dtype(): + # fmt: off + @tvm.script.ir_module + class Conv2dTranspose: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 4, 3, 3), "float32")): + gv = R.nn.conv2d_transpose(x, w, out_dtype="float16") + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 4, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 30, 30), dtype="float16"): + gv = R.call_tir(Expected.conv2d_transpose, (x, w), out_sinfo=R.Tensor((2, 4, 30, 30), dtype="float16")) + return gv + + @T.prim_func + def conv2d_transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), T.int64(4), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4), T.int64(30), T.int64(30)), "float16")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + data_dilate = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + data_pad = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(32), T.int64(32))) + kernel_transform = T.alloc_buffer((T.int64(4), T.int64(3), T.int64(3), T.int64(3))) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("data_dilate"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3]) + T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3]) + data_dilate[v_i0, v_i1, v_i2, v_i3] = rxplaceholder[v_i0, v_i1, v_i2, v_i3] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(32)): + with T.block("data_pad"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2)]) + T.writes(data_pad[v_i0, v_i1, v_i2, v_i3]) + data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(30) and T.int64(2) <= v_i3 and v_i3 < T.int64(30), data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2)], T.float32(0)) + for o, i, h, w in T.grid(T.int64(4), T.int64(3), T.int64(3), T.int64(3)): + with T.block("kernel_transform"): + v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h, w]) + T.reads(rxplaceholder_1[v_i, v_o, T.int64(2) - v_h, T.int64(2) - v_w]) + T.writes(kernel_transform[v_o, v_i, v_h, v_w]) + kernel_transform[v_o, v_i, v_h, v_w] = rxplaceholder_1[v_i, v_o, T.int64(2) - v_h, T.int64(2) - v_w] + for b, c, h, w, dc, dh, dw in T.grid(T.int64(2), T.int64(4), T.int64(30), T.int64(30), T.int64(3), T.int64(3), T.int64(3)): + with T.block("compute"): + v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = T.axis.remap("SSSSRRR", [b, c, h, w, dc, dh, dw]) + T.reads(data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw], kernel_transform[v_c, v_dc, v_dh, v_dw]) + T.writes(compute[v_b, v_c, v_h, v_w]) + with T.init(): + compute[v_b, v_c, v_h, v_w] = T.float16(0) + compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] + T.Cast("float16", data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw]) * T.Cast("float16", kernel_transform[v_c, v_dc, v_dh, v_dw]) + # fmt: on + + mod = LegalizeOps()(Conv2dTranspose) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv2d_transpose_symbolic(): + # fmt: off + @tvm.script.ir_module + class Conv2dTranspose: + @R.function + def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c", "kh", "kw"), "float32")): + gv = R.nn.conv2d_transpose(x, kernel, strides=(3, 3)) + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("n", "c", "h", "w"), dtype="float32"), kernel: R.Tensor(("f", "c", "kh", "kw"), dtype="float32")) -> R.Tensor(("n", "c", "h * 3 + kh - 3", "w * 3 + kw - 3"), dtype="float32"): + n = T.Var("n", "int64") + c = T.Var("c", "int64") + h = T.Var("h", "int64") + kh = T.Var("kh", "int64") + w = T.Var("w", "int64") + kw = T.Var("kw", "int64") + f = T.Var("f", "int64") + gv = R.call_tir(Expected.conv2d_transpose, (x, kernel), out_sinfo=R.Tensor((n, c, h * 3 + kh - 3, w * 3 + kw - 3), dtype="float32")) + return gv + + @T.prim_func + def conv2d_transpose(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + n = T.var("int64") + c = T.var("int64") + h = T.var("int64") + w = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, h, w)) + f = T.var("int64") + kh = T.var("int64") + kw = T.var("int64") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kh, kw)) + compute = T.match_buffer(var_compute, (n, c, h * T.int64(3) + kh - T.int64(3), w * T.int64(3) + kw - T.int64(3))) + # with T.block("root"): + data_dilate = T.alloc_buffer((n, c, h * T.int64(3) - T.int64(2), w * T.int64(3) - T.int64(2))) + data_pad = T.alloc_buffer((n, c, h * T.int64(3) + kh * T.int64(2) - T.int64(4), w * T.int64(3) + kw * T.int64(2) - T.int64(4))) + kernel_transform = T.alloc_buffer((c, c, kh, kw)) + for i0, i1, i2, i3 in T.grid(n, c, h * T.int64(3) - T.int64(2), w * T.int64(3) - T.int64(2)): + with T.block("data_dilate"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2 // T.int64(3), v_i3 // T.int64(3)]) + T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3]) + data_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2 % T.int64(3) == T.int64(0) and v_i3 % T.int64(3) == T.int64(0), rxplaceholder[v_i0, v_i1, v_i2 // T.int64(3), v_i3 // T.int64(3)], T.float32(0)) + for i0, i1, i2, i3 in T.grid(n, c, h * T.int64(3) + kh * T.int64(2) - T.int64(4), w * T.int64(3) + kw * T.int64(2) - T.int64(4)): + with T.block("data_pad"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh, v_i3 + T.int64(1) - kw]) + T.writes(data_pad[v_i0, v_i1, v_i2, v_i3]) + data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(kh - T.int64(1) <= v_i2 and v_i2 < h * T.int64(3) + kh - T.int64(3) and kw - T.int64(1) <= v_i3 and v_i3 < w * T.int64(3) + kw - T.int64(3), data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh, v_i3 + T.int64(1) - kw], T.float32(0)) + for o, i, h_1, w_1 in T.grid(c, c, kh, kw): + with T.block("kernel_transform"): + v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h_1, w_1]) + T.reads(rxplaceholder_1[v_i, v_o, kh - v_h - T.int64(1), kw - v_w - T.int64(1)]) + T.writes(kernel_transform[v_o, v_i, v_h, v_w]) + kernel_transform[v_o, v_i, v_h, v_w] = rxplaceholder_1[v_i, v_o, kh - v_h - T.int64(1), kw - v_w - T.int64(1)] + for b, c_1, h_1, w_1, dc, dh, dw in T.grid(n, c, h * T.int64(3) + kh - T.int64(3), w * T.int64(3) + kw - T.int64(3), c, kh, kw): + with T.block("compute"): + v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = T.axis.remap("SSSSRRR", [b, c_1, h_1, w_1, dc, dh, dw]) + T.reads(data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw], kernel_transform[v_c, v_dc, v_dh, v_dw]) + T.writes(compute[v_b, v_c, v_h, v_w]) + with T.init(): + compute[v_b, v_c, v_h, v_w] = T.float32(0) + compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] + data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw] * kernel_transform[v_c, v_dc, v_dh, v_dw] + # fmt: on + + mod = LegalizeOps()(Conv2dTranspose) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_max_pool2d(): + # fmt: off + @tvm.script.ir_module + class MaxPool2D: + @R.function + def main(x: R.Tensor((4, 112, 112, 6), "float32")) -> R.Tensor((4, 56, 56, 6), "float32"): + gv: R.Tensor((4, 56, 56, 6), "float32") = R.nn.max_pool2d(x, pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1], layout="NHWC") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((4, 112, 112, 6), "float32")) -> R.Tensor((4, 56, 56, 6), "float32"): + gv = R.call_tir(Expected.max_pool2d, (x,), R.Tensor((4, 56, 56, 6), dtype="float32")) + return gv + + @T.prim_func + def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112), T.int64(112), T.int64(6)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(56), T.int64(56), T.int64(6)), "float32")): + T.func_attr({"tir.noalias": True}) + pad_temp = T.alloc_buffer([T.int64(4), T.int64(114), T.int64(114), T.int64(6)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(114), T.int64(114), T.int64(6)): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1 - T.int64(1), ax2 - T.int64(1), ax3]) + T.writes(pad_temp[ax0, ax1, ax2, ax3]) + pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else(T.int64(1) <= ax1 and ax1 < T.int64(113) and T.int64(1) <= ax2 and ax2 < T.int64(113), rxplaceholder[ax0, ax1 - T.int64(1), ax2 - T.int64(1), ax3], T.float32(-3.4028234663852886e+38), dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(4), T.int64(56), T.int64(56), T.int64(6), T.int64(3), T.int64(3)): + with T.block("pool_max"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(pad_temp[ax0, ax1 * T.int64(2) + rv0, ax2 * T.int64(2) + rv1, ax3]) + T.writes(pool_max[ax0, ax1, ax2, ax3]) + T.block_attr({"schedule_rule":"meta_schedule.pool_max"}) + with T.init(): + pool_max[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e+38) + pool_max[ax0, ax1, ax2, ax3] = T.max(pool_max[ax0, ax1, ax2, ax3], pad_temp[ax0, ax1 * T.int64(2) + rv0, ax2 * T.int64(2) + rv1, ax3]) + # fmt: on + + mod = LegalizeOps()(MaxPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_max_pool2d_NCHW16c(): + # fmt: off + @tvm.script.ir_module + class MaxPool2D: + @R.function + def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4, 4, 110, 110, 16), "float32"): + gv: R.Tensor((4, 4, 110, 110, 16), "float32") = R.nn.max_pool2d(x, pool_size=[3, 3], layout="NCHW16c") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4, 4, 110, 110, 16), "float32"): + gv = R.call_tir(Expected.max_pool2d, (x,), R.Tensor((4, 4, 110, 110, 16), dtype="float32")) + return gv + + @T.prim_func + def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4), T.int64(112), T.int64(112), T.int64(16)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16), T.int64(3), T.int64(3)): + with T.block("pool_max"): + ax0, ax1, ax2, ax3, ax4, rv0, rv1 = T.axis.remap("SSSSSRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1, ax4]) + T.writes(pool_max[ax0, ax1, ax2, ax3, ax4]) + T.block_attr({"schedule_rule":"meta_schedule.pool_max"}) + with T.init(): + pool_max[ax0, ax1, ax2, ax3, ax4] = T.float32(-3.4028234663852886e+38) + pool_max[ax0, ax1, ax2, ax3, ax4] = T.max(pool_max[ax0, ax1, ax2, ax3, ax4], rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1, ax4]) + # fmt: on + + mod = LegalizeOps()(MaxPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_max_pool2d_ceil_mode(): + # fmt: off + @tvm.script.ir_module + class MaxPool2D: + @R.function + def main(x: R.Tensor((4, 6, 112, 112), "float32")) -> R.Tensor((4, 6, 38, 38), "float32"): + gv: R.Tensor((4, 6, 38, 38), "float32") = R.nn.max_pool2d(x, pool_size=[3, 3], strides=[3, 3], dilation=[1, 1], padding=[1, 1, 1, 1], ceil_mode=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((4, 6, 112, 112), dtype="float32")) -> R.Tensor((4, 6, 38, 38), dtype="float32"): + gv = R.call_tir(Expected.max_pool2d, (x,), R.Tensor((4, 6, 38, 38), dtype="float32")) + return gv + + @T.prim_func + def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6), T.int64(112), T.int64(112)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(6), T.int64(38), T.int64(38)), "float32")): + T.func_attr({"tir.noalias": True}) + pad_temp = T.alloc_buffer([T.int64(4), T.int64(6), T.int64(116), T.int64(116)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(6), T.int64(116), T.int64(116)): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2 - T.int64(1), ax3 - T.int64(1)]) + T.writes(pad_temp[ax0, ax1, ax2, ax3]) + pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else(T.int64(1) <= ax2 and ax2 < T.int64(113) and T.int64(1) <= ax3 and ax3 < T.int64(113), rxplaceholder[ax0, ax1, ax2 - T.int64(1), ax3 - T.int64(1)], T.float32(-3.4028234663852886e+38), dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(4), T.int64(6), T.int64(38), T.int64(38), T.int64(3), T.int64(3)): + with T.block("pool_max"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(pad_temp[ax0, ax1, ax2 * T.int64(3) + rv0, ax3 * T.int64(3) + rv1]) + T.writes(pool_max[ax0, ax1, ax2, ax3]) + T.block_attr({"schedule_rule":"meta_schedule.pool_max"}) + with T.init(): + pool_max[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e+38) + pool_max[ax0, ax1, ax2, ax3] = T.max(pool_max[ax0, ax1, ax2, ax3], pad_temp[ax0, ax1, ax2 * T.int64(3) + rv0, ax3 * T.int64(3) + rv1]) + # fmt: on + + mod = LegalizeOps()(MaxPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +@pytest.mark.skip("TOPI pooling casts every shape value to i32.") +def test_max_pool2d_symbolic(): + # fmt: off + @tvm.script.ir_module + class MaxPool2D: + @R.function + def main(dumb_param: R.Tensor(("kh", "kw")), x: R.Tensor(("n", "c", "h", "w"), "float32")) -> R.Tensor(("n", "c", "h - kh + 1", "w - kw + 1"), "float32"): + n = T.int64() + c = T.int64() + h = T.int64() + w = T.int64() + kh = T.int64() + kw = T.int64() + gv: R.Tensor((n, c, h - kh + 1, w - kw + 1), "float32") = R.nn.max_pool2d(x, pool_size=[kh, kw]) + return gv + + # fmt: on + + mod = LegalizeOps()(MaxPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_avg_pool2d(): + # fmt: off + @tvm.script.ir_module + class AvgPool2D: + @R.function + def main(x: R.Tensor((4, 112, 112, 6), "float32")) -> R.Tensor((4, 56, 56, 6), "float32"): + gv: R.Tensor((4, 56, 56, 6), "float32") = R.nn.avg_pool2d(x, pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1], layout="NHWC") + return gv + + @I.ir_module + class Expected: + @T.prim_func + def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112), T.int64(112), T.int64(6)), "float32"), pool_avg: T.Buffer((T.int64(4), T.int64(56), T.int64(56), T.int64(6)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + pad_temp = T.alloc_buffer((T.int64(4), T.int64(114), T.int64(114), T.int64(6))) + pool_sum = T.alloc_buffer((T.int64(4), T.int64(56), T.int64(56), T.int64(6))) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(114), T.int64(114), T.int64(6)): + with T.block("pad_temp"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1 - T.int64(1), v_ax2 - T.int64(1), v_ax3]) + T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3]) + pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(T.int64(1) <= v_ax1 and v_ax1 < T.int64(113) and T.int64(1) <= v_ax2 and v_ax2 < T.int64(113), rxplaceholder[v_ax0, v_ax1 - T.int64(1), v_ax2 - T.int64(1), v_ax3], T.float32(0)) + for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(4), T.int64(56), T.int64(56), T.int64(6), T.int64(3), T.int64(3)): + with T.block("pool_sum"): + v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1]) + T.reads(pad_temp[v_ax0, v_ax1 * T.int64(2) + v_rv0, v_ax2 * T.int64(2) + v_rv1, v_ax3]) + T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) + with T.init(): + pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0) + pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + pad_temp[v_ax0, v_ax1 * T.int64(2) + v_rv0, v_ax2 * T.int64(2) + v_rv1, v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(56), T.int64(56), T.int64(6)): + with T.block("pool_avg"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) + T.block_attr({"schedule_rule": "meta_schedule.pool_avg"}) + pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / T.Cast("float32", (T.min(T.int64(1), T.int64(112) - v_ax1 * T.int64(2)) + T.int64(2)) * (T.min(T.int64(1), T.int64(112) - v_ax2 * T.int64(2)) + T.int64(2))) + + @R.function + def main(x: R.Tensor((4, 112, 112, 6), dtype="float32")) -> R.Tensor((4, 56, 56, 6), dtype="float32"): + gv = R.call_tir(Expected.avg_pool2d, (x,), out_sinfo=R.Tensor((4, 56, 56, 6), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(AvgPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_avg_pool2d_NCHW16c(): + # fmt: off + @tvm.script.ir_module + class AvgPool2D: + @R.function + def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4, 4, 110, 110, 16), "float32"): + gv: R.Tensor((4, 4, 110, 110, 16), "float32") = R.nn.avg_pool2d(x, pool_size=[3, 3], layout="NCHW16c") + return gv + + @I.ir_module + class Expected: + @T.prim_func + def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4), T.int64(112), T.int64(112), T.int64(16)), "float32"), pool_avg: T.Buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + pool_sum = T.alloc_buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16))) + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16), T.int64(3), T.int64(3)): + with T.block("pool_sum"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap("SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2 + v_rv0, v_ax3 + v_rv1, v_ax4]) + T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + with T.init(): + pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32(0) + pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] + rxplaceholder[v_ax0, v_ax1, v_ax2 + v_rv0, v_ax3 + v_rv1, v_ax4] + for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)): + with T.block("pool_avg"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_avg"}) + pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] / T.Cast("float32", (T.min(T.int64(2), T.int64(111) - v_ax2) + T.int64(1)) * (T.min(T.int64(2), T.int64(111) - v_ax3) + T.int64(1))) + + @R.function + def main(x: R.Tensor((4, 4, 112, 112, 16), dtype="float32")) -> R.Tensor((4, 4, 110, 110, 16), dtype="float32"): + gv = R.call_tir(Expected.avg_pool2d, (x,), out_sinfo=R.Tensor((4, 4, 110, 110, 16), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(AvgPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_avg_pool2d_ceil_mode(): + # fmt: off + @tvm.script.ir_module + class AvgPool2D: + @R.function + def main(x: R.Tensor((4, 6, 112, 112), "float32")) -> R.Tensor((4, 6, 38, 38), "float32"): + gv: R.Tensor((4, 6, 38, 38), "float32") = R.nn.avg_pool2d(x, pool_size=[3, 3], strides=[3, 3], dilation=[1, 1], padding=[1, 1, 1, 1], ceil_mode=True) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6), T.int64(112), T.int64(112)), "float32"), pool_avg: T.Buffer((T.int64(4), T.int64(6), T.int64(38), T.int64(38)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + pad_temp = T.alloc_buffer((T.int64(4), T.int64(6), T.int64(116), T.int64(116))) + pool_sum = T.alloc_buffer((T.int64(4), T.int64(6), T.int64(38), T.int64(38))) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(6), T.int64(116), T.int64(116)): + with T.block("pad_temp"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2 - T.int64(1), v_ax3 - T.int64(1)]) + T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3]) + pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(T.int64(1) <= v_ax2 and v_ax2 < T.int64(113) and T.int64(1) <= v_ax3 and v_ax3 < T.int64(113), rxplaceholder[v_ax0, v_ax1, v_ax2 - T.int64(1), v_ax3 - T.int64(1)], T.float32(0)) + for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(4), T.int64(6), T.int64(38), T.int64(38), T.int64(3), T.int64(3)): + with T.block("pool_sum"): + v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1]) + T.reads(pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(3) + v_rv0, v_ax3 * T.int64(3) + v_rv1]) + T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) + with T.init(): + pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0) + pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(3) + v_rv0, v_ax3 * T.int64(3) + v_rv1] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(6), T.int64(38), T.int64(38)): + with T.block("pool_avg"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) + T.block_attr({"schedule_rule": "meta_schedule.pool_avg"}) + pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / T.Cast("float32", (T.min(T.int64(1), T.int64(112) - v_ax2 * T.int64(3)) + T.int64(2)) * (T.min(T.int64(1), T.int64(112) - v_ax3 * T.int64(3)) + T.int64(2))) + + @R.function + def main(x: R.Tensor((4, 6, 112, 112), dtype="float32")) -> R.Tensor((4, 6, 38, 38), dtype="float32"): + gv = R.call_tir(Expected.avg_pool2d, (x,), out_sinfo=R.Tensor((4, 6, 38, 38), dtype="float32")) + return gv + + # fmt: on + + mod = LegalizeOps()(AvgPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +@pytest.mark.skip("TOPI pooling casts every shape value to i32.") +def test_avg_pool2d_symbolic(): + # fmt: off + @tvm.script.ir_module + class AvgPool2D: + @R.function + def main(dumb_param: R.Tensor(("kh", "kw")), x: R.Tensor(("n", "c", "h", "w"), "float32")) -> R.Tensor(("n", "c", "h - kh + 1", "w - kw + 1"), "float32"): + n = T.int64() + c = T.int64() + h = T.int64() + w = T.int64() + kh = T.int64() + kw = T.int64() + gv: R.Tensor((n, c, h - kh + 1, w - kw + 1), "float32") = R.nn.avg_pool2d(x, pool_size=[kh, kw]) + return gv + + # fmt: on + + mod = LegalizeOps()(AvgPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_adaptive_avg_pool2d(): + # fmt: off + @tvm.script.ir_module + class AdaptiveAvgPool2D: + @R.function + def main(x: R.Tensor((2, 4, 7, 7, 16), "float32")) -> R.Tensor((2, 4, 1, 1, 16), "float32"): + gv: R.Tensor((2, 4, 1, 1, 16), "float32") = R.nn.adaptive_avg_pool2d(x, output_size=[1, 1], layout="NCHW16c") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4, 7, 7, 16), "float32")) -> R.Tensor((2, 4, 1, 1, 16), "float32"): + gv = R.call_tir(Expected.adaptive_avg_pool2d, (x,), R.Tensor((2, 4, 1, 1, 16), dtype="float32")) + return gv + + @T.prim_func + def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(7), T.int64(7), T.int64(16)), "float32"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": True}) + adaptive_pool_sum = T.alloc_buffer([T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)], dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16), T.int64(7), T.int64(7)): + with T.block("adaptive_pool_sum"): + ax0, ax1, ax2, ax3, ax4, rv0, rv1 = T.axis.remap("SSSSSRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(rxplaceholder[ax0, ax1, ax2 * T.int64(7) + rv0, ax3 * T.int64(7) + rv1, ax4]) + T.writes(adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4]) + with T.init(): + adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] = T.float32(0) + adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] = adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] + rxplaceholder[ax0, ax1, ax2 * T.int64(7) + rv0, ax3 * T.int64(7) + rv1, ax4] + for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)): + with T.block("adaptive_pool_avg"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4]) + T.writes(adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4]) + T.block_attr({"schedule_rule":"meta_schedule.adaptive_pool_avg"}) + adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4] = adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] * T.float32(0.020408163265306121) + # fmt: on + + mod = LegalizeOps()(AdaptiveAvgPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_adaptive_avg_pool2d_without_output_size(): + # fmt: off + @tvm.script.ir_module + class AdaptiveAvgPool2D: + @R.function + def main(x: R.Tensor((2, 16, 7, 7), "float32")) -> R.Tensor((2, 16, 7, 7), "float32"): + gv: R.Tensor((2, 16, 7, 7), "float32") = R.nn.adaptive_avg_pool2d(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 16, 7, 7), "float32")) -> R.Tensor((2, 16, 7, 7), "float32"): + gv = R.call_tir(Expected.adaptive_avg_pool2d, (x,), R.Tensor((2, 16, 7, 7), dtype="float32")) + return gv + + @T.prim_func + def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(16), T.int64(7), T.int64(7)), "float32"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(16), T.int64(7), T.int64(7)), "float32")): + T.func_attr({"tir.noalias": True}) + adaptive_pool_sum = T.alloc_buffer([T.int64(2), T.int64(16), T.int64(7), T.int64(7)], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(16), T.int64(7), T.int64(7), T.int64(1), T.int64(1)): + with T.block("adaptive_pool_sum"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1]) + T.writes(adaptive_pool_sum[ax0, ax1, ax2, ax3]) + with T.init(): + adaptive_pool_sum[ax0, ax1, ax2, ax3] = T.float32(0) + adaptive_pool_sum[ax0, ax1, ax2, ax3] = adaptive_pool_sum[ax0, ax1, ax2, ax3] + rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(16), T.int64(7), T.int64(7)): + with T.block("adaptive_pool_avg"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(adaptive_pool_sum[ax0, ax1, ax2, ax3]) + T.writes(adaptive_pool_avg[ax0, ax1, ax2, ax3]) + T.block_attr({"schedule_rule":"meta_schedule.adaptive_pool_avg"}) + adaptive_pool_avg[ax0, ax1, ax2, ax3] = adaptive_pool_sum[ax0, ax1, ax2, ax3] + # fmt: on + + mod = LegalizeOps()(AdaptiveAvgPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +@pytest.mark.skip("TOPI pooling casts every shape value to i32.") +def test_adaptive_avg_pool2d_symbolic(): + # fmt: off + @tvm.script.ir_module + class AdaptiveAvgPool2D: + @R.function + def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", "h", "w"), "float32")) -> R.Tensor(("n", "c", "oh", "ow"), "float32"): + n = T.int64() + c = T.int64() + oh = T.int64() + ow = T.int64() + gv: R.Tensor((n, c, oh, ow), "float32") = R.nn.adaptive_avg_pool2d(x, (oh, ow)) + return gv + # fmt: on + + mod = LegalizeOps()(AdaptiveAvgPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_relu(): + # fmt: off + @tvm.script.ir_module + class Relu: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.relu(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.relu, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def relu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float32(0)) + # fmt: on + + mod = LegalizeOps()(Relu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_relu_symbolic(): + # fmt: off + @tvm.script.ir_module + class Relu: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.nn.relu(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.relu, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def relu(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float32(0)) + # fmt: on + + mod = LegalizeOps()(Relu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_gelu(): + # fmt: off + @tvm.script.ir_module + class Gelu: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.gelu(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.gelu, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def gelu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + T_multiply_1 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + compute = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + T_multiply_2 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + T_divide = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_multiply_1[ax0, ax1]) + T_multiply_1[ax0, ax1] = rxplaceholder[ax0, ax1] * T.float32(0.70710678118654757) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply_1[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.erf(T_multiply_1[i0_1, i1_1], dtype="float32") + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply_1"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(compute[ax0, ax1]) + T.writes(T_multiply_2[ax0, ax1]) + T_multiply_2[ax0, ax1] = compute[ax0, ax1] * T.float32(0.5) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply_2[ax0, ax1]) + T.writes(T_divide[ax0, ax1]) + T_divide[ax0, ax1] = T.float32(0.5) + T_multiply_2[ax0, ax1] + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply_2"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], T_divide[ax0, ax1]) + T.writes(T_multiply[ax0, ax1]) + T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * T_divide[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Gelu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_gelu_symbolic(): + # fmt: off + @tvm.script.ir_module + class Gelu: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.nn.gelu(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.gelu, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def gelu(var_rxplaceholder: T.handle, var_T_multiply: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + T_multiply = T.match_buffer(var_T_multiply, [m, n], dtype="float32") + T_multiply_1 = T.alloc_buffer([m, n], dtype="float32") + compute = T.alloc_buffer([m, n], dtype="float32") + T_multiply_2 = T.alloc_buffer([m, n], dtype="float32") + T_add = T.alloc_buffer([m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_multiply"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_multiply_1[ax0, ax1]) + T_multiply_1[ax0, ax1] = rxplaceholder[ax0, ax1] * T.float32(0.70710678118654757) + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply_1[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.erf(T_multiply_1[i0_1, i1_1], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_multiply_1"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(compute[ax0, ax1]) + T.writes(T_multiply_2[ax0, ax1]) + T_multiply_2[ax0, ax1] = compute[ax0, ax1] * T.float32(0.5) + for i0, i1 in T.grid(m, n): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply_2[ax0, ax1]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = T.float32(0.5) + T_multiply_2[ax0, ax1] + for i0, i1 in T.grid(m, n): + with T.block("T_multiply_2"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], T_add[ax0, ax1]) + T.writes(T_multiply[ax0, ax1]) + T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * T_add[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Gelu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_silu(): + # fmt: off + @tvm.script.ir_module + class Silu: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.silu(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.silu, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def silu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + compute = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1]) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], compute[ax0, ax1]) + T.writes(T_multiply[ax0, ax1]) + T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * compute[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Silu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_silu_symbolic(): + # fmt: off + @tvm.script.ir_module + class Silu: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.nn.silu(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.silu, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def silu(var_rxplaceholder: T.handle, var_T_multiply: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + T_multiply = T.match_buffer(var_T_multiply, [m, n], dtype="float32") + compute = T.alloc_buffer([m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1]) + for i0, i1 in T.grid(m, n): + with T.block("T_multiply"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], compute[ax0, ax1]) + T.writes(T_multiply[ax0, ax1]) + T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * compute[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Silu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_softmax(): + # fmt: off + @tvm.script.ir_module + class Softmax: + @R.function + def main(x: R.Tensor((2, 3, 16, 32), "float32")) -> R.Tensor((2, 3, 16, 32), "float32"): + gv: R.Tensor((2, 3, 16, 32), "float32") = R.nn.softmax(x, axis=-2) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 16, 32), "float32")) -> R.Tensor((2, 3, 16, 32), "float32"): + gv = R.call_tir(Expected.softmax, (x,), R.Tensor((2, 3, 16, 32), dtype="float32")) + return gv + + @T.prim_func + def softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"), T_softmax_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32")): + T.func_attr({"tir.noalias": True}) + T_softmax_maxelem = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") + T_softmax_exp = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(16), T.int64(32)], dtype="float32") + T_softmax_expsum = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): + with T.block("T_softmax_maxelem"): + i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, k, i2_1]) + T.writes(T_softmax_maxelem[i0_1, i1_1, i2_1]) + with T.init(): + T_softmax_maxelem[i0_1, i1_1, i2_1] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[i0_1, i1_1, i2_1] = T.max(T_softmax_maxelem[i0_1, i1_1, i2_1], rxplaceholder[i0_1, i1_1, k, i2_1]) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)): + with T.block("T_softmax_exp"): + i0_2, i1_2, i2_2, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_2, i1_2, i2_2, i3_1], T_softmax_maxelem[i0_2, i1_2, i3_1]) + T.writes(T_softmax_exp[i0_2, i1_2, i2_2, i3_1]) + T_softmax_exp[i0_2, i1_2, i2_2, i3_1] = T.exp(rxplaceholder[i0_2, i1_2, i2_2, i3_1] - T_softmax_maxelem[i0_2, i1_2, i3_1], dtype="float32") + for i0_3, i1_3, i2_3, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): + with T.block("T_softmax_expsum"): + i0_4, i1_4, i2_4, k = T.axis.remap("SSSR", [i0_3, i1_3, i2_3, i3]) + T.reads(T_softmax_exp[i0_4, i1_4, k, i2_4]) + T.writes(T_softmax_expsum[i0_4, i1_4, i2_4]) + with T.init(): + T_softmax_expsum[i0_4, i1_4, i2_4] = T.float32(0) + T_softmax_expsum[i0_4, i1_4, i2_4] = T_softmax_expsum[i0_4, i1_4, i2_4] + T_softmax_exp[i0_4, i1_4, k, i2_4] + for i0_5, i1_5, i2_5, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)): + with T.block("T_softmax_norm"): + i0_6, i1_6, i2_6, i3_2 = T.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3]) + T.reads(T_softmax_exp[i0_6, i1_6, i2_6, i3_2], T_softmax_expsum[i0_6, i1_6, i3_2]) + T.writes(T_softmax_norm[i0_6, i1_6, i2_6, i3_2]) + T.block_attr({"axis":2}) + T_softmax_norm[i0_6, i1_6, i2_6, i3_2] = T_softmax_exp[i0_6, i1_6, i2_6, i3_2] / T_softmax_expsum[i0_6, i1_6, i3_2] + # fmt: on + + mod = LegalizeOps()(Softmax) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_softmax_symbolic(): + # fmt: off + @tvm.script.ir_module + class Softmax: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + gv: R.Tensor((a, b, c), "float32") = R.nn.softmax(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + gv = R.call_tir(Expected.softmax, (x,), R.Tensor((a, b, c), dtype="float32")) + return gv + + @T.prim_func + def softmax(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") + T_softmax_norm = T.match_buffer(var_T_softmax_norm, [a, b, c], dtype="float32") + T_softmax_maxelem = T.alloc_buffer([a, b], dtype="float32") + T_softmax_exp = T.alloc_buffer([a, b, c], dtype="float32") + T_softmax_expsum = T.alloc_buffer([a, b], dtype="float32") + for i0, i1, i2 in T.grid(a, b, c): + with T.block("T_softmax_maxelem"): + i0_1, i1_1, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, k]) + T.writes(T_softmax_maxelem[i0_1, i1_1]) + with T.init(): + T_softmax_maxelem[i0_1, i1_1] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[i0_1, i1_1] = T.max(T_softmax_maxelem[i0_1, i1_1], rxplaceholder[i0_1, i1_1, k]) + for i0, i1, i2 in T.grid(a, b, c): + with T.block("T_softmax_exp"): + i0_2, i1_2, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_2, i1_2, i2_1], T_softmax_maxelem[i0_2, i1_2]) + T.writes(T_softmax_exp[i0_2, i1_2, i2_1]) + T_softmax_exp[i0_2, i1_2, i2_1] = T.exp(rxplaceholder[i0_2, i1_2, i2_1] - T_softmax_maxelem[i0_2, i1_2], dtype="float32") + for i0_3, i1_3, i2 in T.grid(a, b, c): + with T.block("T_softmax_expsum"): + i0_4, i1_4, k = T.axis.remap("SSR", [i0_3, i1_3, i2]) + T.reads(T_softmax_exp[i0_4, i1_4, k]) + T.writes(T_softmax_expsum[i0_4, i1_4]) + with T.init(): + T_softmax_expsum[i0_4, i1_4] = T.float32(0) + T_softmax_expsum[i0_4, i1_4] = T_softmax_expsum[i0_4, i1_4] + T_softmax_exp[i0_4, i1_4, k] + for i0_5, i1_5, i2 in T.grid(a, b, c): + with T.block("T_softmax_norm"): + i0_6, i1_6, i2_2 = T.axis.remap("SSS", [i0_5, i1_5, i2]) + T.reads(T_softmax_exp[i0_6, i1_6, i2_2], T_softmax_expsum[i0_6, i1_6]) + T.writes(T_softmax_norm[i0_6, i1_6, i2_2]) + T.block_attr({"axis":2}) + T_softmax_norm[i0_6, i1_6, i2_2] = T_softmax_exp[i0_6, i1_6, i2_2] / T_softmax_expsum[i0_6, i1_6] + # fmt: on + + mod = LegalizeOps()(Softmax) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_log_softmax(): + # fmt: off + @tvm.script.ir_module + class LogSoftmax: + @R.function + def main(x: R.Tensor((2, 3, 16, 32), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv: R.Tensor((2, 3, 16, 32), "float32") = R.nn.log_softmax(x, axis=-2) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 16, 32), dtype="float32")) -> R.Tensor((2, 3, 16, 32), dtype="float32"): + gv = R.call_tir(Expected.log_softmax, (x,), R.Tensor((2, 3, 16, 32), dtype="float32")) + return gv + + @T.prim_func + def log_softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"),): + T.func_attr({"tir.noalias": True}) + T_softmax_maxelem = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") + compute_1 = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): + with T.block("T_softmax_maxelem"): + i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, k, i2_1]) + T.writes(T_softmax_maxelem[i0_1, i1_1, i2_1]) + with T.init(): + T_softmax_maxelem[i0_1, i1_1, i2_1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem[i0_1, i1_1, i2_1] = T.max(T_softmax_maxelem[i0_1, i1_1, i2_1], rxplaceholder[i0_1, i1_1, k, i2_1]) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): + with T.block("compute"): + i0_2, i1_2, i2_2, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_2, i1_2, k, i2_2], T_softmax_maxelem[i0_2, i1_2, i2_2]) + T.writes(compute_1[i0_2, i1_2, i2_2]) + with T.init(): + compute_1[i0_2, i1_2, i2_2] = T.float32(0) + compute_1[i0_2, i1_2, i2_2] = compute_1[i0_2, i1_2, i2_2] + T.exp(rxplaceholder[i0_2, i1_2, k, i2_2] - T_softmax_maxelem[i0_2, i1_2, i2_2], dtype="float32") + for i0_3, i1_3, i2_3, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)): + with T.block("compute_1"): + i0_4, i1_4, i2_4, i3_1 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3]) + T.reads(rxplaceholder[i0_4, i1_4, i2_4, i3_1], T_softmax_maxelem[i0_4, i1_4, i3_1], compute_1[i0_4, i1_4, i3_1]) + T.writes(compute[i0_4, i1_4, i2_4, i3_1]) + T.block_attr({"axis": 2}) + compute[i0_4, i1_4, i2_4, i3_1] = (rxplaceholder[i0_4, i1_4, i2_4, i3_1] - T_softmax_maxelem[i0_4, i1_4, i3_1] - T.log(compute_1[i0_4, i1_4, i3_1], dtype="float32")) + # fmt: on + + mod = LegalizeOps()(LogSoftmax) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_log_softmax_symbolic(): + # fmt: off + @tvm.script.ir_module + class LogSoftmax: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + gv: R.Tensor((a, b, c), "float32") = R.nn.log_softmax(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", "c"), dtype="float32"): + a = T.int64() + b = T.int64() + c = T.int64() + # block 0 + gv = R.call_tir(Expected.log_softmax, (x,), R.Tensor((a, b, c), dtype="float32")) + return gv + + @T.prim_func + def log_softmax(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") + compute = T.match_buffer(var_compute, [a, b, c], dtype="float32") + T_softmax_maxelem = T.alloc_buffer([a, b], dtype="float32") + compute_1 = T.alloc_buffer([a, b], dtype="float32") + for i0, i1, k in T.grid(a, b, c): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], rxplaceholder[v_i0, v_i1, v_k]) + for i0, i1, k in T.grid(a, b, c): + with T.block("compute"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_k], T_softmax_maxelem[v_i0, v_i1]) + T.writes(compute_1[v_i0, v_i1]) + with T.init(): + compute_1[v_i0, v_i1] = T.float32(0) + compute_1[v_i0, v_i1] = compute_1[v_i0, v_i1] + T.exp(rxplaceholder[v_i0, v_i1, v_k] - T_softmax_maxelem[v_i0, v_i1], dtype="float32") + for i0, i1, i2 in T.grid(a, b, c): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1], compute_1[v_i0, v_i1],) + T.writes(compute[v_i0, v_i1, v_i2]) + T.block_attr({"axis": 2}) + compute[v_i0, v_i1, v_i2] = (rxplaceholder[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1] - T.log(compute_1[v_i0, v_i1], dtype="float32")) + # fmt: on + + mod = LegalizeOps()(LogSoftmax) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cross_entropy_with_logits(): + # fmt: off + @tvm.script.ir_module + class CrossEntropyWithLogits: + @R.function + def main(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3,), dtype="float32"), y: R.Tensor((3,), dtype="float32")) -> R.Tensor(dtype="float32", ndim=2): + gv = R.call_tir(Expected.cross_entropy_with_logits, (x, y), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def cross_entropy_with_logits(rxplaceholder: T.Buffer(T.int64(3), "float32"), rxplaceholder_1: T.Buffer(T.int64(3), "float32"), T_multiply: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + T_multiply_1 = T.alloc_buffer([T.int64(3)], dtype="float32") + T_multiply_red = T.alloc_buffer([], dtype="float32") + for i0 in T.serial(T.int64(3)): + with T.block("T_multiply"): + ax0 = T.axis.spatial(T.int64(3), i0) + T.reads(rxplaceholder[ax0], rxplaceholder_1[ax0]) + T.writes(T_multiply_1[ax0]) + T_multiply_1[ax0] = rxplaceholder[ax0] * rxplaceholder_1[ax0] + for i0 in T.serial(T.int64(3)): + with T.block("T_multiply_red"): + k0 = T.axis.reduce(T.int64(3), i0) + T.reads(T_multiply_1[k0]) + T.writes(T_multiply_red[()]) + with T.init(): + T_multiply_red[()] = T.float32(0) + T_multiply_red[()] = T_multiply_red[()] + T_multiply_1[k0] + with T.block("T_multiply_1"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_red[()]) + T.writes(T_multiply[()]) + T_multiply[()] = T_multiply_red[()] * T.float32(-1) + # fmt: on + + mod = LegalizeOps()(CrossEntropyWithLogits) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cross_entropy_with_logits_batch(): + # fmt: off + @tvm.script.ir_module + class CrossEntropyWithLogits: + @R.function + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")) -> R.Tensor(dtype="float32", ndim=2): + gv = R.call_tir(Expected.cross_entropy_with_logits, (x, y), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def cross_entropy_with_logits(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + T_multiply = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + T_multiply_red = T.alloc_buffer([], dtype="float32") + T_multiply_1 = T.alloc_buffer([], dtype="float32") + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, ax1]) + T.writes(T_multiply[ax0, ax1]) + T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * rxplaceholder_1[ax0, ax1] + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply_red"): + k0, k1 = T.axis.remap("RR", [i0, i1]) + T.reads(T_multiply[k0, k1]) + T.writes(T_multiply_red[()]) + with T.init(): + T_multiply_red[()] = T.float32(0) + T_multiply_red[()] = T_multiply_red[()] + T_multiply[k0, k1] + with T.block("T_multiply_1"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_red[()]) + T.writes(T_multiply_1[()]) + T_multiply_1[()] = T_multiply_red[()] * T.float32(-1) + with T.block("T_divide"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_1[()]) + T.writes(T_divide[()]) + T_divide[()] = T_multiply_1[()] * T.float32(0.5) + # fmt: on + + mod = LegalizeOps()(CrossEntropyWithLogits) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cross_entropy_with_logits_batch_symbolic(): + # fmt: off + @tvm.script.ir_module + class CrossEntropyWithLogits: + @R.function + def main(x: R.Tensor(("n", "m"), "float32"), y: R.Tensor(("n", "m"), "float32")) -> R.Tensor(None, "float32", ndim=2): + n = T.int64() + m = T.int64() + gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("n", "m"), dtype="float32"), y: R.Tensor(("n", "m"), dtype="float32")) -> R.Tensor(dtype="float32", ndim=2): + gv = R.call_tir(Expected.cross_entropy_with_logits, (x, y), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def cross_entropy_with_logits(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, T_divide: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32") + T_multiply = T.alloc_buffer([n, m], dtype="float32") + T_multiply_red = T.alloc_buffer([], dtype="float32") + T_multiply_1 = T.alloc_buffer([], dtype="float32") + for ax0, ax1 in T.grid(n, m): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax0, v_ax1]) + T.writes(T_multiply[v_ax0, v_ax1]) + T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * rxplaceholder_1[v_ax0, v_ax1] + for k0, k1 in T.grid(n, m): + with T.block("T_multiply_red"): + v_k0, v_k1 = T.axis.remap("RR", [k0, k1]) + T.reads(T_multiply[v_k0, v_k1]) + T.writes(T_multiply_red[()]) + with T.init(): + T_multiply_red[()] = T.float32(0) + T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0, v_k1] + with T.block("T_multiply_1"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_red[()]) + T.writes(T_multiply_1[()]) + T_multiply_1[()] = T_multiply_red[()] * T.float32(-1) + with T.block("T_divide"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_1[()]) + T.writes(T_divide[()]) + T_divide[()] = T_multiply_1[()] / T.Cast("float32", n) + # fmt: on + + mod = LegalizeOps()(CrossEntropyWithLogits) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_batch_norm(): + # fmt: off + @tvm.script.ir_module + class BatchNorm: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), "float32"), gamma: R.Tensor((3,), "float32"), beta: R.Tensor((3,), "float32"), moving_mean: R.Tensor((3,), "float32"), moving_var: R.Tensor((3,), "float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")): + gv: R.Tuple(R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), "float32"), gamma: R.Tensor((3,), "float32"), beta: R.Tensor((3,), "float32"), moving_mean: R.Tensor((3,), "float32"), moving_var: R.Tensor((3,), "float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")): + gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, moving_var), [R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")]) + return gv + + @T.prim_func + def batch_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer(T.int64(3), "float32"), rxplaceholder_2: T.Buffer(T.int64(3), "float32"), rxplaceholder_3: T.Buffer(T.int64(3), "float32"), rxplaceholder_4: T.Buffer(T.int64(3), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), T_multiply: T.Buffer(T.int64(3), "float32"), T_multiply_1: T.Buffer(T.int64(3), "float32")): + T.func_attr({"tir.noalias": True}) + T_reshape = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(1), T.int64(1)], dtype="float32") + T_subtract = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32") + T_reshape_1 = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(1), T.int64(1)], dtype="float32") + T_add_1 = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(1), T.int64(1)], dtype="float32") + compute = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(1), T.int64(1)], dtype="float32") + T_divide = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32") + T_reshape_2 = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(1), T.int64(1)], dtype="float32") + T_multiply_2 = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32") + T_reshape_3 = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(1), T.int64(1)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_reshape"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_3[(ax1 + ax2 + ax3) % T.int64(3)]) + T.writes(T_reshape[ax0, ax1, ax2, ax3]) + T_reshape[ax0, ax1, ax2, ax3] = rxplaceholder_3[(ax1 + ax2 + ax3) % T.int64(3)] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_reshape[T.int64(0), ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_reshape[T.int64(0), ax1, T.int64(0), T.int64(0)] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_reshape_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_4[(ax1 + ax2 + ax3) % T.int64(3)]) + T.writes(T_reshape_1[ax0, ax1, ax2, ax3]) + T_reshape_1[ax0, ax1, ax2, ax3] = rxplaceholder_4[(ax1 + ax2 + ax3) % T.int64(3)] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_reshape_1[ax0, ax1, ax2, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = T_reshape_1[ax0, ax1, ax2, ax3] + T.float32(1.0000000000000001e-05) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("compute"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_1[i0_1, i1_1, i2_1, i3_1]) + T.writes(compute[i0_1, i1_1, i2_1, i3_1]) + compute[i0_1, i1_1, i2_1, i3_1] = T.sqrt(T_add_1[i0_1, i1_1, i2_1, i3_1], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3], compute[T.int64(0), ax1, T.int64(0), T.int64(0)]) + T.writes(T_divide[ax0, ax1, ax2, ax3]) + T_divide[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] / compute[T.int64(0), ax1, T.int64(0), T.int64(0)] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_reshape_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[(ax1 + ax2 + ax3) % T.int64(3)]) + T.writes(T_reshape_2[ax0, ax1, ax2, ax3]) + T_reshape_2[ax0, ax1, ax2, ax3] = rxplaceholder_1[(ax1 + ax2 + ax3) % T.int64(3)] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_divide[ax0, ax1, ax2, ax3], T_reshape_2[T.int64(0), ax1, T.int64(0), T.int64(0)]) + T.writes(T_multiply_2[ax0, ax1, ax2, ax3]) + T_multiply_2[ax0, ax1, ax2, ax3] = T_divide[ax0, ax1, ax2, ax3] * T_reshape_2[T.int64(0), ax1, T.int64(0), T.int64(0)] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_reshape_3"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_2[(ax1 + ax2 + ax3) % T.int64(3)]) + T.writes(T_reshape_3[ax0, ax1, ax2, ax3]) + T_reshape_3[ax0, ax1, ax2, ax3] = rxplaceholder_2[(ax1 + ax2 + ax3) % T.int64(3)] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_multiply_2[ax0, ax1, ax2, ax3], T_reshape_3[T.int64(0), ax1, T.int64(0), T.int64(0)]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = T_multiply_2[ax0, ax1, ax2, ax3] + T_reshape_3[T.int64(0), ax1, T.int64(0), T.int64(0)] + for i0 in T.serial(T.int64(3)): + with T.block("T_multiply_1"): + ax0 = T.axis.spatial(T.int64(3), i0) + T.reads(rxplaceholder_3[ax0]) + T.writes(T_multiply[ax0]) + T_multiply[ax0] = rxplaceholder_3[ax0] + for i0 in T.serial(T.int64(3)): + with T.block("T_multiply_2"): + ax0 = T.axis.spatial(T.int64(3), i0) + T.reads(rxplaceholder_4[ax0]) + T.writes(T_multiply_1[ax0]) + T_multiply_1[ax0] = rxplaceholder_4[ax0] + # fmt: on + + mod = LegalizeOps()(BatchNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_batch_norm_symbolic(): + # fmt: off + @tvm.script.ir_module + class BatchNorm: + @R.function + def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")): + n = T.int64() + h = T.int64() + w = T.int64() + c = T.int64() + gv: R.Tuple(R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=-1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")): + n = T.int64() + h = T.int64() + w = T.int64() + c = T.int64() + gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, moving_var), [R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")]) + return gv + + @T.prim_func + def batch_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_rxplaceholder_3: T.handle, var_rxplaceholder_4: T.handle, var_T_add: T.handle, var_T_multiply: T.handle, var_T_multiply_1: T.handle): + T.func_attr({"tir.noalias": True}) + c = T.int64() + h = T.int64() + n = T.int64() + w = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, h, w, c], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [c], dtype="float32") + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [c], dtype="float32") + rxplaceholder_3 = T.match_buffer(var_rxplaceholder_3, [c], dtype="float32") + rxplaceholder_4 = T.match_buffer(var_rxplaceholder_4, [c], dtype="float32") + T_add = T.match_buffer(var_T_add, [n, h, w, c], dtype="float32") + T_multiply = T.match_buffer(var_T_multiply, [c], dtype="float32") + T_multiply_1 = T.match_buffer(var_T_multiply_1, [c], dtype="float32") + T_reshape = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(1), c], dtype="float32") + T_subtract = T.alloc_buffer([n, h, w, c], dtype="float32") + T_reshape_1 = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(1), c], dtype="float32") + T_add_1 = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(1), c], dtype="float32") + compute = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(1), c], dtype="float32") + T_divide = T.alloc_buffer([n, h, w, c], dtype="float32") + T_reshape_2 = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(1), c], dtype="float32") + T_multiply_2 = T.alloc_buffer([n, h, w, c], dtype="float32") + T_reshape_3 = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(1), c], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(1), c): + with T.block("T_reshape"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_3[((ax0 + ax1 + ax2) * c + ax3) % c]) + T.writes(T_reshape[ax0, ax1, ax2, ax3]) + T_reshape[ax0, ax1, ax2, ax3] = rxplaceholder_3[((ax0 + ax1 + ax2) * c + ax3) % c] + for i0, i1, i2, i3 in T.grid(n, h, w, c): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_reshape[T.int64(0), T.int64(0), T.int64(0), ax3]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_reshape[T.int64(0), T.int64(0), T.int64(0), ax3] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(1), c): + with T.block("T_reshape_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_4[((ax0 + ax1 + ax2) * c + ax3) % c]) + T.writes(T_reshape_1[ax0, ax1, ax2, ax3]) + T_reshape_1[ax0, ax1, ax2, ax3] = rxplaceholder_4[((ax0 + ax1 + ax2) * c + ax3) % c] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(1), c): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_reshape_1[ax0, ax1, ax2, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = T_reshape_1[ax0, ax1, ax2, ax3] + T.float32(1.0000000000000001e-05) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(1), c): + with T.block("compute"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_1[i0_1, i1_1, i2_1, i3_1]) + T.writes(compute[i0_1, i1_1, i2_1, i3_1]) + compute[i0_1, i1_1, i2_1, i3_1] = T.sqrt(T_add_1[i0_1, i1_1, i2_1, i3_1], dtype="float32") + for i0, i1, i2, i3 in T.grid(n, h, w, c): + with T.block("T_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3], compute[T.int64(0), T.int64(0), T.int64(0), ax3]) + T.writes(T_divide[ax0, ax1, ax2, ax3]) + T_divide[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] / compute[T.int64(0), T.int64(0), T.int64(0), ax3] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(1), c): + with T.block("T_reshape_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[((ax0 + ax1 + ax2) * c + ax3) % c]) + T.writes(T_reshape_2[ax0, ax1, ax2, ax3]) + T_reshape_2[ax0, ax1, ax2, ax3] = rxplaceholder_1[((ax0 + ax1 + ax2) * c + ax3) % c] + for i0, i1, i2, i3 in T.grid(n, h, w, c): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_divide[ax0, ax1, ax2, ax3], T_reshape_2[T.int64(0), T.int64(0), T.int64(0), ax3]) + T.writes(T_multiply_2[ax0, ax1, ax2, ax3]) + T_multiply_2[ax0, ax1, ax2, ax3] = T_divide[ax0, ax1, ax2, ax3] * T_reshape_2[T.int64(0), T.int64(0), T.int64(0), ax3] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(1), c): + with T.block("T_reshape_3"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_2[((ax0 + ax1 + ax2) * c + ax3) % c]) + T.writes(T_reshape_3[ax0, ax1, ax2, ax3]) + T_reshape_3[ax0, ax1, ax2, ax3] = rxplaceholder_2[((ax0 + ax1 + ax2) * c + ax3) % c] + for i0, i1, i2, i3 in T.grid(n, h, w, c): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_multiply_2[ax0, ax1, ax2, ax3], T_reshape_3[T.int64(0), T.int64(0), T.int64(0), ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = T_multiply_2[ax0, ax1, ax2, ax3] + T_reshape_3[T.int64(0), T.int64(0), T.int64(0), ax3] + for i0 in T.serial(c): + with T.block("T_multiply_1"): + ax0 = T.axis.spatial(c, i0) + T.reads(rxplaceholder_3[ax0]) + T.writes(T_multiply[ax0]) + T_multiply[ax0] = rxplaceholder_3[ax0] + for i0 in T.serial(c): + with T.block("T_multiply_2"): + ax0 = T.axis.spatial(c, i0) + T.reads(rxplaceholder_4[ax0]) + T.writes(T_multiply_1[ax0]) + T_multiply_1[ax0] = rxplaceholder_4[ax0] + # fmt: on + + mod = LegalizeOps()(BatchNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_layer_norm(): + # fmt: off + @tvm.script.ir_module + class LayerNorm: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32"), gamma: R.Tensor((4, 5), "float32"), beta: R.Tensor((4, 5), "float32")) -> R.Tensor((2, 3, 4, 5), "float32"): + gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.layer_norm(x, gamma, beta, axes=[-2, -1]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32"), gamma: R.Tensor((4, 5), "float32"), beta: R.Tensor((4, 5), "float32")) -> R.Tensor((2, 3, 4, 5), "float32"): + gv = R.call_tir(Expected.layer_norm, (x, gamma, beta), R.Tensor((2, 3, 4, 5), dtype="float32")) + return gv + + @T.prim_func + def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(5)), "float32"), rxplaceholder_2: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_layer_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): + T.func_attr({"tir.noalias": True}) + rxplaceholder_red_temp_v0 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + rxplaceholder_red_temp_v1 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red_temp"): + ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, k2, k3]) + T.writes(rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1]) + with T.init(): + rxplaceholder_red_temp_v0[ax0, ax1] = T.float32(0) + rxplaceholder_red_temp_v1[ax0, ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[ax0, ax1] + rxplaceholder[ax0, ax1, k2, k3] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[ax0, ax1] + rxplaceholder[ax0, ax1, k2, k3] * rxplaceholder[ax0, ax1, k2, k3] + rxplaceholder_red_temp_v0[ax0, ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[ax0, ax1] = v_rxplaceholder_red_temp_v1 + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_layer_norm"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], rxplaceholder_1[ax2, ax3], rxplaceholder_2[ax2, ax3]) + T.writes(T_layer_norm[ax0, ax1, ax2, ax3]) + T_layer_norm[ax0, ax1, ax2, ax3] = (rxplaceholder[ax0, ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) * T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] * T.float32(0.05) - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05) * (rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) + T.float32(1e-05), dtype="float32") * rxplaceholder_1[ax2, ax3] + rxplaceholder_2[ax2, ax3] + # fmt: on + mod = LegalizeOps()(LayerNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_layer_norm_fp16(): + # fmt: off + @tvm.script.ir_module + class LayerNorm: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float16"), gamma: R.Tensor((4, 5), "float16"), beta: R.Tensor((4, 5), "float16")) -> R.Tensor((2, 3, 4, 5), "float16"): + gv: R.Tensor((2, 3, 4, 5), "float16") = R.nn.layer_norm(x, gamma, beta, axes=[-2, -1]) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle): + T.func_attr({"tir.noalias": True}) + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4), T.int64(5)), "float16") + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4), T.int64(5)), "float16") + T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16") + with T.block("root"): + T.reads() + T.writes() + rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(3))) + rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2), T.int64(3))) + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for k2 in range(T.int64(4)): + for k3 in range(T.int64(5)): + with T.block("rxplaceholder_red_temp"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_k2 = T.axis.reduce(T.int64(4), k2) + v_k3 = T.axis.reduce(T.int64(5), k3) + T.reads(rxplaceholder[v_ax0, v_ax1, v_k2, v_k3]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_k2, v_k3]) + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_k2, v_k3]) * T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_k2, v_k3]) + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(4)): + for ax3 in range(T.int64(5)): + with T.block("T_layer_norm"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(4), ax2) + v_ax3 = T.axis.spatial(T.int64(5), ax3) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], rxplaceholder_1[v_ax2, v_ax3], rxplaceholder_2[v_ax2, v_ax3]) + T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3]) + T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float16", (T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3]) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * T.float16(5))) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * T.float16(5)) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * T.float16(5)) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * T.float16(5))) + T.float32(1.0000000000000001e-05))) * rxplaceholder_1[v_ax2, v_ax3] + rxplaceholder_2[v_ax2, v_ax3] + + @R.function + def main(x: R.Tensor((2, 3, 4, 5), dtype="float16"), gamma: R.Tensor((4, 5), dtype="float16"), beta: R.Tensor((4, 5), dtype="float16")) -> R.Tensor((2, 3, 4, 5), dtype="float16"): + gv = R.call_tir(Expected.layer_norm, (x, gamma, beta), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float16")) + return gv + # fmt: on + mod = LegalizeOps()(LayerNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_layer_norm_symbolic(): + # fmt: off + @tvm.script.ir_module + class LayerNorm: + @R.function + def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: R.Tensor(("s", "f"), "float32"), beta: R.Tensor(("s", "f"), "float32")) -> R.Tensor(("n", "s", "f"), "float32"): + n = T.int64() + s = T.int64() + f = T.int64() + gv: R.Tensor((n, s, f), "float32") = R.nn.layer_norm(x, gamma, beta, axes=[1, 2]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: R.Tensor(("s", "f"), "float32"), beta: R.Tensor(("s", "f"), "float32")) -> R.Tensor(("n", "s", "f"), "float32"): + n = T.int64() + s = T.int64() + f = T.int64() + gv = R.call_tir(Expected.layer_norm, (x, gamma, beta), R.Tensor((n, s, f), dtype="float32")) + return gv + + @T.prim_func + def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle): + T.func_attr({"tir.noalias": True}) + f = T.int64() + n = T.int64() + s = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, s, f], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [s, f], dtype="float32") + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [s, f], dtype="float32") + T_layer_norm = T.match_buffer(var_T_layer_norm, [n, s, f], dtype="float32") + rxplaceholder_red_temp_v0 = T.alloc_buffer([n], dtype="float32") + rxplaceholder_red_temp_v1 = T.alloc_buffer([n], dtype="float32") + for i0, i1, i2 in T.grid(n, s, f): + with T.block("rxplaceholder_red_temp"): + ax0, k1, k2 = T.axis.remap("SRR", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, k1, k2]) + T.writes(rxplaceholder_red_temp_v0[ax0], rxplaceholder_red_temp_v1[ax0]) + with T.init(): + rxplaceholder_red_temp_v0[ax0] = T.float32(0) + rxplaceholder_red_temp_v1[ax0] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[ax0] + rxplaceholder[ax0, k1, k2] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[ax0] + rxplaceholder[ax0, k1, k2] * rxplaceholder[ax0, k1, k2] + rxplaceholder_red_temp_v0[ax0] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[ax0] = v_rxplaceholder_red_temp_v1 + for i0, i1, i2 in T.grid(n, s, f): + with T.block("T_layer_norm"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_red_temp_v0[ax0], rxplaceholder_red_temp_v1[ax0], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, ax2]) + T.writes(T_layer_norm[ax0, ax1, ax2]) + T_layer_norm[ax0, ax1, ax2] = (rxplaceholder[ax0, ax1, ax2] - rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", f))) * T.rsqrt(rxplaceholder_red_temp_v1[ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) - rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) * (rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", f))) + T.float32(1e-05), dtype="float32") * rxplaceholder_1[ax1, ax2] + rxplaceholder_2[ax1, ax2] + # fmt: on + mod = LegalizeOps()(LayerNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_group_norm(): + # fmt: off + @tvm.script.ir_module + class GroupNorm: + @R.function + def main(x: R.Tensor((2, 4, 4, 5), "float32"), gamma: R.Tensor((4,), "float32"), beta: R.Tensor((4,), "float32")) -> R.Tensor((2, 4, 4, 5), "float32"): + gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4),), "float32"), rxplaceholder_2: T.Buffer((T.int64(4),), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float32")): + T.func_attr({"tir.noalias": True}) + T_reshape_1 = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) + rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(2))) + rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_reshape_2 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_reshape_3 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_group_norm = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) + for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)] + for ax0, ax1, k2, k3, k4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red_temp"): + v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] * T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 + for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): + with T.block("T_reshape_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) + T.writes(T_reshape_2[v_ax0, v_ax1]) + T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] + for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): + with T.block("T_reshape_2"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) + T.writes(T_reshape_3[v_ax0, v_ax1]) + T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] + for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("T_group_norm"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) + T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] * T.float32(0.025000000000000001) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(4), T.int64(5)): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)] + + @R.function + def main(x: R.Tensor((2, 4, 4, 5), dtype="float32"), gamma: R.Tensor((4,), dtype="float32"), beta: R.Tensor((4,), dtype="float32")) -> R.Tensor((2, 4, 4, 5), dtype="float32"): + gv = R.call_tir(Expected.group_norm, (x, gamma, beta), out_sinfo=R.Tensor((2, 4, 4, 5), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(GroupNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_group_norm_fp16(): + # fmt: off + @tvm.script.ir_module + class GroupNorm: + @R.function + def main(x: R.Tensor((2, 4, 4, 5), "float16"), gamma: R.Tensor((4,), "float16"), beta: R.Tensor((4,), "float16")) -> R.Tensor((2, 4, 4, 5), "float16"): + gv: R.Tensor((2, 4, 4, 5), "float16") = R.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4, 4, 5), dtype="float16"), gamma: R.Tensor((4,), dtype="float16"), beta: R.Tensor((4,), dtype="float16")) -> R.Tensor((2, 4, 4, 5), dtype="float16"): + gv = R.call_tir(Expected.group_norm, (x, gamma, beta), out_sinfo=R.Tensor((2, 4, 4, 5), dtype="float16")) + return gv + + @T.prim_func + def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float16"), rxplaceholder_1: T.Buffer((T.int64(4),), "float16"), rxplaceholder_2: T.Buffer((T.int64(4),), "float16"), T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float16")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + T_reshape_1 = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)), "float16") + T_cast = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) + rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(2))) + rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_reshape_2 = T.alloc_buffer((T.int64(2), T.int64(2)), "float16") + T_reshape_3 = T.alloc_buffer((T.int64(2), T.int64(2)), "float16") + T_group_norm = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)), "float16") + for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)] + for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("T_cast"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.Cast("float32", T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + for ax0, ax1, k2, k3, k4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red_temp"): + v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4]) + T.reads(T_cast[v_ax0, v_ax1, v_k2, v_k3, v_k4]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_cast[v_ax0, v_ax1, v_k2, v_k3, v_k4] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_cast[v_ax0, v_ax1, v_k2, v_k3, v_k4] * T_cast[v_ax0, v_ax1, v_k2, v_k3, v_k4] + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 + for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): + with T.block("T_reshape_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) + T.writes(T_reshape_2[v_ax0, v_ax1]) + T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] + for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): + with T.block("T_reshape_2"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) + T.writes(T_reshape_3[v_ax0, v_ax1]) + T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] + for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("T_group_norm"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) + T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.Cast("float16", (T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] * T.float32(0.025000000000000001) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) + T.float32(1.0000000000000001e-05))) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(4), T.int64(5)): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)] + # fmt: on + + mod = LegalizeOps()(GroupNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_group_norm_symbolic(): + # fmt: off + @tvm.script.ir_module + class GroupNorm: + @R.function + def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), "float32"), gamma: R.Tensor(("4 * c",), "float32"), beta: R.Tensor(("4 * c",), "float32")) -> R.Tensor(("n", "4 * c", "h", "w"), "float32"): + n = T.int64() + c = T.int64() + h = T.int64() + w = T.int64() + gv: R.Tensor((n, 4 * c, h, w), "float32") = R.nn.group_norm(x, gamma, beta, num_groups=4, channel_axis=1, axes=[2, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def group_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_reshape: T.handle, c: T.int64): + T.func_attr({"tir.noalias": True}) + n = T.int64() + h = T.int64() + w = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (n, T.int64(4) * c, h, w)) + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4) * c,)) + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4) * c,)) + T_reshape = T.match_buffer(var_T_reshape, (n, T.int64(4) * c, h, w)) + # with T.block("root"): + T_reshape_1 = T.alloc_buffer((n, T.int64(4), T.int64(4) * c // T.int64(4), h, w)) + rxplaceholder_red_temp_v0 = T.alloc_buffer((n, T.int64(4))) + rxplaceholder_red_temp_v1 = T.alloc_buffer((n, T.int64(4))) + T_reshape_2 = T.alloc_buffer((T.int64(4), T.int64(4) * c // T.int64(4))) + T_reshape_3 = T.alloc_buffer((T.int64(4), T.int64(4) * c // T.int64(4))) + T_group_norm = T.alloc_buffer((n, T.int64(4), T.int64(4) * c // T.int64(4), h, w)) + for ax0, ax1, ax2, ax3, ax4 in T.grid(n, T.int64(4), c, h, w): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(rxplaceholder[((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h // (c * T.int64(4)) % n, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h % (c * T.int64(4)), ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w % h, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) % w]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h // (c * T.int64(4)) % n, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h % (c * T.int64(4)), ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w % h, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) % w] + for ax0, ax1, k2, k3, k4 in T.grid(n, T.int64(4), c, h, w): + with T.block("rxplaceholder_red_temp"): + v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] * T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 + for ax0, ax1 in T.grid(T.int64(4), c): + with T.block("T_reshape_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_1[(v_ax0 * c + v_ax1) % (c * T.int64(4))]) + T.writes(T_reshape_2[v_ax0, v_ax1]) + T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * c + v_ax1) % (c * T.int64(4))] + for ax0, ax1 in T.grid(T.int64(4), c): + with T.block("T_reshape_2"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_2[(v_ax0 * c + v_ax1) % (c * T.int64(4))]) + T.writes(T_reshape_3[v_ax0, v_ax1]) + T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * c + v_ax1) % (c * T.int64(4))] + for ax0, ax1, ax2, ax3, ax4 in T.grid(n, T.int64(4), c, h, w): + with T.block("T_group_norm"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) + T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w))) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w)) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w)) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w))) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] + for ax0, ax1, ax2, ax3 in T.grid(n, c * T.int64(4), h, w): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_group_norm[(((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) % w]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) % w] + + @R.function + def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), dtype="float32"), gamma: R.Tensor(("4 * c",), dtype="float32"), beta: R.Tensor(("4 * c",), dtype="float32")) -> R.Tensor(("n", "4 * c", "h", "w"), dtype="float32"): + n = T.int64() + c = T.int64() + h = T.int64() + w = T.int64() + gv = R.call_tir(Expected.group_norm, (x, gamma, beta), out_sinfo=R.Tensor((n, 4 * c, h, w), dtype="float32"), tir_vars=R.shape([c])) + return gv + # fmt: on + + mod = LegalizeOps()(GroupNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_attention(): + # fmt: off + @tvm.script.ir_module + class Attention: + @R.function + def main(q: R.Tensor((4, 16, 32, 8), "float32"), k: R.Tensor((4, 8, 32, 8), "float32"), v: R.Tensor((4, 8, 32, 16), "float32"), bias: R.Tensor((4, 32, 16, 8), "float32")): + scale = T.FloatImm("float32", 0.1) + gv: R.Tensor((4, 16, 32, 16), "float32") = R.nn.attention(q, k, v, bias, scale) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def attention_bias(rxplaceholder: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(8)), "float32"), rxplaceholder_2: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(16)), "float32"), rxplaceholder_3: T.Buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8)), "float32"), T_transpose: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + T_transpose_1 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8))) + T_reshape = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) + T_transpose_2 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(8), T.int64(8))) + T_reshape_1 = T.alloc_buffer((T.int64(128), T.int64(8), T.int64(8))) + T_batch_matmul_NT = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) + T_multiply = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) + T_reshape_2 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8))) + T_add = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8))) + T_reshape_3 = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) + T_softmax_maxelem = T.alloc_buffer((T.int64(128), T.int64(16))) + T_softmax_exp = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) + T_softmax_expsum = T.alloc_buffer((T.int64(128), T.int64(16))) + T_softmax_norm = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) + T_transpose_3 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(8), T.int64(16))) + T_reshape_4 = T.alloc_buffer((T.int64(128), T.int64(8), T.int64(16))) + T_batch_matmul_NN = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(16))) + T_reshape_5 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(16))) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax2, v_ax1, v_ax3] + for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(8), T.int64(8)): + with T.block("T_transpose_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder_1[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_1[v_ax0, v_ax2, v_ax1, v_ax3] + for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(8)): + with T.block("T_reshape_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2]) + T_reshape_1[v_ax0, v_ax1, v_ax2] = T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)] + for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(8), T.int64(8)): + with T.block("T_batch_matmul_NT"): + v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k]) + T.reads(T_reshape[v_b, v_i, v_k], T_reshape_1[v_b, v_j, v_k]) + T.writes(T_batch_matmul_NT[v_b, v_i, v_j]) + T.block_attr({"layout_free_placeholders": [T_reshape_1]}) + with T.init(): + T_batch_matmul_NT[v_b, v_i, v_j] = T.float32(0) + T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[v_b, v_i, v_j] + T_reshape[v_b, v_i, v_k] * T_reshape_1[v_b, v_j, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_batch_matmul_NT[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = T_batch_matmul_NT[v_ax0, v_ax1, v_ax2] * T.float32(0.10000000000000001) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)): + with T.block("T_reshape_2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_multiply[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(8) + v_ax2) % T.int64(16), v_ax3 % T.int64(8)]) + T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(8) + v_ax2) % T.int64(16), v_ax3 % T.int64(8)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] + rxplaceholder_3[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_add[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)]) + T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2]) + T_reshape_3[v_ax0, v_ax1, v_ax2] = T_add[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)] + for i0, i1, k in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(T_reshape_3[v_i0, v_i1, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], T_reshape_3[v_i0, v_i1, v_k]) + for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_reshape_3[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2]) + T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(T_reshape_3[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1]) + for i0, i1, k in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1]) + with T.init(): + T_softmax_expsum[v_i0, v_i1] = T.float32(0) + T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k] + for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2]) + T.block_attr({"axis": 2}) + T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(8), T.int64(16)): + with T.block("T_transpose_2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder_2[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_2[v_ax0, v_ax2, v_ax1, v_ax3] + for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(16)): + with T.block("T_reshape_4"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)]) + T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2]) + T_reshape_4[v_ax0, v_ax1, v_ax2] = T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)] + for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(16), T.int64(8)): + with T.block("T_batch_matmul_NN"): + v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k]) + T.reads(T_softmax_norm[v_b, v_i, v_k], T_reshape_4[v_b, v_k, v_j]) + T.writes(T_batch_matmul_NN[v_b, v_i, v_j]) + T.block_attr({"layout_free_placeholders": [T_reshape_4]}) + with T.init(): + T_batch_matmul_NN[v_b, v_i, v_j] = T.float32(0) + T_batch_matmul_NN[v_b, v_i, v_j] = T_batch_matmul_NN[v_b, v_i, v_j] + T_softmax_norm[v_b, v_i, v_k] * T_reshape_4[v_b, v_k, v_j] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(16)): + with T.block("T_reshape_5"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_batch_matmul_NN[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(16) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(16) + v_ax2) % T.int64(16), v_ax3 % T.int64(16)]) + T.writes(T_reshape_5[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_batch_matmul_NN[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(16) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(16) + v_ax2) % T.int64(16), v_ax3 % T.int64(16)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(16), T.int64(32), T.int64(16)): + with T.block("T_transpose_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_reshape_5[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_5[v_ax0, v_ax2, v_ax1, v_ax3] + + @R.function + def main(q: R.Tensor((4, 16, 32, 8), dtype="float32"), k: R.Tensor((4, 8, 32, 8), dtype="float32"), v: R.Tensor((4, 8, 32, 16), dtype="float32"), bias: R.Tensor((4, 32, 16, 8), dtype="float32")) -> R.Tensor((4, 16, 32, 16), dtype="float32"): + gv = R.call_tir(Expected.attention_bias, (q, k, v, bias), out_sinfo=R.Tensor((4, 16, 32, 16), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Attention) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py new file mode 100644 index 000000000000..682abf2d5777 --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -0,0 +1,990 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + + +##################### Search ##################### + + +def test_where(): + # fmt: off + @tvm.script.ir_module + class Where: + @R.function + def main(condition: R.Tensor((3, 2, 1), "bool"), x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32")) -> R.Tensor((3, 2, 3), "float32"): + gv: R.Tensor((3, 2, 3), "float32") = R.where(condition, x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(condition: R.Tensor((3, 2, 1), "bool"), x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32")) -> R.Tensor((3, 2, 3), "float32"): + gv = R.call_tir(Expected.where, (condition, x, y), R.Tensor((3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def where(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(1)), "bool"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_2: T.Buffer((T.int64(2), T.int64(1)), "float32"), T_where: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_where"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, T.int64(0)], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)]) + T.writes(T_where[ax0, ax1, ax2]) + T_where[ax0, ax1, ax2] = T.Select(0 < T.Cast("int32", rxplaceholder[ax0, ax1, T.int64(0)]), rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)]) + # fmt: on + + mod = LegalizeOps()(Where) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_where_symbolic(): + # fmt: off + @tvm.script.ir_module + class Where: + @R.function + def main(condition: R.Tensor(("a", "b", 1), "bool"), x: R.Tensor(("b", "c"), "float32"), y: R.Tensor(("b", 1), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + gv: R.Tensor((a, b, c), "float32") = R.where(condition, x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(condition: R.Tensor(("a", "b", 1), "bool"), x: R.Tensor(("b", "c"), "float32"), y: R.Tensor(("b", 1), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + gv = R.call_tir(Expected.where, (condition, x, y), R.Tensor((a, b, c), dtype="float32")) + return gv + + @T.prim_func + def where(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_where: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, T.int64(1)], dtype="bool") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [b, c], dtype="float32") + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [b, T.int64(1)], dtype="float32") + T_where = T.match_buffer(var_T_where, [a, b, c], dtype="float32") + for i0, i1, i2 in T.grid(a, b, c): + with T.block("T_where"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, T.int64(0)], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)]) + T.writes(T_where[ax0, ax1, ax2]) + T_where[ax0, ax1, ax2] = T.Select(0 < T.Cast("int32", rxplaceholder[ax0, ax1, T.int64(0)]), rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)]) + # fmt: on + + mod = LegalizeOps()(Where) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_argmax(): + # fmt: off + @tvm.script.ir_module + class Argmax: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 4, 5), "int64"): + gv: R.Tensor((2, 4, 5), "int64") = R.argmax(x, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((2, 4, 5), dtype="int64"): + gv = R.call_tir(Expected.argmax, (x,), out_sinfo=R.Tensor((2, 4, 5), dtype="int64")) + return gv + + @T.prim_func + def argmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64")): + T.func_attr({"tir.noalias": True}) + rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5)), "int64") + rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5))) + for ax0, ax1, ax2, k1 in T.grid(T.int64(2), T.int64(4), T.int64(5), T.int64(3)): + with T.block("rxplaceholder_red_temp"): + v_ax0, v_ax1, v_ax2, v_k1 = T.axis.remap("SSSR", [ax0, ax1, ax2, k1]) + T.reads(rxplaceholder[v_ax0, v_k1, v_ax1, v_ax2]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2], rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.int64(-1) + rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.min_value("float32") + v_rxplaceholder_red_temp_v0: T.int64 = T.Select(rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2] > rxplaceholder[v_ax0, v_k1, v_ax1, v_ax2] or rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2] == rxplaceholder[v_ax0, v_k1, v_ax1, v_ax2] and rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2] < v_k1, rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2], v_k1) + v_rxplaceholder_red_temp_v1: T.float32 = T.Select(rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2] > rxplaceholder[v_ax0, v_k1, v_ax1, v_ax2], rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2], rxplaceholder[v_ax0, v_k1, v_ax1, v_ax2]) + rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_rxplaceholder_red_temp_v1 + for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2]) + T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2]) + rxplaceholder_red[v_ax0, v_ax1, v_ax2] = rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2] + # fmt: on + + mod = LegalizeOps()(Argmax) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_argmax_symbolic(): + # fmt: off + @tvm.script.ir_module + class Argmax: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", 1, "c", "d"), "int64"): + a = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, 1, c, d), "int64") = R.argmax(x, axis=1, keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor(("a", 1, "c", "d"), dtype="int64"): + a = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.argmax, (x,), out_sinfo=R.Tensor((a, 1, c, d), dtype="int64")) + return gv + + @T.prim_func + def argmax(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c, d)) + rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, (a, T.int64(1), c, d), "int64") + # with T.block("root"): + rxplaceholder_red_temp_v0 = T.alloc_buffer((a, T.int64(1), c, d), "int64") + rxplaceholder_red_temp_v1 = T.alloc_buffer((a, T.int64(1), c, d)) + for ax0, ax1, ax2, ax3, k1 in T.grid(a, T.int64(1), c, d, b): + with T.block("rxplaceholder_red_temp"): + v_ax0, v_ax1, v_ax2, v_ax3, v_k1 = T.axis.remap("SSSSR", [ax0, ax1, ax2, ax3, k1]) + T.reads(rxplaceholder[v_ax0, v_k1, v_ax2, v_ax3]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3] = T.int64(-1) + rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3] = T.min_value("float32") + v_rxplaceholder_red_temp_v0: T.int64 = T.Select(rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3] > rxplaceholder[v_ax0, v_k1, v_ax2, v_ax3] or rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3] == rxplaceholder[v_ax0, v_k1, v_ax2, v_ax3] and rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3] < v_k1, rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3], v_k1) + v_rxplaceholder_red_temp_v1: T.float32 = T.Select(rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3] > rxplaceholder[v_ax0, v_k1, v_ax2, v_ax3], rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder[v_ax0, v_k1, v_ax2, v_ax3]) + rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3] = v_rxplaceholder_red_temp_v1 + for ax0, ax1, ax2, ax3 in T.grid(a, T.int64(1), c, d): + with T.block("rxplaceholder_red"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3]) + rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3] + # fmt: on + + mod = LegalizeOps()(Argmax) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_argmin(): + # fmt: off + @tvm.script.ir_module + class Argmin: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "int64"): + gv: R.Tensor((), "int64") = R.argmin(x) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def argmin(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((), "int64")): + T.func_attr({"tir.noalias": True}) + rxplaceholder_red_temp_v0 = T.alloc_buffer((), "int64") + rxplaceholder_red_temp_v1 = T.alloc_buffer(()) + for k0, k1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red_temp"): + v_k0, v_k1, v_k2, v_k3 = T.axis.remap("RRRR", [k0, k1, k2, k3]) + T.reads(rxplaceholder[v_k0, v_k1, v_k2, v_k3]) + T.writes(rxplaceholder_red_temp_v0[()], rxplaceholder_red_temp_v1[()]) + with T.init(): + rxplaceholder_red_temp_v0[()] = T.int64(-1) + rxplaceholder_red_temp_v1[()] = T.max_value("float32") + v_rxplaceholder_red_temp_v0: T.int64 = T.Select(rxplaceholder_red_temp_v1[()] < rxplaceholder[v_k0, v_k1, v_k2, v_k3] or rxplaceholder_red_temp_v1[()] == rxplaceholder[v_k0, v_k1, v_k2, v_k3] and rxplaceholder_red_temp_v0[()] < v_k0 * T.int64(60) + v_k1 * T.int64(20) + v_k2 * T.int64(5) + v_k3, rxplaceholder_red_temp_v0[()], v_k0 * T.int64(60) + v_k1 * T.int64(20) + v_k2 * T.int64(5) + v_k3) + v_rxplaceholder_red_temp_v1: T.float32 = T.Select(rxplaceholder_red_temp_v1[()] < rxplaceholder[v_k0, v_k1, v_k2, v_k3], rxplaceholder_red_temp_v1[()], rxplaceholder[v_k0, v_k1, v_k2, v_k3]) + rxplaceholder_red_temp_v0[()] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[()] = v_rxplaceholder_red_temp_v1 + with T.block("rxplaceholder_red"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(rxplaceholder_red_temp_v0[()]) + T.writes(rxplaceholder_red[()]) + rxplaceholder_red[()] = rxplaceholder_red_temp_v0[()] + + @R.function + def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((), dtype="int64"): + gv = R.call_tir(Expected.argmin, (x,), out_sinfo=R.Tensor((), dtype="int64")) + return gv + # fmt: on + + mod = LegalizeOps()(Argmin) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_argmin_symbolic(): + # fmt: off + @tvm.script.ir_module + class Argmin: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, 1, 1, 1), "int64"): + gv: R.Tensor((1, 1, 1, 1), "int64") = R.argmin(x, keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def argmin(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "int64")): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c, d)) + rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "int64") + rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1))) + for ax0, ax1, ax2, ax3, k0, k1, k2, k3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), a, b, c, d): + with T.block("rxplaceholder_red_temp"): + v_ax0, v_ax1, v_ax2, v_ax3, v_k0, v_k1, v_k2, v_k3 = T.axis.remap("SSSSRRRR", [ax0, ax1, ax2, ax3, k0, k1, k2, k3]) + T.reads(rxplaceholder[v_k0, v_k1, v_k2, v_k3]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3] = T.int64(-1) + rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3] = T.max_value("float32") + v_rxplaceholder_red_temp_v0: T.int64 = T.Select(rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3] < rxplaceholder[v_k0, v_k1, v_k2, v_k3] or rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3] == rxplaceholder[v_k0, v_k1, v_k2, v_k3] and rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3] < ((v_k0 * b + v_k1) * c + v_k2) * d + v_k3, rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3], ((v_k0 * b + v_k1) * c + v_k2) * d + v_k3) + v_rxplaceholder_red_temp_v1: T.float32 = T.Select(rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3] < rxplaceholder[v_k0, v_k1, v_k2, v_k3], rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder[v_k0, v_k1, v_k2, v_k3]) + rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3] = v_rxplaceholder_red_temp_v1 + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("rxplaceholder_red"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3]) + rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3] + + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor((1, 1, 1, 1), dtype="int64"): + gv = R.call_tir(Expected.argmin, (x,), out_sinfo=R.Tensor((1, 1, 1, 1), dtype="int64")) + return gv + # fmt: on + + mod = LegalizeOps()(Argmin) + tvm.ir.assert_structural_equal(mod, Expected) + + +##################### Statistical ##################### + + +def test_max(): + # fmt: off + @tvm.script.ir_module + class Max: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 5), "float32"): + gv: R.Tensor((2, 5), "float32") = R.max(x, axis=[1, 2]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 5), "float32"): + gv = R.call_tir(Expected.max, (x,), R.Tensor((2, 5), dtype="float32")) + return gv + + @T.prim_func + def max(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(5)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(5), T.int64(3), T.int64(4)): + with T.block("rxplaceholder_red"): + ax0, ax1, k1, k2 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, k1, k2, ax1]) + T.writes(rxplaceholder_red[ax0, ax1]) + with T.init(): + rxplaceholder_red[ax0, ax1] = T.min_value("float32") + rxplaceholder_red[ax0, ax1] = T.max(rxplaceholder_red[ax0, ax1], rxplaceholder[ax0, k1, k2, ax1]) + # fmt: on + + mod = LegalizeOps()(Max) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_max_symbolic(): + # fmt: off + @tvm.script.ir_module + class Max: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", "d"), "float32"): + a = T.int64() + d = T.int64() + gv: R.Tensor((a, d), "float32") = R.max(x, axis=[1, 2]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", "d"), "float32"): + a = T.int64() + d = T.int64() + gv = R.call_tir(Expected.max, (x,), R.Tensor((a, d), dtype="float32")) + return gv + + @T.prim_func + def max(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, [a, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, d, b, c): + with T.block("rxplaceholder_red"): + ax0, ax1, k1, k2 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, k1, k2, ax1]) + T.writes(rxplaceholder_red[ax0, ax1]) + with T.init(): + rxplaceholder_red[ax0, ax1] = T.min_value("float32") + rxplaceholder_red[ax0, ax1] = T.max(rxplaceholder_red[ax0, ax1], rxplaceholder[ax0, k1, k2, ax1]) + # fmt: on + + mod = LegalizeOps()(Max) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_min(): + # fmt: off + @tvm.script.ir_module + class Min: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 1, 1, 5), "float32"): + gv: R.Tensor((2, 1, 1, 5), "float32") = R.min(x, axis=[1, 2], keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 1, 1, 5), "float32"): + gv = R.call_tir(Expected.min, (x,), R.Tensor((2, 1, 1, 5), dtype="float32")) + return gv + + @T.prim_func + def min(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(1), T.int64(1), T.int64(5)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(5), T.int64(3), T.int64(4)): + with T.block("rxplaceholder_red"): + ax0, ax1, ax2, ax3, k1, k2 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[ax0, k1, k2, ax3]) + T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) + with T.init(): + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.max_value("float32") + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.min(rxplaceholder_red[ax0, ax1, ax2, ax3], rxplaceholder[ax0, k1, k2, ax3]) + # fmt: on + + mod = LegalizeOps()(Min) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_min_symbolic(): + # fmt: off + @tvm.script.ir_module + class Min: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", 1, 1, "d"), "float32"): + a = T.int64() + d = T.int64() + gv: R.Tensor((a, 1, 1, d), "float32") = R.min(x, axis=[1, 2], keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", 1, 1, "d"), "float32"): + a = T.int64() + d = T.int64() + gv = R.call_tir(Expected.min, (x,), R.Tensor((a, 1, 1, d), dtype="float32")) + return gv + + @T.prim_func + def min(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, [a, T.int64(1), T.int64(1), d], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(a, T.int64(1), T.int64(1), d, b, c): + with T.block("rxplaceholder_red"): + ax0, ax1, ax2, ax3, k1, k2 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[ax0, k1, k2, ax3]) + T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) + with T.init(): + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.max_value("float32") + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.min(rxplaceholder_red[ax0, ax1, ax2, ax3], rxplaceholder[ax0, k1, k2, ax3]) + # fmt: on + + mod = LegalizeOps()(Min) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sum(): + # fmt: off + @tvm.script.ir_module + class Sum: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.sum(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "float32"): + gv = R.call_tir(Expected.sum, (x,), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def sum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red"): + k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k0, k1, k2, k3]) + T.writes(rxplaceholder_red[()]) + with T.init(): + rxplaceholder_red[()] = T.float32(0) + rxplaceholder_red[()] = rxplaceholder_red[()] + rxplaceholder[k0, k1, k2, k3] + # fmt: on + + mod = LegalizeOps()(Sum) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sum_symbolic(): + # fmt: off + @tvm.script.ir_module + class Sum: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.sum(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((), "float32"): + gv = R.call_tir(Expected.sum, (x,), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def sum(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("rxplaceholder_red"): + k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k0, k1, k2, k3]) + T.writes(rxplaceholder_red[()]) + with T.init(): + rxplaceholder_red[()] = T.float32(0) + rxplaceholder_red[()] = rxplaceholder_red[()] + rxplaceholder[k0, k1, k2, k3] + # fmt: on + + mod = LegalizeOps()(Sum) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_prod(): + # fmt: off + @tvm.script.ir_module + class Prod: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 1, 1, 1), "float32"): + gv: R.Tensor((1, 1, 1, 1), "float32") = R.prod(x, keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 1, 1, 1), "float32"): + gv = R.call_tir(Expected.prod, (x,), R.Tensor((1, 1, 1, 1), dtype="float32")) + return gv + + @T.prim_func + def prod(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red"): + ax0, ax1, ax2, ax3, k0, k1, k2, k3 = T.axis.remap("SSSSRRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) + T.reads(rxplaceholder[k0, k1, k2, k3]) + T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) + with T.init(): + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.float32(1) + rxplaceholder_red[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] * rxplaceholder[k0, k1, k2, k3] + # fmt: on + + mod = LegalizeOps()(Prod) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_prod_symbolic(): + # fmt: off + @tvm.script.ir_module + class Prod: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, 1, 1, 1), "float32"): + gv: R.Tensor((1, 1, 1, 1), "float32") = R.prod(x, keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, 1, 1, 1), "float32"): + gv = R.call_tir(Expected.prod, (x,), R.Tensor((1, 1, 1, 1), dtype="float32")) + return gv + + @T.prim_func + def prod(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "float32")): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), a, b, c, d): + with T.block("rxplaceholder_red"): + ax0, ax1, ax2, ax3, k0, k1, k2, k3 = T.axis.remap("SSSSRRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) + T.reads(rxplaceholder[k0, k1, k2, k3]) + T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) + with T.init(): + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.float32(1) + rxplaceholder_red[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] * rxplaceholder[k0, k1, k2, k3] + # fmt: on + + mod = LegalizeOps()(Prod) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_mean(): + # fmt: off + @tvm.script.ir_module + class Mean: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((3, 4), "float32"): + gv: R.Tensor((3, 4), "float32") = R.mean(x, [0, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((3, 4), "float32"): + gv = R.call_tir(Expected.mean, (x,), R.Tensor((3, 4), dtype="float32")) + return gv + + @T.prim_func + def mean(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), T_divide: T.Buffer((T.int64(3), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + rxplaceholder_red = T.alloc_buffer([T.int64(3), T.int64(4)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(3), T.int64(4), T.int64(2), T.int64(5)): + with T.block("rxplaceholder_red"): + ax0, ax1, k0, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k0, ax0, ax1, k3]) + T.writes(rxplaceholder_red[ax0, ax1]) + with T.init(): + rxplaceholder_red[ax0, ax1] = T.float32(0) + rxplaceholder_red[ax0, ax1] = rxplaceholder_red[ax0, ax1] + rxplaceholder[k0, ax0, ax1, k3] + for i0, i1 in T.grid(T.int64(3), T.int64(4)): + with T.block("T_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder_red[ax0, ax1]) + T.writes(T_divide[ax0, ax1]) + T_divide[ax0, ax1] = rxplaceholder_red[ax0, ax1] * T.float32(0.1) + # fmt: on + + mod = LegalizeOps()(Mean) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_mean_symbolic(): + # fmt: off + @tvm.script.ir_module + class Mean: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("b", "c"), "float32"): + b = T.int64() + c = T.int64() + gv: R.Tensor((b, c), "float32") = R.mean(x, [0, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor(("b", "c"), dtype="float32"): + b = T.int64() + c = T.int64() + gv = R.call_tir(Expected.mean, (x,), R.Tensor((b, c), dtype="float32")) + return gv + + @T.prim_func + def mean(var_rxplaceholder: T.handle, var_T_divide: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + T_divide = T.match_buffer(var_T_divide, [b, c], dtype="float32") + rxplaceholder_red = T.alloc_buffer([b, c], dtype="float32") + for i0, i1, i2, i3 in T.grid(b, c, a, d): + with T.block("rxplaceholder_red"): + ax0, ax1, k0, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k0, ax0, ax1, k3]) + T.writes(rxplaceholder_red[ax0, ax1]) + with T.init(): + rxplaceholder_red[ax0, ax1] = T.float32(0) + rxplaceholder_red[ax0, ax1] = rxplaceholder_red[ax0, ax1] + rxplaceholder[k0, ax0, ax1, k3] + for i0, i1 in T.grid(b, c): + with T.block("T_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder_red[ax0, ax1]) + T.writes(T_divide[ax0, ax1]) + T_divide[ax0, ax1] = rxplaceholder_red[ax0, ax1] / T.Cast("float32", a * d) + # fmt: on + + mod = LegalizeOps()(Mean) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_std(): + # fmt: off + @tvm.script.ir_module + class Std: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.std(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "float32"): + gv = R.call_tir(Expected.std, (x,), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def std(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), compute: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + rxplaceholder_red = T.alloc_buffer([], dtype="float32") + T_divide = T.alloc_buffer([], dtype="float32") + T_subtract = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32") + T_multiply = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32") + T_multiply_red = T.alloc_buffer([], dtype="float32") + T_divide_1 = T.alloc_buffer([], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red"): + k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k0, k1, k2, k3]) + T.writes(rxplaceholder_red[()]) + with T.init(): + rxplaceholder_red[()] = T.float32(0) + rxplaceholder_red[()] = rxplaceholder_red[()] + rxplaceholder[k0, k1, k2, k3] + with T.block("T_divide"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(rxplaceholder_red[()]) + T.writes(T_divide[()]) + T_divide[()] = rxplaceholder_red[()] * T.float32(0.0083333333333333332) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide[()]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_divide[()] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] * T_subtract[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_multiply_red"): + k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) + T.reads(T_multiply[k0, k1, k2, k3]) + T.writes(T_multiply_red[()]) + with T.init(): + T_multiply_red[()] = T.float32(0) + T_multiply_red[()] = T_multiply_red[()] + T_multiply[k0, k1, k2, k3] + with T.block("T_divide_1"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_red[()]) + T.writes(T_divide_1[()]) + T_divide_1[()] = T_multiply_red[()] * T.float32(0.0083333333333333332) + with T.block("compute"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_divide_1[()]) + T.writes(compute[()]) + compute[()] = T.sqrt(T_divide_1[()]) + # fmt: on + + mod = LegalizeOps()(Std) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_std_symbolic(): + # fmt: off + @tvm.script.ir_module + class Std: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.std(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((), "float32"): + gv = R.call_tir(Expected.std, (x,), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def std(var_rxplaceholder: T.handle, compute: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + rxplaceholder_red = T.alloc_buffer([], dtype="float32") + T_divide = T.alloc_buffer([], dtype="float32") + T_subtract = T.alloc_buffer([a, b, c, d], dtype="float32") + T_multiply = T.alloc_buffer([a, b, c, d], dtype="float32") + T_multiply_red = T.alloc_buffer([], dtype="float32") + T_divide_1 = T.alloc_buffer([], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("rxplaceholder_red"): + k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k0, k1, k2, k3]) + T.writes(rxplaceholder_red[()]) + with T.init(): + rxplaceholder_red[()] = T.float32(0) + rxplaceholder_red[()] = rxplaceholder_red[()] + rxplaceholder[k0, k1, k2, k3] + with T.block("T_divide"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(rxplaceholder_red[()]) + T.writes(T_divide[()]) + T_divide[()] = rxplaceholder_red[()] / T.Cast("float32", a * b * c * d) + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide[()]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_divide[()] + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] * T_subtract[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_multiply_red"): + k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) + T.reads(T_multiply[k0, k1, k2, k3]) + T.writes(T_multiply_red[()]) + with T.init(): + T_multiply_red[()] = T.float32(0) + T_multiply_red[()] = T_multiply_red[()] + T_multiply[k0, k1, k2, k3] + with T.block("T_divide_1"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_red[()]) + T.writes(T_divide_1[()]) + T_divide_1[()] = T_multiply_red[()] / T.Cast("float32", a * b * c * d) + with T.block("compute"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_divide_1[()]) + T.writes(compute[()]) + compute[()] = T.sqrt(T_divide_1[()]) + # fmt: on + + mod = LegalizeOps()(Std) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_variance(): + # fmt: off + @tvm.script.ir_module + class Variance: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 3, 4, 1), "float32"): + gv: R.Tensor((1, 3, 4, 1), "float32") = R.variance(x, [0, 3], keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((1, 3, 4, 1), dtype="float32"): + gv = R.call_tir(Expected.variance, (x,), R.Tensor((1, 3, 4, 1), dtype="float32")) + return gv + + @T.prim_func + def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), T_divide: T.Buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(1)), "float32")): + T.func_attr({"tir.noalias": True}) + rxplaceholder_red = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(4), T.int64(1)], dtype="float32") + T_divide_1 = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(4), T.int64(1)], dtype="float32") + T_subtract = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32") + T_multiply = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32") + T_multiply_red = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(4), T.int64(1)], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1), T.int64(2), T.int64(5)): + with T.block("rxplaceholder_red"): + ax0, ax1, ax2, ax3, k0, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[k0, ax1, ax2, k3]) + T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) + with T.init(): + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.float32(0) + rxplaceholder_red[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] + rxplaceholder[k0, ax1, ax2, k3] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1)): + with T.block("T_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_red[ax0, ax1, ax2, ax3]) + T.writes(T_divide_1[ax0, ax1, ax2, ax3]) + T_divide_1[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] * T.float32(0.10000000000000001) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide_1[T.int64(0), ax1, ax2, T.int64(0)]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_divide_1[T.int64(0), ax1, ax2, T.int64(0)] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] * T_subtract[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1), T.int64(2), T.int64(5)): + with T.block("T_multiply_red"): + ax0, ax1, ax2, ax3, k0, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(T_multiply[k0, ax1, ax2, k3]) + T.writes(T_multiply_red[ax0, ax1, ax2, ax3]) + with T.init(): + T_multiply_red[ax0, ax1, ax2, ax3] = T.float32(0) + T_multiply_red[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] + T_multiply[k0, ax1, ax2, k3] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1)): + with T.block("T_divide_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_multiply_red[ax0, ax1, ax2, ax3]) + T.writes(T_divide[ax0, ax1, ax2, ax3]) + T_divide[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] * T.float32(0.10000000000000001) + # fmt: on + + mod = LegalizeOps()(Variance) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_variance_symbolic(): + # fmt: off + @tvm.script.ir_module + class Variance: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, "b", "c", 1), "float32"): + b = T.int64() + c = T.int64() + gv: R.Tensor((1, b, c, 1), "float32") = R.variance(x, [0, 3], keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, "b", "c", 1), "float32"): + b = T.int64() + c = T.int64() + gv = R.call_tir(Expected.variance, (x,), R.Tensor((1, b, c, 1), dtype="float32")) + return gv + + @T.prim_func + def variance(var_rxplaceholder: T.handle, var_T_divide: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + T_divide = T.match_buffer(var_T_divide, [T.int64(1), b, c, T.int64(1)], dtype="float32") + rxplaceholder_red = T.alloc_buffer([T.int64(1), b, c, T.int64(1)], dtype="float32") + T_divide_1 = T.alloc_buffer([T.int64(1), b, c, T.int64(1)], dtype="float32") + T_subtract = T.alloc_buffer([a, b, c, d], dtype="float32") + T_multiply = T.alloc_buffer([a, b, c, d], dtype="float32") + T_multiply_red = T.alloc_buffer([T.int64(1), b, c, T.int64(1)], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(1), b, c, T.int64(1), a, d): + with T.block("rxplaceholder_red"): + ax0, ax1, ax2, ax3, k0, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[k0, ax1, ax2, k3]) + T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) + with T.init(): + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.float32(0) + rxplaceholder_red[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] + rxplaceholder[k0, ax1, ax2, k3] + for i0, i1, i2, i3 in T.grid(T.int64(1), b, c, T.int64(1)): + with T.block("T_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_red[ax0, ax1, ax2, ax3]) + T.writes(T_divide_1[ax0, ax1, ax2, ax3]) + T_divide_1[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] / T.Cast("float32", a * d) + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide_1[T.int64(0), ax1, ax2, T.int64(0)]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_divide_1[T.int64(0), ax1, ax2, T.int64(0)] + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] * T_subtract[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(1), b, c, T.int64(1), a, d): + with T.block("T_multiply_red"): + ax0, ax1, ax2, ax3, k0, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(T_multiply[k0, ax1, ax2, k3]) + T.writes(T_multiply_red[ax0, ax1, ax2, ax3]) + with T.init(): + T_multiply_red[ax0, ax1, ax2, ax3] = T.float32(0) + T_multiply_red[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] + T_multiply[k0, ax1, ax2, k3] + for i0, i1, i2, i3 in T.grid(T.int64(1), b, c, T.int64(1)): + with T.block("T_divide_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_multiply_red[ax0, ax1, ax2, ax3]) + T.writes(T_divide[ax0, ax1, ax2, ax3]) + T_divide[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] / T.Cast("float32", a * d) + # fmt: on + + mod = LegalizeOps()(Variance) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py b/tests/python/relax/test_transform_legalize_ops_unary.py new file mode 100644 index 000000000000..27103e5f8f70 --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_unary.py @@ -0,0 +1,1098 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.relax.transform import LegalizeOps +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_abs(): + # fmt: off + @tvm.script.ir_module + class Abs: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.abs(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + gv = R.call_tir(Expected.tir_abs, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_abs(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32"),): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.fabs(rxplaceholder[v_i0, v_i1], dtype="float32") + # fmt: on + + mod = LegalizeOps()(Abs) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_abs_symbolic(): + # fmt: off + @tvm.script.ir_module + class Abs: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.abs(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_abs, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_abs(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.fabs(rxplaceholder[v_i0, v_i1], dtype="float32") + # fmt: on + + mod = LegalizeOps()(Abs) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_ceil(): + # fmt: off + @tvm.script.ir_module + class Ceil: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.ceil(x) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def tir_ceil(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.ceil(rxplaceholder[v_i0, v_i1]) + + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + gv = R.call_tir(Expected.tir_ceil, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Ceil) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_ceil_int(): + # fmt: off + @tvm.script.ir_module + class Ceil: + @R.function + def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.ceil(x) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def tir_ceil(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1] + + @R.function + def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), dtype="int32"): + gv = R.call_tir(Expected.tir_ceil, (x,), out_sinfo=R.Tensor((2, 3), dtype="int32")) + return gv + # fmt: on + + mod = LegalizeOps()(Ceil) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_ceil_symbolic(): + # fmt: off + @tvm.script.ir_module + class Ceil: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.ceil(x) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def tir_ceil(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n)) + compute = T.match_buffer(var_compute, (m, n)) + for i0, i1 in T.grid(m, n): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.ceil(rxplaceholder[v_i0, v_i1]) + + @R.function + def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_ceil, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Ceil) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cos(): + # fmt: off + @tvm.script.ir_module + class Cos: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.cos(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.tir_cos, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_cos(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.cos(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Cos) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cos_symbolic(): + # fmt: off + @tvm.script.ir_module + class Cos: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.cos(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_cos, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_cos(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.cos(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Cos) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_exp(): + # fmt: off + @tvm.script.ir_module + class Exp: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.exp(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + gv = R.call_tir(Expected.tir_exp, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32"),): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.exp(rxplaceholder[v_i0, v_i1], dtype="float32") + # fmt: on + + mod = LegalizeOps()(Exp) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_exp_symbolic(): + # fmt: off + @tvm.script.ir_module + class Exp: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.exp(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_exp, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.exp(rxplaceholder[v_i0, v_i1], dtype="float32") + # fmt: on + + mod = LegalizeOps()(Exp) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_floor(): + # fmt: off + @tvm.script.ir_module + class Floor: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.floor(x) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def tir_floor(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.floor(rxplaceholder[v_i0, v_i1]) + + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + gv = R.call_tir(Expected.tir_floor, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Floor) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_floor_int(): + # fmt: off + @tvm.script.ir_module + class Floor: + @R.function + def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.floor(x) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def tir_floor(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1] + + @R.function + def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), dtype="int32"): + gv = R.call_tir(Expected.tir_floor, (x,), out_sinfo=R.Tensor((2, 3), dtype="int32")) + return gv + # fmt: on + + mod = LegalizeOps()(Floor) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_floor_symbolic(): + # fmt: off + @tvm.script.ir_module + class Floor: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.floor(x) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def tir_floor(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n)) + compute = T.match_buffer(var_compute, (m, n)) + for i0, i1 in T.grid(m, n): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.floor(rxplaceholder[v_i0, v_i1]) + + @R.function + def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_floor, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Floor) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_log(): + # fmt: off + @tvm.script.ir_module + class Log: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.log(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.tir_log, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_log(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.log(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Log) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_log_symbolic(): + # fmt: off + @tvm.script.ir_module + class Log: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.log(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_log, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_log(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.log(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Log) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_negative(): + # fmt: off + @tvm.script.ir_module + class Negative: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.negative(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.tir_negative, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_negative(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = rxplaceholder[i0_1, i1_1] * T.float32(-1) + # fmt: on + + mod = LegalizeOps()(Negative) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_negative_symbolic(): + # fmt: off + @tvm.script.ir_module + class Negative: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.negative(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_negative, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_negative(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = rxplaceholder[i0_1, i1_1] * T.float32(-1) + # fmt: on + + mod = LegalizeOps()(Negative) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_round(): + # fmt: off + @tvm.script.ir_module + class Round: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.round(x) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def tir_round(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.round(rxplaceholder[v_i0, v_i1]) + + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + gv = R.call_tir(Expected.tir_round, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Round) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_round_int(): + # fmt: off + @tvm.script.ir_module + class Round: + @R.function + def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.round(x) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def tir_round(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1] + + @R.function + def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), dtype="int32"): + gv = R.call_tir(Expected.tir_round, (x,), out_sinfo=R.Tensor((2, 3), dtype="int32")) + return gv + # fmt: on + + mod = LegalizeOps()(Round) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_round_symbolic(): + # fmt: off + @tvm.script.ir_module + class Round: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.round(x) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def tir_round(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n)) + compute = T.match_buffer(var_compute, (m, n)) + for i0, i1 in T.grid(m, n): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.round(rxplaceholder[v_i0, v_i1]) + + @R.function + def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_round, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Round) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sigmoid(): + # fmt: off + @tvm.script.ir_module + class Sigmoid: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.sigmoid(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.tir_sigmoid, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_sigmoid(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Sigmoid) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sigmoid_symbolic(): + # fmt: off + @tvm.script.ir_module + class Sigmoid: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.sigmoid(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_sigmoid, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_sigmoid(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Sigmoid) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sign(): + # fmt: off + @tvm.script.ir_module + class Sign: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.sign(x) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def tir_sign(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_sign: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_sign"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_sign[v_ax0, v_ax1]) + T_sign[v_ax0, v_ax1] = T.Select(T.float32(0) < rxplaceholder[v_ax0, v_ax1], T.float32(1), T.Select(rxplaceholder[v_ax0, v_ax1] < T.float32(0), T.float32(-1), T.float32(0))) + + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + gv = R.call_tir(Expected.tir_sign, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Sign) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sign_int(): + # fmt: off + @tvm.script.ir_module + class Sign: + @R.function + def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.sign(x) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def tir_sign(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "int32"), T_sign: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_sign"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_sign[v_ax0, v_ax1]) + T_sign[v_ax0, v_ax1] = T.Select(0 < rxplaceholder[v_ax0, v_ax1], 1, T.Select(rxplaceholder[v_ax0, v_ax1] < 0, -1, 0)) + + @R.function + def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), dtype="int32"): + gv = R.call_tir(Expected.tir_sign, (x,), out_sinfo=R.Tensor((2, 3), dtype="int32")) + return gv + # fmt: on + + mod = LegalizeOps()(Sign) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sign_symbolic(): + # fmt: off + @tvm.script.ir_module + class Sign: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.sign(x) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def tir_sign(var_rxplaceholder: T.handle, var_T_sign: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n)) + T_sign = T.match_buffer(var_T_sign, (m, n)) + for ax0, ax1 in T.grid(m, n): + with T.block("T_sign"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_sign[v_ax0, v_ax1]) + T_sign[v_ax0, v_ax1] = T.Select(T.float32(0) < rxplaceholder[v_ax0, v_ax1], T.float32(1), T.Select(rxplaceholder[v_ax0, v_ax1] < T.float32(0), T.float32(-1), T.float32(0))) + + @R.function + def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_sign, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Sign) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sin(): + # fmt: off + @tvm.script.ir_module + class Sin: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.sin(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.tir_sin, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_sin(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sin(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Sin) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sin_symbolic(): + # fmt: off + @tvm.script.ir_module + class Sin: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.sin(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_sin, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_sin(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sin(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Sin) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sqrt(): + # fmt: off + @tvm.script.ir_module + class Sqrt: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.sqrt(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.tir_sqrt, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_sqrt(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sqrt(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Sqrt) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sqrt_symbolic(): + # fmt: off + @tvm.script.ir_module + class Sqrt: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.sqrt(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_sqrt, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_sqrt(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sqrt(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Sqrt) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_tanh(): + # fmt: off + @tvm.script.ir_module + class Tanh: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.tanh(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(Expected.tir_tanh, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_tanh(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.tanh(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Tanh) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_tanh_symbolic(): + # fmt: off + @tvm.script.ir_module + class Tanh: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.tanh(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_tanh, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_tanh(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.tanh(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Tanh) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_clip_symbolic(): + @tvm.script.ir_module + class Clip: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv: R.Tensor((m, n), "float32") = R.clip(x, 5, 8) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.tir_clip, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_clip(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + n = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + compute[v_i0, v_i1] = T.max( + T.min(rxplaceholder[v_i0, v_i1], T.float32(8)), T.float32(5) + ) + + mod = LegalizeOps()(Clip) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py new file mode 100644 index 000000000000..b6189488404f --- /dev/null +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -0,0 +1,444 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R, tir as T +from tvm.script import ir as I +import numpy as np +import tvm.topi.testing + + +def test_basic(): + @tvm.script.ir_module + class Before: + @T.prim_func + def transform_layout_IOHW_to_OIHW( + w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") + ) -> None: + for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): + with T.block("layout_transform"): + o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + out[o, i, h, w] = w1[i, o, h, w] + + @R.function + def main( + x: R.Tensor((1, 3, 224, 224), "float32"), + w1: R.Tensor((3, 16, 3, 3), "float32"), + w2: R.Tensor((16, 16, 3, 3), "float32"), + ) -> R.Tensor((1, 16, 224, 224), "float32"): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + w1_transformed = R.call_tir( + cls.transform_layout_IOHW_to_OIHW, w1, R.Tensor((16, 3, 3, 3), "float32") + ) + conv1 = R.nn.conv2d( + x, w1_transformed, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW" + ) + conv2 = R.nn.conv2d( + conv1, w2, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW" + ) + R.output(conv2) + return conv2 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((1, 3, 224, 224), dtype="float32"), + params: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") + ), + ) -> R.Tensor((1, 16, 224, 224), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((16, 3, 3, 3), dtype="float32") = params[1] + conv1: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d( + x, + lv, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="void", + ) + lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] + conv2: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d( + conv1, + lv1, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="void", + ) + R.output(conv2) + return conv2 + + @T.prim_func + def transform_layout_IOHW_to_OIHW( + w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") + ): + for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): + with T.block("layout_transform"): + o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(w1[i, o, h, w]) + T.writes(out[o, i, h, w]) + out[o, i, h, w] = w1[i, o, h, w] + + @R.function + def main_transform_params( + params: R.Tuple( + R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") + ) + ) -> R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") + ): + cls = Expected + with R.dataflow(): + lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] + lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0] + lv2 = R.call_tir( + cls.transform_layout_IOHW_to_OIHW, + (lv1,), + out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + ) + gv: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 3, 3, 3), dtype="float32"), + ) = (lv, lv2) + R.output(gv) + return gv + + mod = Before + after = relax.transform.LiftTransformParams()(mod) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_tuple(): + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor((1, 16, 224, 224), "float32"), w1: R.Tensor((16, 16, 3, 3), "float32") + ) -> R.Tensor((1, 16, 224, 224), "float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + l0 = (w1,) + l1 = (l0,) + l2 = l1[0] + l3 = l2[0] + conv1 = R.nn.conv2d(x, l3, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW") + conv2 = R.nn.conv2d( + conv1, w1, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW" + ) + R.output(conv2) + return conv2 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((1, 16, 224, 224), dtype="float32"), + params: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") + ), + ) -> R.Tensor((1, 16, 224, 224), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] + conv1: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d( + x, + lv, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="void", + ) + lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] + conv2: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d( + conv1, + lv1, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="void", + ) + R.output(conv2) + return conv2 + + @R.function + def main_transform_params( + params: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) + ) -> R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] + lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] + l0: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) = (lv1,) + l1: R.Tuple(R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32"))) = (l0,) + l2: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) = l1[0] + lv2: R.Tensor((16, 16, 3, 3), dtype="float32") = l2[0] + gv: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + ) = (lv, lv2) + R.output(gv) + return gv + + mod = Before + after = relax.transform.LiftTransformParams()(mod) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_condition(): + """Test case that the conditional statement can't be lifted""" + + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor((1, 16, 224, 224), "float32"), + w1: R.Tensor((16, 16, 3, 3), "float32"), + w2: R.Tensor((16, 16, 3, 3), "float32"), + cond: R.Tensor((), "bool"), + ) -> R.Tensor((1, 16, 224, 224), "float32"): + R.func_attr({"num_input": 1}) + if cond: + w = w1 + else: + w = w2 + with R.dataflow(): + conv1 = R.nn.conv2d(x, w, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW") + R.output(conv1) + return conv1 + + @tvm.script.ir_module + class Expected: + @R.function + def main_transform_params( + params: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((), dtype="bool"), + ) + ) -> R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((), dtype="bool"), + ): + with R.dataflow(): + lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] + lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] + lv2: R.Tensor((), dtype="bool") = params[2] + gv: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((), dtype="bool"), + ) = (lv, lv1, lv2) + R.output(gv) + return gv + + @R.function + def main( + x: R.Tensor((1, 16, 224, 224), "float32"), + params: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((), dtype="bool"), + ), + ) -> R.Tensor((1, 16, 224, 224), "float32"): + gv: R.Tensor((), dtype="bool") = params[2] + if gv: + gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] + w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1 + else: + gv2: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] + w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv2 + with R.dataflow(): + conv1 = R.nn.conv2d(x, w, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW") + R.output(conv1) + return conv1 + + mod = Before + after = relax.transform.LiftTransformParams()(mod) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_multiple_functions(): + @tvm.script.ir_module + class Before: + @R.function + def func1( + x: R.Tensor((256, 256), "float32"), + w1: R.Tensor((256, 256), "float32"), + ) -> R.Tensor((256, 256), "float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + w1_t = R.permute_dims(w1, [1, 0]) + y = R.matmul(x, w1_t) + R.output(y) + return y + + @R.function + def func2( + x: R.Tensor((256, 256), "float32"), + w1: R.Tensor((128, 256), "float32"), + ) -> R.Tensor((256, 128), "float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + w1_t = R.permute_dims(w1, [1, 0]) + y = R.matmul(x, w1_t) + R.output(y) + return y + + @R.function + def func3( + x: R.Tensor((256, 256), "float32"), + w1: R.Tensor((256, 256), "float32"), + ) -> R.Tensor((256, 256), "float32"): + with R.dataflow(): + w1_t = R.permute_dims(w1, [1, 0]) + y = R.matmul(x, w1_t) + R.output(y) + return y + + @tvm.script.ir_module + class Expected: + @R.function + def func1( + x: R.Tensor((256, 256), dtype="float32"), + params: R.Tuple(R.Tensor((256, 256), dtype="float32")), + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = params[0] + y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, lv, out_dtype="void") + R.output(y) + return y + + @R.function + def func1_transform_params( + params: R.Tuple(R.Tensor((256, 256), dtype="float32")) + ) -> R.Tuple(R.Tensor((256, 256), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = params[0] + lv1: R.Tensor((256, 256), dtype="float32") = R.permute_dims(lv, axes=[1, 0]) + gv: R.Tuple(R.Tensor((256, 256), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + @R.function + def func2( + x: R.Tensor((256, 256), dtype="float32"), + params: R.Tuple(R.Tensor((256, 128), dtype="float32")), + ) -> R.Tensor((256, 128), dtype="float32"): + with R.dataflow(): + lv1: R.Tensor((256, 128), dtype="float32") = params[0] + y: R.Tensor((256, 128), dtype="float32") = R.matmul(x, lv1, out_dtype="void") + R.output(y) + return y + + @R.function + def func2_transform_params( + params: R.Tuple(R.Tensor((128, 256), dtype="float32")) + ) -> R.Tuple(R.Tensor((256, 128), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((128, 256), dtype="float32") = params[0] + lv1: R.Tensor((256, 128), dtype="float32") = R.permute_dims(lv, axes=[1, 0]) + gv: R.Tuple(R.Tensor((256, 128), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + @R.function + def func3( + x: R.Tensor((256, 256), dtype="float32"), w1: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + w1_t: R.Tensor((256, 256), dtype="float32") = R.permute_dims(w1, axes=[1, 0]) + y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, w1_t, out_dtype="void") + R.output(y) + return y + + mod = Before + after = relax.transform.LiftTransformParams()(mod) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_stop_lifting(): + @tvm.script.ir_module + class Before: + @R.function + def func1( + x: R.Tensor((256, 256), "float32"), + w1: R.Tensor((256, 256), "float32"), + ) -> R.Tensor((256, 256), "float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + w1_t = R.permute_dims(w1, [1, 0]) + w1_t1 = R.builtin.stop_lift_params(w1_t) + w1_add = R.add(w1_t1, R.const(1, "float32")) + y = R.matmul(x, w1_add) + R.output(y) + return y + + @I.ir_module + class Expected: + @R.function + def func1( + x: R.Tensor((256, 256), dtype="float32"), + params: R.Tuple(R.Tensor((256, 256), dtype="float32")), + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = params[0] + w1_add: R.Tensor((256, 256), dtype="float32") = R.add(lv, R.const(1, "float32")) + y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, w1_add, out_dtype="void") + R.output(y) + return y + + @R.function + def func1_transform_params( + params: R.Tuple(R.Tensor((256, 256), dtype="float32")) + ) -> R.Tuple(R.Tensor((256, 256), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = params[0] + lv1: R.Tensor((256, 256), dtype="float32") = R.permute_dims(lv, axes=[1, 0]) + gv: R.Tuple(R.Tensor((256, 256), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + mod = Before + after = relax.transform.LiftTransformParams()(mod) + tvm.ir.assert_structural_equal(after, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py new file mode 100644 index 000000000000..d5a3b1aa599a --- /dev/null +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -0,0 +1,1070 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import relax +from tvm.script import relax as R + + +@tvm.script.ir_module +class Conv2dReLUx2: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + cls = Conv2dReLUx2 + with R.dataflow(): + lv: R.Tensor( + (1, 64, 56, 56), dtype="float32" + ) = cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1) + gv: R.Tensor( + (1, 64, 54, 54), dtype="float32" + ) = cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data1, + weight11, + padding=[1, 1, 1, 1], + ) + gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu1( + conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + conv1, + weight21, + padding=[0, 0, 0, 0], + ) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv2) + R.output(gv2) + return gv2 + + +@tvm.script.ir_module +class Conv2dReLUx2_merged: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + cls = Conv2dReLUx2_merged + with R.dataflow(): + gv: R.Tensor( + (1, 64, 54, 54), dtype="float32" + ) = cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1( + data, weight1, weight2 + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr( + { + "Primitive": 1, + "Codegen": "dnnl", + "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1", + } + ) + with R.dataflow(): + + @R.function + def lv( + data11: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight111: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Composite": "dnnl.conv2d_relu", "Primitive": 1}) + with R.dataflow(): + lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data11, + weight111, + padding=[1, 1, 1, 1], + ) + gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + lv2: R.Tensor((1, 64, 56, 56), dtype="float32") = lv(data1, weight11) + + @R.function + def lv11( + conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight211: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Composite": "dnnl.conv2d_relu", "Primitive": 1}) + with R.dataflow(): + lv21: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + conv1, + weight211, + padding=[0, 0, 0, 0], + ) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv21) + R.output(gv2) + return gv2 + + gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv11(lv2, weight21) + R.output(gv3) + return gv3 + + +@tvm.script.ir_module +class Diamond: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + cls = Diamond + with R.dataflow(): + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_nn_conv2d( + data, weight + ) + lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_nn_relu(lv2) + lv4: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_nn_gelu(lv2) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_add(lv3, lv4) + R.output(gv2) + return gv2 + + @R.function + def fused_relax_nn_gelu( + lv: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"}) + with R.dataflow(): + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_relu( + lv1: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_add( + lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), + gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, gelu1) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_nn_conv2d( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.conv2d"}) + with R.dataflow(): + gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + data1, + weight1, + padding=[0, 0, 0, 0], + ) + R.output(gv4) + return gv4 + + +@tvm.script.ir_module +class Diamond_merged: + @R.function + def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr( + { + "Codegen": "compiler_A", + "Primitive": 1, + "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add", + } + ) + # block 0 + with R.dataflow(): + + @R.function + def lv( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.conv2d", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + data1, + weight1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="", + ) + R.output(gv4) + return gv4 + + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight) + + @R.function + def lv1( + lv11: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv11) + R.output(gv1) + return gv1 + + lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(lv2) + + @R.function + def lv21( + lv4: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv4) + R.output(gv) + return gv + + lv41: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2) + + @R.function + def lv31( + lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), + gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.add", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, gelu1) + R.output(gv3) + return gv3 + + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv31(lv3, lv41) + R.output(gv2) + return gv2 + + @R.function + def main( + data2: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + cls = Diamond_merged + with R.dataflow(): + gv5: R.Tensor( + (1, 64, 54, 54), dtype="float32" + ) = cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add(data2, weight2) + R.output(gv5) + return gv5 + + +@tvm.script.ir_module +class Diamond_cyclic_dep: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + cls = Diamond_cyclic_dep + with R.dataflow(): + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_nn_conv2d( + data, weight + ) + lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_nn_relu(lv2) + lv4: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_nn_gelu(lv2) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_add(lv3, lv4) + R.output(gv2) + return gv2 + + @R.function + def fused_relax_nn_gelu( + lv: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"}) + with R.dataflow(): + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_relu( + lv1: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_add( + lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), + gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, gelu1) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_nn_conv2d( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.conv2d"}) + with R.dataflow(): + gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + data1, + weight1, + padding=[0, 0, 0, 0], + ) + R.output(gv4) + return gv4 + + +@tvm.script.ir_module +class Diamond_cyclic_dep_merged: + @R.function + def main( + data2: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + cls = Diamond_cyclic_dep_merged + with R.dataflow(): + lv4: R.Tuple( + R.Tensor((1, 64, 54, 54), dtype="float32"), + R.Tensor((1, 64, 54, 54), dtype="float32"), + ) = cls.fused_relax_nn_conv2d_relax_nn_relu(data2, weight2) + lv12: R.Tensor((1, 64, 54, 54), dtype="float32") = lv4[0] + lv22: R.Tensor((1, 64, 54, 54), dtype="float32") = lv4[1] + lv31: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_nn_gelu1(lv12) + gv5: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_add1(lv22, lv31) + R.output(gv5) + return gv5 + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tuple( + R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54), dtype="float32") + ): + R.func_attr( + { + "Primitive": 1, + "Codegen": "compiler_A", + "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu", + } + ) + with R.dataflow(): + + @R.function + def lv( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Composite": "compiler_A.conv2d", "Primitive": 1}) + with R.dataflow(): + gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + data1, + weight1, + padding=[0, 0, 0, 0], + ) + R.output(gv4) + return gv4 + + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight) + + @R.function + def lv1( + lv11: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1}) + with R.dataflow(): + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv11) + R.output(gv1) + return gv1 + + gv11: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(gv) + R.output(gv, gv11) + return (gv, gv11) + + @R.function + def fused_relax_nn_gelu1( + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr( + {"Primitive": 1, "Codegen": "compiler_B", "global_symbol": "fused_relax_nn_gelu1"} + ) + with R.dataflow(): + + @R.function + def lv21( + lv3: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Composite": "compiler_B.gelu", "Primitive": 1}) + with R.dataflow(): + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv3) + R.output(gv2) + return gv2 + + gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_add1( + lv32: R.Tensor((1, 64, 54, 54), dtype="float32"), + lv41: R.Tensor((1, 64, 54, 54), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Codegen": "compiler_A", "global_symbol": "fused_relax_add1"}) + with R.dataflow(): + + @R.function + def lv33( + lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), + gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Composite": "compiler_A.add", "Primitive": 1}) + with R.dataflow(): + gv31: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, gelu1) + R.output(gv31) + return gv31 + + gv6: R.Tensor((1, 64, 54, 54), dtype="float32") = lv33(lv32, lv41) + R.output(gv6) + return gv6 + + +@tvm.script.ir_module +class MultipleProducers: + @R.function + def main( + x1: R.Tensor((10,), dtype="float32"), x2: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + cls = MultipleProducers + with R.dataflow(): + lv1: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_relu(x1) + lv2: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu(x2) + lv3: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_relu(lv1) + lv4: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu(lv2) + gv1: R.Tensor((10,), dtype="float32") = cls.fused_relax_add(lv3, lv4) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_relu( + x11: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11) + R.output(gv2) + return gv2 + + @R.function + def fused_relax_nn_gelu( + x21: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"}) + with R.dataflow(): + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_add( + lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.add(lv, gelu1) + R.output(gv) + return gv + + +@tvm.script.ir_module +class MultipleProducers_merged: + @R.function + def fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add( + x1: R.Tensor((10,), dtype="float32"), x2: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr( + { + "Codegen": "compiler_A", + "Primitive": 1, + "global_symbol": "fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add", + } + ) + # block 0 + with R.dataflow(): + + @R.function + def lv(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11) + R.output(gv2) + return gv2 + + lv1: R.Tensor((10,), dtype="float32") = lv(x1) + + @R.function + def lv11(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) + R.output(gv3) + return gv3 + + lv2: R.Tensor((10,), dtype="float32") = lv11(x2) + lv3: R.Tensor((10,), dtype="float32") = lv(lv1) + lv4: R.Tensor((10,), dtype="float32") = lv11(lv2) + + @R.function + def lv21( + lv5: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.add", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.add(lv5, gelu1) + R.output(gv) + return gv + + gv1: R.Tensor((10,), dtype="float32") = lv21(lv3, lv4) + R.output(gv1) + return gv1 + + @R.function + def main( + x12: R.Tensor((10,), dtype="float32"), x22: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + cls = MultipleProducers_merged + with R.dataflow(): + gv4: R.Tensor( + (10,), dtype="float32" + ) = cls.fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add( + x12, x22 + ) + R.output(gv4) + return gv4 + + +@tvm.script.ir_module +class MultipleProducersCyclic: + @R.function + def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = MultipleProducersCyclic + with R.dataflow(): + lv1: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_relu(x1) + lv2: R.Tensor((10,), dtype="float32") = R.nn.relu(lv1) + lv3: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu(lv2) + gv1: R.Tensor((10,), dtype="float32") = cls.fused_relax_add(lv1, lv3) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_relu( + x11: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11) + R.output(gv2) + return gv2 + + @R.function + def fused_relax_nn_gelu( + x21: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"}) + with R.dataflow(): + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_add( + lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.add(lv, gelu1) + R.output(gv) + return gv + + +@tvm.script.ir_module +class MultipleProducersCyclic_merged: + @R.function + def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = MultipleProducersCyclic_merged + with R.dataflow(): + lv: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_relu1(x1) + lv2: R.Tensor((10,), dtype="float32") = R.nn.relu(lv) + gv: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu_relax_add(lv2, lv) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_relu1( + x11: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr( + {"Codegen": "compiler_A", "Primitive": 1, "global_symbol": "fused_relax_nn_relu1"} + ) + # block 0 + with R.dataflow(): + + @R.function + def lv1(x111: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x111) + R.output(gv2) + return gv2 + + gv1: R.Tensor((10,), dtype="float32") = lv1(x11) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_gelu_relax_add( + lv21: R.Tensor((10,), dtype="float32"), lv11: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr( + { + "Codegen": "compiler_A", + "Primitive": 1, + "global_symbol": "fused_relax_nn_gelu_relax_add", + } + ) + # block 0 + with R.dataflow(): + + @R.function + def lv12(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) + R.output(gv3) + return gv3 + + lv3: R.Tensor((10,), dtype="float32") = lv12(lv21) + + @R.function + def lv22( + lv4: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.add", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv4: R.Tensor((10,), dtype="float32") = R.add(lv4, gelu1) + R.output(gv4) + return gv4 + + gv5: R.Tensor((10,), dtype="float32") = lv22(lv11, lv3) + R.output(gv5) + return gv5 + + +@tvm.script.ir_module +class MergeCompilerRegionsExample: + @R.function + def main( + x1: R.Tensor((10,), dtype="float32"), + x2: R.Tensor((10,), dtype="float32"), + x3: R.Tensor((10,), dtype="float32"), + ) -> R.Tensor((10,), dtype="float32"): + cls = MergeCompilerRegionsExample + with R.dataflow(): + lv: R.Tensor((10,), dtype="float32") = cls.fused_relax_add(x1, x2) + lv1: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu(x3) + lv11: R.Tensor((10,), dtype="float32") = cls.fused_relax_add(lv, lv1) + lv12: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu(lv11) + lv2: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_relu(lv11) + lv21: R.Tensor((10,), dtype="float32") = cls.fused_relax_add(lv12, lv2) + gv1: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_relu(lv21) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_relu( + add2: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.nn.relu(add2) + R.output(gv) + return gv + + @R.function + def fused_relax_add( + x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.add(x11, x21) + R.output(gv2) + return gv2 + + @R.function + def fused_relax_nn_gelu( + x31: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"}) + with R.dataflow(): + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x31) + R.output(gv3) + return gv3 + + +@tvm.script.ir_module +class MergeCompilerRegionsExampleRef: + @R.function + def fused_relax_add_relax_add_relax_nn_relu( + x1: R.Tensor((10,), dtype="float32"), + x2: R.Tensor((10,), dtype="float32"), + lv: R.Tensor((10,), dtype="float32"), + ) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((10,), dtype="float32")): + R.func_attr( + { + "Primitive": 1, + "Codegen": "compiler_A", + "global_symbol": "fused_relax_add_relax_add_relax_nn_relu", + } + ) + with R.dataflow(): + + @R.function + def lv1( + x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21) + R.output(gv) + return gv + + lv2: R.Tensor((10,), dtype="float32") = lv1(x1, x2) + gv1: R.Tensor((10,), dtype="float32") = lv1(lv2, lv) + + @R.function + def lv11(add2: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2) + R.output(gv2) + return gv2 + + gv11: R.Tensor((10,), dtype="float32") = lv11(gv1) + R.output(gv1, gv11) + return (gv1, gv11) + + @R.function + def fused_relax_add_relax_nn_relu( + lv12: R.Tensor((10,), dtype="float32"), lv3: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr( + { + "Primitive": 1, + "Codegen": "compiler_A", + "global_symbol": "fused_relax_add_relax_nn_relu", + } + ) + with R.dataflow(): + + @R.function + def lv21( + x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21) + R.output(gv) + return gv + + lv22: R.Tensor((10,), dtype="float32") = lv21(lv12, lv3) + + @R.function + def lv31(add2: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2) + R.output(gv2) + return gv2 + + gv3: R.Tensor((10,), dtype="float32") = lv31(lv22) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_nn_gelu1( + x3: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr( + {"Primitive": 1, "Codegen": "compiler_B", "global_symbol": "fused_relax_nn_gelu1"} + ) + with R.dataflow(): + + @R.function + def lv4(x31: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"}) + with R.dataflow(): + gv4: R.Tensor((10,), dtype="float32") = R.nn.gelu(x31) + R.output(gv4) + return gv4 + + gv5: R.Tensor((10,), dtype="float32") = lv4(x3) + R.output(gv5) + return gv5 + + @R.function + def main( + x12: R.Tensor((10,), dtype="float32"), + x22: R.Tensor((10,), dtype="float32"), + x32: R.Tensor((10,), dtype="float32"), + ) -> R.Tensor((10,), dtype="float32"): + cls = MergeCompilerRegionsExampleRef + with R.dataflow(): + lv5: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu1(x32) + lv13: R.Tuple( + R.Tensor((10,), dtype="float32"), R.Tensor((10,), dtype="float32") + ) = cls.fused_relax_add_relax_add_relax_nn_relu(x12, x22, lv5) + lv23: R.Tensor((10,), dtype="float32") = lv13[0] + lv32: R.Tensor((10,), dtype="float32") = lv13[1] + lv41: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu1(lv23) + gv6: R.Tensor((10,), dtype="float32") = cls.fused_relax_add_relax_nn_relu(lv41, lv32) + R.output(gv6) + return gv6 + + +@tvm.script.ir_module +class ModuleWithNonComposite: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + cls = ModuleWithNonComposite + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d(data, weight) + conv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv) + R.output(conv) + return conv + + @R.function + def fused_relax_nn_conv2d( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1}) + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data1, + weight1, + padding=[1, 1, 1, 1], + ) + R.output(gv) + return gv + + +@tvm.script.ir_module +class ModuleWithNonComposite_ref: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + cls = ModuleWithNonComposite_ref + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d1( + data, weight + ) + conv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv) + R.output(conv) + return conv + + @R.function + def fused_relax_nn_conv2d1( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr( + {"Codegen": "tensorrt", "Primitive": 1, "global_symbol": "fused_relax_nn_conv2d1"} + ) + with R.dataflow(): + + @R.function + def lv1( + data2: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1}) + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data2, + weight2, + padding=[1, 1, 1, 1], + ) + R.output(gv) + return gv + + gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = lv1(data1, weight1) + R.output(gv1) + return gv1 + + +def check(mod, expected): + partitioned = relax.transform.MergeCompositeFunctions()(mod) + tvm.ir.assert_structural_equal(partitioned, expected) + + +def test_conv2d_relu_x2(): + check(Conv2dReLUx2, Conv2dReLUx2_merged) + + +def test_diamond_cyclic_dep(): + """ + O = Offloaded to A + X = Offloaded to B + + O O + / \\ / \\ + O X --> O + + X + \\ / \\ / + O O + + We cannot merge all 'O' since it would create a cyclic dependency between the group of `X`. + """ + check(Diamond_cyclic_dep, Diamond_cyclic_dep_merged) + + +def test_diamond(): + """ + O = Offloaded to A + + O O + / \\ / \\ + O O --> O O + \\ / \\ / + O O + + """ + check(Diamond, Diamond_merged) + + +def test_merge_producers(): + """ + Test merging multiple producer groups into a single representative group. + O O + | | + O O + \\ / + O + """ + check(MultipleProducers, MultipleProducers_merged) + + +def test_merge_producers_cyclic_dep(): + """ + Test when multiple producer groups being blocked to merge due to circular dependency + in the result. + O + |\\ + | X + | | + | O + |/ + O + """ + check(MultipleProducersCyclic, MultipleProducersCyclic_merged) + + +def test_merge_compiler_regions_example(): + """ + A tricky example from https://discuss.tvm.apache.org/t/relay-improved-graph-partitioning-algorithm/5830 + See also the corresponding test case for Relay MergeCompilerRegions in relay/test_pass_merge_compiler_regions.py. + """ + check( + MergeCompilerRegionsExample, + MergeCompilerRegionsExampleRef, + ) + + +def test_mixed_non_composite(): + check(ModuleWithNonComposite, ModuleWithNonComposite_ref) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py b/tests/python/relax/test_transform_meta_schedule_tuning.py new file mode 100644 index 000000000000..13c81ba962f8 --- /dev/null +++ b/tests/python/relax/test_transform_meta_schedule_tuning.py @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tempfile + +import tvm +import tvm.testing +import tvm.meta_schedule as ms +from tvm import relax +from tvm.ir import transform +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.relax.transform.tuning_api import Trace +from tvm.script import relax as R +from tvm.script import tir as T + +target = tvm.target.Target("llvm --num-cores=16") + + +@tvm.script.ir_module +class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + k = T.int32() + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + C = T.match_buffer(z, (32, 32)) + + for (i0, j0, k0) in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + @T.prim_func + def tir_relu(x: T.handle, y: T.handle): + T.func_attr({"global_symbol": "tir_relu"}) + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + for (i, j) in T.grid(32, 32): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.max(A[vi, vj], 0.0) + + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + cls = InputModule + with R.dataflow(): + lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) + R.output(lv1) + return lv1 + + +# TODO(@sunggg): determine how to pass MS database object across different passes. +# PassContext might be an option, but we already have TuningAPI database. +# (MS database and TuningAPI database will be unified in the future) +# For now, we only support default JSON database config. +def test_ms_tuning_irmodule(): + + mod = InputModule + assert isinstance(mod, IRModule) + + with tempfile.TemporaryDirectory() as work_dir: + with target, transform.PassContext(trace=Trace(mod), opt_level=0): + tuning_pass = relax.transform.MetaScheduleTuneIRMod( + params={}, work_dir=work_dir, max_trials_global=4 + ) + out_mod = tuning_pass(mod) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + tvm.ir.assert_structural_equal(mod, out_mod) + + application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir) + + out_mod = application_pass(mod) + assert not tvm.ir.structural_equal(mod, out_mod) + + +def test_ms_tuning_primfunc(): + mod = InputModule + assert isinstance(mod, IRModule) + with tempfile.TemporaryDirectory() as work_dir: + with target, transform.PassContext(trace=Trace(mod), opt_level=0): + tuning_pass = relax.transform.MetaScheduleTuneTIR( + work_dir=work_dir, max_trials_global=4 + ) + out_mod = tuning_pass(mod) + assert PassContext.current().get_trace_stack_size() == 1 + # TODO (@sunggg): Need to determine how to track subgraph-level tuning traces. + # Currently, we don't track this so the trace size. Revisit this later. + tvm.ir.assert_structural_equal(mod, out_mod) + + application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir) + out_mod = application_pass(mod) + assert not tvm.ir.structural_equal(mod, out_mod) + + +@tvm.script.ir_module +class DefaultScheduledModule: + @T.prim_func + def tir_matmul( + A: T.Buffer((32, 32), "float32"), + B: T.Buffer((32, 32), "float32"), + C: T.Buffer((32, 32), "float32"), + ): + T.func_attr({"global_symbol": "tir_matmul"}) + # with T.block("root"): + for i0_j0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i0_j0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + for k0 in range(32): + with T.block(""): + i = T.axis.spatial(32, (i0_j0_fused_0 * 1024 + i0_j0_fused_1) // 32) + j = T.axis.spatial(32, (i0_j0_fused_0 * 1024 + i0_j0_fused_1) % 32) + k = T.axis.reduce(32, k0) + T.reads(A[i, k], B[j, k]) + T.writes(C[i, j]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[j, k] + + @T.prim_func + def tir_relu(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32")): + T.func_attr({"global_symbol": "tir_relu"}) + # with T.block("root"): + for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + with T.block(""): + vi = T.axis.spatial(32, (i_j_fused_0 * 1024 + i_j_fused_1) // 32) + vj = T.axis.spatial(32, (i_j_fused_0 * 1024 + i_j_fused_1) % 32) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = T.max(A[vi, vj], T.float32(0)) + + @R.function + def main( + x: R.Tensor((32, 32), dtype="float32"), w: R.Tensor((32, 32), dtype="float32") + ) -> R.Tensor((32, 32), dtype="float32"): + with R.dataflow(): + lv0 = R.call_tir( + DefaultScheduledModule.tir_matmul, + (x, w), + out_sinfo=R.Tensor((32, 32), dtype="float32"), + ) + lv1 = R.call_tir( + DefaultScheduledModule.tir_relu, + (lv0,), + out_sinfo=R.Tensor((32, 32), dtype="float32"), + ) + R.output(lv1) + return lv1 + + +def test_ms_database_apply_fallback(): + mod = InputModule + target_cuda = tvm.target.Target("nvidia/geforce-rtx-3090-ti") + assert isinstance(mod, IRModule) + with tempfile.TemporaryDirectory() as work_dir: + with target_cuda, transform.PassContext(trace=Trace(mod), opt_level=0): + tuning_pass = relax.transform.MetaScheduleTuneTIR( + work_dir=work_dir, max_trials_global=0 + ) + out_mod = tuning_pass(mod) + tvm.ir.assert_structural_equal(mod, out_mod) + default_pass = tvm.tir.transform.DefaultGPUSchedule() + out_mod = default_pass(mod) + tvm.ir.assert_structural_equal(out_mod, DefaultScheduledModule) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py new file mode 100644 index 000000000000..874e83c7f955 --- /dev/null +++ b/tests/python/relax/test_transform_normalize.py @@ -0,0 +1,556 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +import tvm.testing +from tvm import relax +from tvm import tir +from tvm.ir.base import assert_structural_equal + +import tvm.script +from tvm.script import tir as T, relax as R + + +def test_normalize_function(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor([m, n], "float16")) + + # Note: the parser automatically normalize the IR written in TVMScript, + # so we manually construct the function here. + mul_add = relax.Function( + [x], + relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), + ret_struct_info=R.Tensor("float16", ndim=2), + ) + + # Note: from_expr api names private function (function without global_symbol) as "main" + before_mod = tvm.IRModule.from_expr(mul_add) + + after_mod = relax.transform.Normalize()(before_mod) + + @R.function + def expected(x: R.Tensor(("m", "n"), "float16")) -> R.Tensor(dtype="float16", ndim=2): + gv = R.add(x, x) + gv1 = R.add(x, x) + return R.multiply(gv, gv1) + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_if(): + cond = relax.Var("cond", R.Tensor([], "bool")) + x = relax.Var("x", R.Tensor([1], "float32")) + # TODO(relax-team): add type and shape inference for IfNode + y = relax.Var("y") + + # Note: the parser automatically normalize the IR written in TVMScript, + # so we manually construct the function and If here. + f = relax.Function( + [cond, x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + y, + relax.If( + cond, + relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), + relax.op.add(relax.op.multiply(x, x), relax.op.multiply(x, x)), + ), + ) + ] + ) + ], + y, + ), + ret_struct_info=R.Tensor("float32", ndim=1), + ) + + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @R.function + def expected( + cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") + ) -> R.Tensor(dtype="float32", ndim=1): + if cond: + gv = R.add(x, x) + gv1 = R.add(x, x) + y = R.multiply(gv, gv1) + else: + gv = R.multiply(x, x) + gv1 = R.multiply(x, x) + y = R.add(gv, gv1) + return y + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_no_op(): + # the normalize pass should be no-op for IR in ANF + @tvm.script.ir_module + class ANFMod1: + @R.function + def f(x: R.Tensor(dtype="float32")): + gv = R.add(x, x) + gv1 = R.add(gv, gv) + gv2 = R.add(gv, gv1) + return (gv, gv2) + + before_mod = ANFMod1 + after_mod = relax.transform.Normalize()(before_mod) + assert_structural_equal(before_mod, after_mod, map_free_vars=True) + + @tvm.script.ir_module + class ANFMod2: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.int64(), T.int64() + with R.dataflow(): + lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_dps_packed( + "test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32") + ) + R.output(gv0) + return gv0 + + mod = ANFMod2 + mod_post = relax.transform.Normalize()(mod) + + assert_structural_equal(mod, mod_post) + + +def test_normalize_seq_body(): + # a seq expression with a non-leaf body should bind the body to a var as well + x = relax.Var("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + seq = relax.SeqExpr([], relax.op.add(x, y)) + f = relax.Function( + [x, y], + seq, + ret_struct_info=R.Tensor([], "int32"), + ) + + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @R.function + def expected( + x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32") + ) -> R.Tensor(ndim=0, dtype="int32"): + # normalization inserts a binding like this + z = R.add(x, y) + return z + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_func_body(): + # a function with a body that is not a seq expr should have it wrapped in a seq expr + x = relax.Var("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + f = relax.Function( + [x, y], + relax.op.add(x, y), + ret_struct_info=R.Tensor([], "int32"), + ) + + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @R.function + def expected( + x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32") + ) -> R.Tensor(ndim=0, dtype="int32"): + # result will be a seq expr where the body is a var + z = R.add(x, y) + return z + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_if_branches(): + # an if node's branches must be seq exprs + x = relax.Var("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + # TODO(@relax-team): z has a shape of () and type of DynTensorType(ndim=0), + # but normalization fails to infer these even though it should + z = relax.Var("z") + cond = relax.Var("cond", R.Tensor([], "bool")) + plus = relax.op.add(x, y) + mult = relax.op.multiply(x, y) + if_node = relax.If(cond, plus, mult) + seq = relax.SeqExpr([relax.BindingBlock([relax.VarBinding(z, if_node)])], z) + f = relax.Function( + [cond, x, y], + seq, + ret_struct_info=R.Tensor([], "int32"), + ) + + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @R.function + def expected( + cond: R.Tensor((), dtype="bool"), + x: R.Tensor((), dtype="int32"), + y: R.Tensor((), dtype="int32"), + ) -> R.Tensor(ndim=0, dtype="int32"): + # the bodies of the branches will be seq exprs with a binding + if cond: + w = R.add(x, y) + z = w + else: + w = R.multiply(x, y) + z = w + return z + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_if_condition(): + cond = relax.Var("cond", R.Tensor([], "bool")) + x = relax.Var("x", R.Tensor([1], "float32")) + # TODO(relax-team): add type and shape inference for IfNode + y = relax.Var("y") + + # The condition is wrapped in a tuple and then indexed + f = relax.Function( + [cond, x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + y, + relax.If( + relax.TupleGetItem(relax.Tuple([cond]), 0), + relax.op.add(x, x), + relax.op.multiply(x, x), + ), + ) + ] + ) + ], + y, + ), + ret_struct_info=R.Tensor("float32", ndim=1), + ) + + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @R.function + def expected( + cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") + ) -> R.Tensor(dtype="float32", ndim=1): + c = R.TupleGetItem(R.tuple(cond), 0) + if c: + gv = R.add(x, x) + y = gv + else: + gv = R.multiply(x, x) + y = gv + return y + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_tuple_get_item(): + x = relax.Var("x", R.Tensor([], "int32")) + f = relax.Function( + [x], + relax.TupleGetItem( + relax.TupleGetItem( + relax.Tuple([relax.Tuple([x])]), + 0, + ), + 0, + ), + ret_struct_info=R.Tensor([], "int32"), + ) + + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + # TODO: Revisit once we canonicalize SeqExprs (part of normalization?) + # Not using the parser this time because writing it out correctly results in + # *one* binding block, whereas the normalized version has *two* + idx_var = relax.Var("idx_var", R.Tuple([R.Tensor([], "int32")])) + ret_var = relax.Var("ret", R.Tensor([], "int32")) + expected_f = relax.Function( + [x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + idx_var, relax.TupleGetItem(relax.Tuple([relax.Tuple([x])]), 0) + ) + ] + ), + relax.BindingBlock([relax.VarBinding(ret_var, relax.TupleGetItem(idx_var, 0))]), + ], + ret_var, + ), + ret_struct_info=R.Tensor([], "int32"), + ) + expected_mod = tvm.IRModule.from_expr(expected_f) + # apply normalization to fill in type and shape annotations (tedious otherwise) + final_mod = relax.transform.Normalize()(expected_mod) + + assert_structural_equal(after_mod, final_mod) + + +def test_normalize_combine_nearby_blocks(): + x = relax.Var("x", R.Tensor([], "int32")) + v0 = relax.Var("v0", R.Tensor([], "int32")) + v1 = relax.Var("v1", R.Tensor([], "int32")) + v2 = relax.Var("v2", R.Tensor([], "int32")) + v3 = relax.Var("v3", R.Tensor([], "int32")) + f = relax.Function( + [x], + relax.SeqExpr( + [ + relax.DataflowBlock([relax.VarBinding(v0, x)]), + relax.DataflowBlock([relax.VarBinding(v1, v0)]), + relax.BindingBlock([relax.VarBinding(v2, v1)]), + relax.BindingBlock([relax.VarBinding(v3, v2)]), + ], + v3, + ), + ret_struct_info=R.Tensor([], "int32"), + ) + + after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) + + @R.function + def expected(x: R.Tensor((), "int32")): + with R.dataflow(): + v0 = x + v1 = v0 + R.output(v0, v1) + v2 = v1 + v3 = v2 + return v3 + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_nested_seq(): + x = relax.Var("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + z = relax.Var("z", R.Tensor([], "int32")) + seq = relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding(x, relax.const(1)), + relax.VarBinding( + y, + relax.SeqExpr( + [relax.BindingBlock([relax.VarBinding(z, relax.const(2))])], + z, + ), + ), + ] + ) + ], + y, + ) + + f = relax.Function( + [], + seq, + ret_struct_info=R.Tensor([], "int32"), + ) + after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) + + @R.function + def expected(): + x = relax.const(1) + z = relax.const(2) + y = z + return y + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_nested_seq_dataflow(): + x = relax.Var("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + z = relax.Var("z", R.Tensor([], "int32")) + q = relax.Var("u", R.Tensor([], "int32")) + w = relax.DataflowVar("w", R.Tensor([], "int32")) + u = relax.Var("u", R.Tensor([], "int32")) + seq = relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding(x, relax.const(1)), + relax.VarBinding( + y, + relax.SeqExpr( + [ + relax.BindingBlock([relax.VarBinding(q, relax.const(2))]), + relax.DataflowBlock( + [ + relax.VarBinding(w, q), + relax.VarBinding(u, w), + ] + ), + relax.BindingBlock([relax.VarBinding(z, u)]), + ], + z, + ), + ), + ] + ) + ], + y, + ) + + f = relax.Function( + [], + seq, + ret_struct_info=R.Tensor([], "int32"), + ) + after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) + + @R.function + def expected(): + x = relax.const(1) + q = relax.const(2) + with R.dataflow(): + w = q + u = w + R.output(u) + z = u + y = z + return y + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_deeply_nested_seq(): + x = relax.Var("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + z = relax.Var("z", R.Tensor([], "int32")) + u = relax.Var("u", R.Tensor([], "int32")) + v = relax.Var("v", R.Tensor([], "int32")) + w = relax.Var("w", R.Tensor([], "int32")) + _ = relax.Var("w", R.Tensor([], "int32")) + seq = relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding(x, relax.const(1)), + relax.VarBinding( + y, + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + z, + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding(u, relax.const(2)), + relax.MatchCast( + _, u, R.Tensor([], "int32") + ), + relax.VarBinding(v, u), + relax.MatchCast( + w, v, R.Tensor([], "int32") + ), + ] + ) + ], + w, + ), + ) + ] + ) + ], + z, + ), + ), + ] + ) + ], + y, + ) + + f = relax.Function( + [], + seq, + ret_struct_info=R.Tensor([], "int32"), + ) + after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) + + @R.function + def expected(): + x = relax.const(1) + u = relax.const(2) + _ = R.match_cast(u, R.Tensor((), "int32")) + v = u + w = R.match_cast(v, R.Tensor((), "int32")) + z = w + y = z + return y + + assert_structural_equal(after_mod["main"], expected) + + +@pytest.mark.xfail() +def test_nesting_non_dataflow_in_dataflow_error(): + x = relax.DataflowVar("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + z = relax.Var("z", R.Tensor([], "int32")) + seq = relax.SeqExpr( + [ + relax.DataflowBlock( + [ + relax.VarBinding(x, relax.const(1)), + relax.VarBinding( + y, + relax.SeqExpr( + [relax.BindingBlock([relax.VarBinding(z, relax.const(2))])], + z, + ), + ), + ] + ) + ], + y, + ) + f = relax.Function( + [], + seq, + ret_struct_info=R.Tensor([], "int32"), + ) + relax.transform.Normalize()(tvm.IRModule.from_expr(f)) + # should fail due to a normal binding block being inside a dataflowblock + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py new file mode 100644 index 000000000000..ecf2a96064e6 --- /dev/null +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R, tir as T + + +def test_reshape_expand_dims(): + @tvm.script.ir_module + class Module: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"), + T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + ): + for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + rxplaceholder[ + (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3), + (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3), + ] + ) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[ + (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3), + (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3), + ] + + @T.prim_func + def expand_dims( + rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + expand_dims: T.Buffer( + (T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)), + "float32", + ), + ): + for i0, i1, i2, i3, i4 in T.grid( + T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3) + ): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(rxplaceholder[i0_1, i2_1, i4_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i2_1, i4_1] + + @R.function + def main( + x: R.Tensor((8, 3), dtype="float32") + ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"): + cls = Module + with R.dataflow(): + y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) + z = R.call_tir( + cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 1, 3), "float32") + ) + R.output(z) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"), + T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + ): + for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + rxplaceholder[ + (v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) // T.int64(3), + (v_ax1 * T.int64(12) + v_ax2 * T.int64(3) + v_ax2) % T.int64(3), + ] + ) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[ + (v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) // T.int64(3), + (v_ax1 * T.int64(12) + v_ax2 * T.int64(3) + v_ax2) % T.int64(3), + ] + + @T.prim_func + def expand_dims( + rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + expand_dims: T.Buffer( + (T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)), "float32" + ), + ): + for i0, i1, i2, i3, i4 in T.grid( + T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3) + ): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(rxplaceholder[i0_1, i2_1, i4_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i2_1, i4_1] + + @R.function + def main( + x: R.Tensor((8, 3), dtype="float32") + ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"): + with R.dataflow(): + cls = Expected + y: R.Tensor((2, 4, 3), "float32") = R.reshape(x, (2, 4, 3)) + # Note: `z` is the output var of the dataflow block, and is thus + # not expected to be rewritten. + z = R.call_tir( + cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 1, 3), dtype="float32") + ) + R.output(z) + return z + + assert relax.analysis.has_reshape_pattern(Module["expand_dims"]) + mod = relax.transform.RewriteDataflowReshape()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_reshape_pattern_detect(): + # fmt: off + @tvm.script.ir_module + class Module: + @T.prim_func + def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4096), T.int64(5), T.int64(64)), "float32")): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(10)): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(T.int64(2), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) // T.int64(1310720)) + v_ax1 = T.axis.spatial(T.int64(4096), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(1310720) // T.int64(320)) + v_ax2 = T.axis.spatial(T.int64(5), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(320) // T.int64(64)) + v_ax3 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(64)) + T.reads(rxplaceholder[(((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) // T.int64(4096) + v_ax0) % T.int64(2), ((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096), (v_ax2 * T.int64(64) + v_ax3) % T.int64(320)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[(((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) // T.int64(4096) + v_ax0) % T.int64(2), ((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096), (v_ax2 * T.int64(64) + v_ax3) % T.int64(320)] + + @T.prim_func + def expand_dims( + rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(5), T.int64(64)), "float32"), + expand_dims: T.Buffer( + (T.int64(2), T.int64(1), T.int64(4096), T.int64(1), T.int64(5), T.int64(64)), + "float32", + ), + ): + for i0, i1, i2, i3, i4, i5 in T.grid( + T.int64(2), T.int64(1), T.int64(4096), T.int64(1), T.int64(5), T.int64(64) + ): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[i0_1, i2_1, i4_1, i5_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] = rxplaceholder[i0_1, i2_1, i4_1, i5_1] + + @R.function + def main( + x: R.Tensor((2, 4096, 320), dtype="float32") + ) -> R.Tensor((2, 1, 4096, 1, 5, 64), dtype="float32"): + cls = Module + with R.dataflow(): + y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4096, 5, 64), dtype="float32")) + z = R.call_tir( + cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4096, 1, 5, 64), "float32") + ) + R.output(z) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def expand_dims(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(5), T.int64(64)), "float32"), expand_dims_1: T.Buffer((T.int64(2), T.int64(1), T.int64(4096), T.int64(1), T.int64(5), T.int64(64)), "float32")): + # with T.block("root"): + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(1), T.int64(4096), T.int64(1), T.int64(5), T.int64(64)): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[i0_1, i2_1, i4_1, i5_1]) + T.writes(expand_dims_1[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1]) + expand_dims_1[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] = rxplaceholder[i0_1, i2_1, i4_1, i5_1] + + @T.prim_func + def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4096), T.int64(5), T.int64(64)), "float32")): + # with T.block("root"): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(10)): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(T.int64(2), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) // T.int64(1310720)) + v_ax1 = T.axis.spatial(T.int64(4096), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(1310720) // T.int64(320)) + v_ax2 = T.axis.spatial(T.int64(5), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(320) // T.int64(64)) + v_ax3 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(64)) + T.reads(rxplaceholder[(((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) // T.int64(4096) + v_ax0) % T.int64(2), ((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096), (v_ax2 * T.int64(64) + v_ax3) % T.int64(320)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[(((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) // T.int64(4096) + v_ax0) % T.int64(2), ((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096), (v_ax2 * T.int64(64) + v_ax3) % T.int64(320)] + + @R.function + def main(x: R.Tensor((2, 4096, 320), dtype="float32")) -> R.Tensor((2, 1, 4096, 1, 5, 64), dtype="float32"): + cls = Expected + with R.dataflow(): + y: R.Tensor((2, 4096, 5, 64), dtype="float32") = R.reshape(x, R.shape([2, 4096, 5, 64])) + z = R.call_tir(cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4096, 1, 5, 64), dtype="float32")) + R.output(z) + return z + # fmt: on + + assert relax.analysis.has_reshape_pattern(Module["reshape"]) + mod = relax.transform.RewriteDataflowReshape()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_reshape_non_dataflow(): + @tvm.script.ir_module + class Module: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"), + T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + ): + for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + rxplaceholder[ + (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3), + (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3), + ] + ) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[ + (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3), + (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3), + ] + + @R.function + def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), dtype="float32"): + cls = Module + y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) + return y + + assert relax.analysis.has_reshape_pattern(Module["reshape"]) + # The binding var of the call_tir is not a DataflowVar. So the pass does no change. + mod = relax.transform.RewriteDataflowReshape()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +if __name__ == "__main__": + test_reshape_pattern_detect() + # tvm.testing.main() diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py new file mode 100644 index 000000000000..521fcc1924e7 --- /dev/null +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -0,0 +1,840 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import ir as I, relax as R, tir as T + + +def test_basic(): + # fmt: off + @tvm.script.ir_module + class Module: + @T.prim_func + def add(rxplaceholder: T.Buffer(T.int64(8), "float32"), rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer(T.int64(8), "float32")): + T.evaluate(0) + + @T.prim_func + def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), T_reshape: T.Buffer(T.int64(8), "float32")): + T.evaluate(0) + + @T.prim_func + def relu(rxplaceholder: T.Buffer(T.int64(8), "float32"), compute: T.Buffer(T.int64(8), "float32")): + T.evaluate(0) + + @T.prim_func + def log(rxplaceholder: T.Buffer(T.int64(10), "float32"), compute: T.Buffer(T.int64(10), "float32")): + T.evaluate(0) + + @T.prim_func + def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): + T.evaluate(0) + + @T.prim_func + def pad(rxplaceholder: T.Buffer(T.int64(8), "float32"), PadInput: T.Buffer(T.int64(10), "float32")): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = Module + alloc: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), dtype="float32", runtime_device_index=0) + _: R.Tuple() = cls.exp(x, alloc) + lv: R.Tensor((2, 4), dtype="float32") = alloc + lv1: R.Tensor((8,), dtype="float32") = R.reshape(lv, (8,)) + alloc1: R.Tensor((8,), dtype="float32") = R.builtin.alloc_tensor(R.shape([8]), dtype="float32", runtime_device_index=0) + _1: R.Tuple() = cls.relu(lv1, alloc1) + lv2: R.Tensor((8,), dtype="float32") = alloc1 + alloc2: R.Tensor((8,), dtype="float32") = R.builtin.alloc_tensor(R.shape([8]), dtype="float32", runtime_device_index=0) + _2: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2) + lv3: R.Tensor((8,), dtype="float32") = alloc2 + alloc3: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0) + _3: R.Tuple() = cls.pad(lv3, alloc3) + lv4: R.Tensor((10,), dtype="float32") = alloc3 + alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0) + _4: R.Tuple() = cls.log(lv4, alloc4) + gv: R.Tensor((10,), dtype="float32") = alloc4 + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add(rxplaceholder: T.Buffer(T.int64(8), "float32"), rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer(T.int64(8), "float32")): + T.evaluate(0) + + @T.prim_func + def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), T_reshape: T.Buffer(T.int64(8), "float32")): + T.evaluate(0) + + @T.prim_func + def relu(rxplaceholder: T.Buffer(T.int64(8), "float32"), compute: T.Buffer(T.int64(8), "float32")): + T.evaluate(0) + + @T.prim_func + def log(rxplaceholder: T.Buffer(T.int64(10), "float32"), compute: T.Buffer(T.int64(10), "float32")): + T.evaluate(0) + + @T.prim_func + def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): + T.evaluate(0) + + @T.prim_func + def pad(rxplaceholder: T.Buffer(T.int64(8), "float32"), PadInput: T.Buffer(T.int64(10), "float32")): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = Expected + storage: R.Object = R.memory.alloc_storage(R.shape([32]), virtual_device_index=0, storage_scope="global", dtype="float32") + alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), dtype="float32") + _: R.Tuple() = cls.exp(x, alloc) + lv: R.Tensor((2, 4), dtype="float32") = alloc + lv1: R.Tensor((8,), dtype="float32") = R.reshape(lv, (8,)) + storage1: R.Object = R.memory.alloc_storage(R.shape([40]), virtual_device_index=0, storage_scope="global", dtype="float32") + alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(storage1, 0, R.shape([8]), dtype="float32") + _1: R.Tuple() = cls.relu(lv1, alloc1) + _2: R.Tuple() = R.memory.kill_tensor(alloc) + _3: R.Tuple() = R.memory.kill_tensor(lv1) + lv2: R.Tensor((8,), dtype="float32") = alloc1 + alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(storage, 0, R.shape([8]), dtype="float32") + _4: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2) + _5: R.Tuple() = R.memory.kill_tensor(alloc1) + lv3: R.Tensor((8,), dtype="float32") = alloc2 + alloc3: R.Tensor((10,), dtype="float32") = R.memory.alloc_tensor(storage1, 0, R.shape([10]), dtype="float32") + _6: R.Tuple() = cls.pad(lv3, alloc3) + _7: R.Tuple() = R.memory.kill_tensor(alloc2) + lv4: R.Tensor((10,), dtype="float32") = alloc3 + alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0) + _8: R.Tuple() = cls.log(lv4, alloc4) + _9: R.Tuple() = R.memory.kill_tensor(alloc3) + gv5: R.Tensor((10,), dtype="float32") = alloc4 + _11: R.Tuple() = R.memory.kill_storage(storage) + _10: R.Tuple() = R.memory.kill_storage(storage1) + return gv5 + # fmt: on + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_different_dtype(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + C: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def add1( + A: T.Buffer((T.int64(2), T.int64(3)), "int32"), + B: T.Buffer((T.int64(2), T.int64(3)), "int32"), + C: T.Buffer((T.int64(2), T.int64(3)), "int32"), + ): + T.evaluate(0) + + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") + ) -> R.Tensor((2, 3), dtype="float32"): + cls = Module + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = cls.add(x, x, alloc) + gv: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="int32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="int32", runtime_device_index=0 + ) + _1: R.Tuple() = cls.add1(y, y, alloc1) + gv1: R.Tensor((2, 3), dtype="int32") = alloc1 + return x + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + C: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def add1( + A: T.Buffer((T.int64(2), T.int64(3)), "int32"), + B: T.Buffer((T.int64(2), T.int64(3)), "int32"), + C: T.Buffer((T.int64(2), T.int64(3)), "int32"), + ): + T.evaluate(0) + + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") + ) -> R.Tensor((2, 3), dtype="float32"): + cls = Expected + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _: R.Tuple() = cls.add(x, x, alloc) + _1: R.Tuple() = R.memory.kill_tensor(alloc) + gv1: R.Tensor((2, 3), dtype="float32") = alloc + storage1: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="int32" + ) + alloc1: R.Tensor((2, 3), dtype="int32") = R.memory.alloc_tensor( + storage1, 0, R.shape([2, 3]), dtype="int32" + ) + _2: R.Tuple() = cls.add1(y, y, alloc1) + _3: R.Tuple() = R.memory.kill_tensor(alloc1) + gv12: R.Tensor((2, 3), dtype="int32") = alloc1 + _5: R.Tuple() = R.memory.kill_storage(storage) + _4: R.Tuple() = R.memory.kill_storage(storage1) + return x + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_dtype_bool(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add1( + A: T.Buffer((T.int64(2), T.int64(3)), "bool"), + B: T.Buffer((T.int64(2), T.int64(3)), "bool"), + C: T.Buffer((T.int64(2), T.int64(3)), "bool"), + ): + T.evaluate(0) + + @R.function + def main(y: R.Tensor((2, 3), dtype="bool")) -> R.Tensor((2, 3), dtype="bool"): + cls = Module + alloc: R.Tensor((2, 3), dtype="bool") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="bool", runtime_device_index=0 + ) + _1: R.Tuple() = cls.add1(y, y, alloc) + gv1: R.Tensor((2, 3), dtype="bool") = alloc + return y + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add1( + A: T.Buffer((T.int64(2), T.int64(3)), "bool"), + B: T.Buffer((T.int64(2), T.int64(3)), "bool"), + C: T.Buffer((T.int64(2), T.int64(3)), "bool"), + ): + T.evaluate(0) + + @R.function + def main(y: R.Tensor((2, 3), dtype="bool")) -> R.Tensor((2, 3), dtype="bool"): + cls = Expected + storage: R.Object = R.memory.alloc_storage( + R.shape([6]), virtual_device_index=0, storage_scope="global", dtype="bool" + ) + alloc: R.Tensor((2, 3), dtype="bool") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="bool" + ) + _2: R.Tuple() = cls.add1(y, y, alloc) + _3: R.Tuple() = R.memory.kill_tensor(alloc) + gv12: R.Tensor((2, 3), dtype="bool") = alloc + _4: R.Tuple() = R.memory.kill_storage(storage) + return y + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_same_dtype(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + C: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.evaluate(0) + + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + cls = Module + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = cls.add(x, x, alloc) + gv: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _1: R.Tuple() = cls.add(y, y, alloc1) + gv1: R.Tensor((2, 3), dtype="float32") = alloc1 + return x + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + C: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.evaluate(0) + + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + cls = Expected + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _: R.Tuple() = cls.add(x, x, alloc) + _1: R.Tuple() = R.memory.kill_tensor(alloc) + gv1: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _2: R.Tuple() = cls.add(y, y, alloc1) + _3: R.Tuple() = R.memory.kill_tensor(alloc1) + gv12: R.Tensor((2, 3), dtype="float32") = alloc1 + _4: R.Tuple() = R.memory.kill_storage(storage) + return x + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_if_cond(): + @tvm.script.ir_module + class Module: + @T.prim_func + def all_less_than_zero(A: T.Buffer((2, 3), "float32"), B: T.Buffer((), "bool")): + T.evaluate(0) + + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + cls = Module + alloc: R.Tensor((), dtype="bool") = R.builtin.alloc_tensor( + R.shape([]), dtype="bool", runtime_device_index=0 + ) + _: R.Tuple() = cls.all_less_than_zero(x, alloc) + x1: R.Tensor((), dtype="bool") = alloc + if x1: + y: R.Tensor((2, 3), dtype="float32") = x + else: + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _1: R.Tuple() = cls.exp(x, alloc1) + gv3: R.Tensor((2, 3), dtype="float32") = alloc1 + y: R.Tensor((2, 3), dtype="float32") = gv3 + return x + + # The pass does no change. + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +def test_if_then_else(): + @tvm.script.ir_module + class Module: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def main( + cond: R.Tensor((), dtype="bool"), x: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + cls = Module + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = cls.exp(x, alloc) + y: R.Tensor((2, 3), dtype="float32") = alloc + if cond: + z: R.Tensor((2, 3), dtype="float32") = y + else: + z: R.Tensor((2, 3), dtype="float32") = y + return x + + # The pass does no change. + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +def test_cross_block_use(): + @tvm.script.ir_module + class Module: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def main( + cond: R.Tensor((), dtype="bool"), x: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + cls = Module + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = cls.exp(x, alloc) + y: R.Tensor((2, 3), dtype="float32") = alloc + if cond: + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _1: R.Tuple() = cls.exp(y, alloc1) + y2: R.Tensor((2, 3), dtype="float32") = alloc1 + z: R.Tensor((2, 3), dtype="float32") = y2 + else: + alloc2: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _2: R.Tuple() = cls.exp(y, alloc2) + y2: R.Tensor((2, 3), dtype="float32") = alloc2 + z: R.Tensor((2, 3), dtype="float32") = y2 + return x + + # The pass does no change. + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +def test_nested_tuple(): + @tvm.script.ir_module + class Module: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + cls = Module + _: R.Tuple() = cls.exp(x, alloc) + y1: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _1: R.Tuple() = cls.exp(x, alloc1) + y2: R.Tensor((2, 3), dtype="float32") = alloc1 + alloc2: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _2: R.Tuple() = cls.exp(x, alloc2) + y3: R.Tensor((2, 3), dtype="float32") = alloc2 + t: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")) = ( + y1, + y2, + ) + nt: R.Tuple( + R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")), + R.Tensor((2, 3), dtype="float32"), + ) = (t, y3) + nt0: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")) = nt[ + 0 + ] + y1_: R.Tensor((2, 3), dtype="float32") = nt0[0] + y2_: R.Tensor((2, 3), dtype="float32") = nt0[1] + y3_: R.Tensor((2, 3), dtype="float32") = nt[1] + alloc3: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _3: R.Tuple() = cls.exp(y1_, alloc3) + z1: R.Tensor((2, 3), dtype="float32") = alloc3 + alloc4: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _4: R.Tuple() = cls.exp(y2_, alloc4) + z2: R.Tensor((2, 3), dtype="float32") = alloc4 + alloc5: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _5: R.Tuple() = cls.exp(y3_, alloc5) + z3: R.Tensor((2, 3), dtype="float32") = alloc5 + return x + + @tvm.script.ir_module + class Expected: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + cls = Expected + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _: R.Tuple() = cls.exp(x, alloc) + y1: R.Tensor((2, 3), dtype="float32") = alloc + storage1: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage1, 0, R.shape([2, 3]), dtype="float32" + ) + _1: R.Tuple() = cls.exp(x, alloc1) + y2: R.Tensor((2, 3), dtype="float32") = alloc1 + storage2: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc2: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage2, 0, R.shape([2, 3]), dtype="float32" + ) + _2: R.Tuple() = cls.exp(x, alloc2) + y3: R.Tensor((2, 3), dtype="float32") = alloc2 + t: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")) = ( + y1, + y2, + ) + nt: R.Tuple( + R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")), + R.Tensor((2, 3), dtype="float32"), + ) = (t, y3) + nt0: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")) = nt[ + 0 + ] + y1_: R.Tensor((2, 3), dtype="float32") = nt0[0] + y2_: R.Tensor((2, 3), dtype="float32") = nt0[1] + y3_: R.Tensor((2, 3), dtype="float32") = nt[1] + storage3: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc3: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage3, 0, R.shape([2, 3]), dtype="float32" + ) + _3: R.Tuple() = cls.exp(y1_, alloc3) + _4: R.Tuple() = R.memory.kill_tensor(alloc) + _11: R.Tuple() = R.memory.kill_tensor(alloc3) + z1: R.Tensor((2, 3), dtype="float32") = alloc3 + alloc4: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _41: R.Tuple() = cls.exp(y2_, alloc4) + _21: R.Tuple() = R.memory.kill_tensor(alloc1) + _31: R.Tuple() = R.memory.kill_tensor(alloc4) + z2: R.Tensor((2, 3), dtype="float32") = alloc4 + alloc5: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage3, 0, R.shape([2, 3]), dtype="float32" + ) + _5: R.Tuple() = cls.exp(y3_, alloc5) + _42: R.Tuple() = R.memory.kill_tensor(alloc2) + _51: R.Tuple() = R.memory.kill_tensor(alloc5) + z3: R.Tensor((2, 3), dtype="float32") = alloc5 + _9: R.Tuple() = R.memory.kill_storage(storage) + _7: R.Tuple() = R.memory.kill_storage(storage1) + _8: R.Tuple() = R.memory.kill_storage(storage2) + _6: R.Tuple() = R.memory.kill_storage(storage3) + return x + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_call_func_other_than_primfunc(): + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _ = R.add(x, alloc) + y: R.Tensor((2, 3), dtype="float32") = alloc + return x + + # The pass does no change. + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +def test_call_packed_external_func(): + @I.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _ = R.call_packed("extern_func", x, alloc, sinfo_args=[R.Tuple()]) + y: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _1 = R.call_packed("extern_func", y, alloc1, sinfo_args=[R.Tuple()]) + z: R.Tensor((2, 3), dtype="float32") = alloc1 + return z + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, R.prim_value(0), R.shape([2, 3]), R.dtype("float32") + ) + _: R.Tuple = R.call_packed("extern_func", x, alloc, sinfo_args=(R.Tuple(),)) + y: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), R.dtype("float32"), R.prim_value(0) + ) + _1: R.Tuple = R.call_packed("extern_func", y, alloc1, sinfo_args=(R.Tuple(),)) + _2: R.Tuple = R.memory.kill_tensor(alloc) + z: R.Tensor((2, 3), dtype="float32") = alloc1 + _3: R.Tuple = R.memory.kill_storage(storage) + return z + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_symbolic_shape(): + @tvm.script.ir_module + class Module: + @T.prim_func + def exp(var_A: T.handle, var_B: T.handle): + m = T.int64() + n = T.int64() + A = T.match_buffer(var_A, (m, n), "float32") + B = T.match_buffer(var_B, (m, n), "float32") + T.evaluate(0) + + @R.function + def main(x: R.Tensor(("m", "n"), "float32")): + m = T.int64() + n = T.int64() + alloc: R.Tensor((m, n), dtype="float32") = R.builtin.alloc_tensor( + R.shape([m, n]), dtype="float32", runtime_device_index=0 + ) + _ = Module.exp(x, alloc) + y: R.Tensor((m, n), dtype="float32") = alloc + return x + + # The pass does no change. + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +def test_zero_reference(): + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + return x + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _: R.Tuple() = R.memory.kill_storage(storage) + return x + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_reshape_param(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(25), T.int64(2)), "float32"), + B: T.Buffer((T.int64(2), T.int64(25), T.int64(2)), "float32"), + C: T.Buffer((T.int64(2), T.int64(25), T.int64(2)), "float32"), + ): + T.evaluate(0) + + @R.function + def main( + x: R.Tensor((2, 50), dtype="float32"), y: R.Tensor((100,), dtype="float32") + ) -> R.Tensor((2, 25, 2), dtype="float32"): + lv: R.Tensor((2, 25, 2), dtype="float32") = R.reshape(x, (2, 25, 2)) + lv1: R.Tensor((2, 25, 2), dtype="float32") = R.reshape(y, (2, 25, 2)) + alloc: R.Tensor((2, 25, 2), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 25, 2]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = Module.add(lv, lv1, alloc) + gv: R.Tensor((2, 25, 2), dtype="float32") = alloc + return gv + + # The pass does no change. + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +def test_multiple_functions(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + C: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def add1( + A: T.Buffer((T.int64(2), T.int64(3)), "int32"), + B: T.Buffer((T.int64(2), T.int64(3)), "int32"), + C: T.Buffer((T.int64(2), T.int64(3)), "int32"), + ): + T.evaluate(0) + + @R.function + def func1( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") + ) -> R.Tensor((2, 3), dtype="float32"): + cls = Module + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = cls.add(x, x, alloc) + gv: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="int32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="int32", runtime_device_index=0 + ) + _1: R.Tuple() = cls.add1(y, y, alloc1) + gv1: R.Tensor((2, 3), dtype="int32") = alloc1 + return x + + @R.function + def func2( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + cls = Module + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = cls.add(x, x, alloc) + gv: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _1: R.Tuple() = cls.add(y, y, alloc1) + gv1: R.Tensor((2, 3), dtype="float32") = alloc1 + return x + + @I.ir_module + class Expected: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + C: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def add1( + A: T.Buffer((T.int64(2), T.int64(3)), "int32"), + B: T.Buffer((T.int64(2), T.int64(3)), "int32"), + C: T.Buffer((T.int64(2), T.int64(3)), "int32"), + ): + T.evaluate(0) + + @R.function + def func1( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") + ) -> R.Tensor((2, 3), dtype="float32"): + cls = Expected + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _: R.Tuple() = cls.add(x, x, alloc) + _1: R.Tuple() = R.memory.kill_tensor(alloc) + gv1: R.Tensor((2, 3), dtype="float32") = alloc + storage1: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="int32" + ) + alloc1: R.Tensor((2, 3), dtype="int32") = R.memory.alloc_tensor( + storage1, 0, R.shape([2, 3]), dtype="int32" + ) + _2: R.Tuple() = cls.add1(y, y, alloc1) + _3: R.Tuple() = R.memory.kill_tensor(alloc1) + gv12: R.Tensor((2, 3), dtype="int32") = alloc1 + _5: R.Tuple() = R.memory.kill_storage(storage) + _4: R.Tuple() = R.memory.kill_storage(storage1) + return x + + @R.function + def func2( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + cls = Expected + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _: R.Tuple() = cls.add(x, x, alloc) + _1: R.Tuple() = R.memory.kill_tensor(alloc) + gv1: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _2: R.Tuple() = cls.add(y, y, alloc1) + _3: R.Tuple() = R.memory.kill_tensor(alloc1) + gv12: R.Tensor((2, 3), dtype="float32") = alloc1 + _4: R.Tuple() = R.memory.kill_storage(storage) + return x + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_to_mixed_precision.py b/tests/python/relax/test_transform_to_mixed_precision.py new file mode 100644 index 000000000000..6b699b5165c3 --- /dev/null +++ b/tests/python/relax/test_transform_to_mixed_precision.py @@ -0,0 +1,845 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +from tvm import relax +import tvm.testing +from tvm.relax.transform import ToMixedPrecision +from tvm.script.parser import ir as I, relax as R + + +def _assert_test(input, expected, expected2): + mod = ToMixedPrecision()(input) + tvm.ir.assert_structural_equal(mod, expected) + mod = ToMixedPrecision(out_dtype="float16")(input) + print(mod.script()) + tvm.ir.assert_structural_equal(mod, expected2) + + +def test_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, dtype="float16") + lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + R.output(gv) + return gv + + @I.ir_module + class Expected2: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, dtype="float16") + lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + lv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float16", + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv2, dtype="float32") + R.output(gv) + return gv + + _assert_test(Input, Expected, Expected2) + + +def test_conv2d_relu(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(lv) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, dtype="float16") + lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv_1: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv2, dtype="float16") + lv3: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(lv_1) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv3, dtype="float32") + R.output(gv) + return gv + + @I.ir_module + class Expected2: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, dtype="float16") + lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + lv_1: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float16", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(lv_1) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv2, dtype="float32") + R.output(gv) + return gv + + _assert_test(Input, Expected, Expected2) + + +def test_relu_conv2d_relu(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 3, 28, 28), "float32") = R.nn.relu(x) + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + x0: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.relu(x) + lv1: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x0, dtype="float16") + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d( + lv1, + lv, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv2, dtype="float16") + lv3: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv3, dtype="float32") + R.output(gv2) + return gv2 + + @I.ir_module + class Expected2: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + x0: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.relu(x) + lv1: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x0, dtype="float16") + gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d( + lv1, + lv, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float16", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv2, dtype="float32") + R.output(gv2) + return gv2 + + _assert_test(Input, Expected, Expected2) + + +def test_conv2d_relu_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + w2: R.Tensor((4, 4, 3, 3), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv2, w2, out_dtype="float32") + R.output(gv3) + return gv3 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + w2: R.Tensor((4, 4, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, dtype="float16") + lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + lv2: R.Tensor((4, 4, 3, 3), dtype="float16") = R.astype(w2, dtype="float16") + lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv3, dtype="float16") + gv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 24, 24), dtype="float32") = R.nn.conv2d( + gv2, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected2: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + w2: R.Tensor((4, 4, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, dtype="float16") + lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + lv2: R.Tensor((4, 4, 3, 3), dtype="float16") = R.astype(w2, dtype="float16") + gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float16", + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(gv) + lv3: R.Tensor((2, 4, 24, 24), dtype="float16") = R.nn.conv2d( + gv2, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float16", + ) + gv3: R.Tensor((2, 4, 24, 24), dtype="float32") = R.astype(lv3, dtype="float32") + R.output(gv3) + return gv3 + + _assert_test(Input, Expected, Expected2) + + +def test_gemm_add_silu(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 320), "float32"), + w1: R.Tensor((320, 1280), "float32"), + w2: R.Tensor((2, 1280), "float32"), + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv0: R.Tensor((2, 1280), "float32") = R.matmul(x, w1, out_dtype="float32") + gv1: R.Tensor((2, 1280), "float32") = R.add(gv0, w2) + gv2: R.Tensor((2, 1280), "float32") = R.nn.silu(gv1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 320), dtype="float32"), + w1: R.Tensor((320, 1280), dtype="float32"), + w2: R.Tensor((2, 1280), dtype="float32"), + ) -> R.Tensor((2, 1280), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 320), dtype="float16") = R.astype(x, dtype="float16") + lv1: R.Tensor((320, 1280), dtype="float16") = R.astype(w1, dtype="float16") + lv2: R.Tensor((2, 1280), dtype="float32") = R.matmul(lv, lv1, out_dtype="float32") + gv0: R.Tensor((2, 1280), dtype="float16") = R.astype(lv2, dtype="float16") + lv3: R.Tensor((2, 1280), dtype="float32") = R.astype(gv0, dtype="float32") + gv1: R.Tensor((2, 1280), dtype="float32") = R.add(lv3, w2) + gv2: R.Tensor((2, 1280), dtype="float32") = R.nn.silu(gv1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected2: + @R.function + def main( + x: R.Tensor((2, 320), dtype="float32"), + w1: R.Tensor((320, 1280), dtype="float32"), + w2: R.Tensor((2, 1280), dtype="float32"), + ) -> R.Tensor((2, 1280), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 320), dtype="float16") = R.astype(x, dtype="float16") + lv1: R.Tensor((320, 1280), dtype="float16") = R.astype(w1, dtype="float16") + gv0: R.Tensor((2, 1280), dtype="float16") = R.matmul(lv, lv1, out_dtype="float16") + lv2: R.Tensor((2, 1280), dtype="float32") = R.astype(gv0, dtype="float32") + gv1: R.Tensor((2, 1280), dtype="float32") = R.add(lv2, w2) + gv2: R.Tensor((2, 1280), dtype="float32") = R.nn.silu(gv1) + R.output(gv2) + return gv2 + + _assert_test(Input, Expected, Expected2) + + +def test_tuple(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + w_2: R.Tensor((4, 4, 3, 3), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv3 = (gv, gv2) + gv4 = (gv3, gv2) + gv5 = gv4[0] + gv6 = gv5[0] + gv7 = R.nn.conv2d(gv6, w_2, out_dtype="float32") + R.output(gv7) + return gv7 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + w_2: R.Tensor((4, 4, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, dtype="float16") + lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + lv2: R.Tensor((4, 4, 3, 3), dtype="float16") = R.astype(w_2, dtype="float16") + lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv3, dtype="float16") + lv4: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv4, dtype="float16") + gv3: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float16"), + R.Tensor((2, 4, 26, 26), dtype="float16"), + ) = (gv, gv2) + gv4: R.Tuple( + R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float16"), + R.Tensor((2, 4, 26, 26), dtype="float16"), + ), + R.Tensor((2, 4, 26, 26), dtype="float16"), + ) = (gv3, gv2) + gv5: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float16"), + R.Tensor((2, 4, 26, 26), dtype="float16"), + ) = gv4[0] + gv6: R.Tensor((2, 4, 26, 26), dtype="float16") = gv5[0] + gv7: R.Tensor((2, 4, 24, 24), dtype="float32") = R.nn.conv2d( + gv6, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + R.output(gv7) + return gv7 + + @I.ir_module + class Expected2: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + w_2: R.Tensor((4, 4, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, dtype="float16") + lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + lv2: R.Tensor((4, 4, 3, 3), dtype="float16") = R.astype(w_2, dtype="float16") + gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float16", + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float16", + ) + gv3: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float16"), + R.Tensor((2, 4, 26, 26), dtype="float16"), + ) = (gv, gv2) + gv4: R.Tuple( + R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float16"), + R.Tensor((2, 4, 26, 26), dtype="float16"), + ), + R.Tensor((2, 4, 26, 26), dtype="float16"), + ) = (gv3, gv2) + gv5: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float16"), + R.Tensor((2, 4, 26, 26), dtype="float16"), + ) = gv4[0] + gv6: R.Tensor((2, 4, 26, 26), dtype="float16") = gv5[0] + lv3: R.Tensor((2, 4, 24, 24), dtype="float16") = R.nn.conv2d( + gv6, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float16", + ) + gv7: R.Tensor((2, 4, 24, 24), dtype="float32") = R.astype(lv3, dtype="float32") + R.output(gv7) + return gv7 + + _assert_test(Input, Expected, Expected2) + + +def test_concat_matmul(): + @I.ir_module + class Input: + @R.function + def main( + lv10: R.Tensor((2, 160), "float32"), + lv12: R.Tensor((2, 160), "float32"), + w: R.Tensor((320, 1280), "float32"), + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + lv13: R.Tensor((2, 320), "float32") = R.concat((lv10, lv12), axis=-1) + lv14: R.Tensor((2, 1280), "float32") = R.matmul(lv13, w, out_dtype="float32") + R.output(lv14) + return lv14 + + @I.ir_module + class Expected: + @R.function + def main( + lv10: R.Tensor((2, 160), dtype="float32"), + lv12: R.Tensor((2, 160), dtype="float32"), + w: R.Tensor((320, 1280), dtype="float32"), + ) -> R.Tensor((2, 1280), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((320, 1280), dtype="float16") = R.astype(w, dtype="float16") + lv13: R.Tensor((2, 320), dtype="float32") = R.concat((lv10, lv12), axis=-1) + lv1: R.Tensor((2, 320), dtype="float16") = R.astype(lv13, dtype="float16") + lv14: R.Tensor((2, 1280), dtype="float32") = R.matmul(lv1, lv, out_dtype="float32") + R.output(lv14) + return lv14 + + @I.ir_module + class Expected2: + @R.function + def main( + lv10: R.Tensor((2, 160), dtype="float32"), + lv12: R.Tensor((2, 160), dtype="float32"), + w: R.Tensor((320, 1280), dtype="float32"), + ) -> R.Tensor((2, 1280), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((320, 1280), dtype="float16") = R.astype(w, dtype="float16") + lv13: R.Tensor((2, 320), dtype="float32") = R.concat((lv10, lv12), axis=-1) + lv1: R.Tensor((2, 320), dtype="float16") = R.astype(lv13, dtype="float16") + lv2: R.Tensor((2, 1280), dtype="float16") = R.matmul(lv1, lv, out_dtype="float16") + lv14: R.Tensor((2, 1280), dtype="float32") = R.astype(lv2, dtype="float32") + R.output(lv14) + return lv14 + + _assert_test(Input, Expected, Expected2) + + +def test_conv2d_softmax(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 3, 26, 26), "float32") = R.nn.conv2d(x, w, padding=(1, 1)) + gv1: R.Tensor((2, 3, 26, 26), "float32") = R.nn.softmax(x, axis=1) + gv2 = R.add(gv, gv1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 3, 3, 3), dtype="float32") + ) -> R.Tensor((2, 3, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + lv1: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, dtype="float16") + lv2: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.conv2d( + lv1, + lv, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(lv2, dtype="float16") + gv1: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.softmax(x, axis=1) + lv3: R.Tensor((2, 3, 28, 28), dtype="float32") = R.astype(gv, dtype="float32") + gv2: R.Tensor((2, 3, 28, 28), dtype="float32") = R.add(lv3, gv1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected2: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 3, 3, 3), dtype="float32") + ) -> R.Tensor((2, 3, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + lv1: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, dtype="float16") + gv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.nn.conv2d( + lv1, + lv, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float16", + ) + gv1: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.softmax(x, axis=1) + lv2: R.Tensor((2, 3, 28, 28), dtype="float32") = R.astype(gv, dtype="float32") + gv2: R.Tensor((2, 3, 28, 28), dtype="float32") = R.add(lv2, gv1) + R.output(gv2) + return gv2 + + _assert_test(Input, Expected, Expected2) + + +def test_conv2d_bias_conv2d(): + @tvm.script.ir_module + class Input: + @R.function + def main( + z: R.Tensor((1, 4, 64, 64), dtype="float32"), + w0: R.Tensor((512, 4, 3, 3), dtype="float16"), + w1: R.Tensor((512,), dtype="float16"), + w2: R.Tensor((4, 4, 1, 1), dtype="float16"), + w3: R.Tensor((4,), dtype="float16"), + ) -> R.Tensor((1, 512, 64, 64), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((512, 4, 3, 3), dtype="float32") = R.wrap_param(w0, dtype="float32") + lv1: R.Tensor((512,), dtype="float32") = R.wrap_param(w1, dtype="float32") + lv140: R.Tensor((4, 4, 1, 1), dtype="float32") = R.wrap_param(w2, dtype="float32") + lv141: R.Tensor((4,), dtype="float32") = R.wrap_param(w3, dtype="float32") + lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d( + z, + lv140, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv143: R.Tensor((1, 4, 1, 1), dtype="float32") = R.reshape(lv141, (1, 4, 1, 1)) + lv144: R.Tensor((1, 4, 64, 64), dtype="float32") = R.add(lv142, lv143) + lv145: R.Tensor((1, 512, 64, 64), dtype="float32") = R.nn.conv2d( + lv144, + lv, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv146: R.Tensor((1, 512, 1, 1), dtype="float32") = R.reshape(lv1, (1, 512, 1, 1)) + lv147: R.Tensor((1, 512, 64, 64), dtype="float32") = R.add(lv145, lv146) + gv: R.Tensor((1, 512, 64, 64), dtype="float32") = lv147 + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + z: R.Tensor((1, 4, 64, 64), dtype="float32"), + w0: R.Tensor((512, 4, 3, 3), dtype="float16"), + w1: R.Tensor((512,), dtype="float16"), + w2: R.Tensor((4, 4, 1, 1), dtype="float16"), + w3: R.Tensor((4,), dtype="float16"), + ) -> R.Tensor((1, 512, 64, 64), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 64, 64), dtype="float16") = R.astype(z, dtype="float16") + lv_1: R.Tensor((512, 4, 3, 3), dtype="float16") = w0 + lv1: R.Tensor((512,), dtype="float16") = w1 + lv140: R.Tensor((4, 4, 1, 1), dtype="float16") = w2 + lv141: R.Tensor((4,), dtype="float16") = w3 + lv1_1: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d( + lv, + lv140, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv142: R.Tensor((1, 4, 64, 64), dtype="float16") = R.astype(lv1_1, dtype="float16") + lv143: R.Tensor((1, 4, 1, 1), dtype="float16") = R.reshape(lv141, (1, 4, 1, 1)) + lv144: R.Tensor((1, 4, 64, 64), dtype="float16") = R.add(lv142, lv143) + lv2: R.Tensor((1, 512, 64, 64), dtype="float32") = R.nn.conv2d( + lv144, + lv_1, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv145: R.Tensor((1, 512, 64, 64), dtype="float16") = R.astype(lv2, dtype="float16") + lv146: R.Tensor((1, 512, 1, 1), dtype="float16") = R.reshape(lv1, (1, 512, 1, 1)) + lv147: R.Tensor((1, 512, 64, 64), dtype="float16") = R.add(lv145, lv146) + gv: R.Tensor((1, 512, 64, 64), dtype="float32") = R.astype(lv147, dtype="float32") + R.output(gv) + return gv + + @I.ir_module + class Expected2: + @R.function + def main( + z: R.Tensor((1, 4, 64, 64), dtype="float32"), + w0: R.Tensor((512, 4, 3, 3), dtype="float16"), + w1: R.Tensor((512,), dtype="float16"), + w2: R.Tensor((4, 4, 1, 1), dtype="float16"), + w3: R.Tensor((4,), dtype="float16"), + ) -> R.Tensor((1, 512, 64, 64), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 4, 64, 64), dtype="float16") = R.astype(z, dtype="float16") + lv_1: R.Tensor((512, 4, 3, 3), dtype="float16") = w0 + lv1: R.Tensor((512,), dtype="float16") = w1 + lv140: R.Tensor((4, 4, 1, 1), dtype="float16") = w2 + lv141: R.Tensor((4,), dtype="float16") = w3 + lv142: R.Tensor((1, 4, 64, 64), dtype="float16") = R.nn.conv2d( + lv, + lv140, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float16", + ) + lv143: R.Tensor((1, 4, 1, 1), dtype="float16") = R.reshape( + lv141, R.shape([1, 4, 1, 1]) + ) + lv144: R.Tensor((1, 4, 64, 64), dtype="float16") = R.add(lv142, lv143) + lv145: R.Tensor((1, 512, 64, 64), dtype="float16") = R.nn.conv2d( + lv144, + lv_1, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float16", + ) + lv146: R.Tensor((1, 512, 1, 1), dtype="float16") = R.reshape( + lv1, R.shape([1, 512, 1, 1]) + ) + lv147: R.Tensor((1, 512, 64, 64), dtype="float16") = R.add(lv145, lv146) + gv: R.Tensor((1, 512, 64, 64), dtype="float32") = R.astype(lv147, dtype="float32") + R.output(gv) + return gv + + binding = { + "w0": np.random.uniform(size=(512, 4, 3, 3)).astype("float16"), + "w1": np.random.uniform(size=(512,)).astype("float16"), + "w2": np.random.uniform(size=(4, 4, 1, 1)).astype("float16"), + "w3": np.random.uniform(size=(4,)).astype("float16"), + } + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + Input = relax.transform.BindParams("main", binding)(Input) + Expected = relax.transform.BindParams("main", binding)(Expected) + Expected2 = relax.transform.BindParams("main", binding)(Expected2) + _assert_test(Input, Expected, Expected2) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tuning_api.py b/tests/python/relax/test_tuning_api.py new file mode 100644 index 000000000000..5c2f165dc31d --- /dev/null +++ b/tests/python/relax/test_tuning_api.py @@ -0,0 +1,782 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import numpy as np +import os.path as osp +import tempfile +from typing import List +from math import isclose + +import tvm +from tvm import ir +from tvm.ir import transform +from tvm.ir.transform import PassContext +from tvm.ir.module import IRModule +from tvm.script import tir as T, relax as R +from tvm import relax +from tvm.relax.expr import Expr, DataflowBlock, Function +from tvm.relax.transform.tuning_api import ( + Choice, + Knob, + Trace, + TuningRecord, + JSONDatabase, + default_generate_candidate, + default_consider_eval_passes, + default_evaluate, + select_best_candidate, + get_trace, +) + + +@tvm.script.ir_module +class TestModule: + @T.prim_func + def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + T.func_attr(({"global_symbol": "addone"})) + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + # Input IRModule. + @R.function + def before(c0: R.Tensor((16, 16), "int32")): + cls = TestModule + lv0 = R.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype="int32")) + return lv0 + + # Expected IRModule after transformation. + @R.function + def expected(c1: R.Tensor((16, 16), "int32")): + lv0 = c1 + return c1 + + +def gen_mod(mod, name, binding): + funcs = {} + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + + for k, v in mod.functions.items(): + if isinstance(v, tvm.relax.Function): + if k.name_hint == name: + # rename to main. + gv = tvm.ir.GlobalVar("main") + funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_struct_info).with_attr( + "global_symbol", "main" + ) + else: + funcs[k] = v + mod = tvm.IRModule(funcs) + return relax.transform.BindParams("main", binding)(mod) + + +# Setup for simple testing with IRModule. +def setup_test(): + mod = TestModule + assert isinstance(mod, tvm.IRModule) + return gen_mod(mod, "before", {}) + + +# Setup for testing with constant folding. +def setup_test_const_folding(): + mod = TestModule + assert isinstance(mod, tvm.IRModule) + # Test setup. + c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(mod, "before", {"c0": c0_np}) + expected = gen_mod(mod, "expected", {"c1": c1_np}) + + return before, expected + + +# Define a choice by using FoldConstant pass. +@tvm.register_func("testing.apply_fold_constant") +def apply_fold_constant(mod): + return relax.transform.FoldConstant()(mod) + + +@tvm.register_func("testing.add_global_symbol") +def add_global_symbol(mod, func_name, global_symbol): + mod[func_name] = mod[func_name].with_attr("global_symbol", global_symbol) + return mod + + +@tvm.register_func("testing.check_num_functions") +def check_num_funcs(mod, N): + # Explicit type specification is necessary. + # Otherwise, PackedFunc cannot derive the return type correctly. + # e.g., Check failed: type_code_ == kDLInt (8 vs. 0) : expected int but got Object + return bool(len(mod.functions) == N) + + +def test_choice(): + # Test setup. + ( + before, + expected, + ) = setup_test_const_folding() + + # Without any argument, default setting will be used for both transformation and constraint functions. + # default transformation function will return the original IRModule without any change. + choice = Choice( + # - transform_func_key="relax.tuning_api.Choice.default_transform_func" + # - constr_func_key="relax.tuning_api.Choice.default_constr_func") + ) + # Load transformation function from the choice and apply it. + after = choice.apply_transform_func(before) + tvm.ir.assert_structural_equal(after, before) + + choice = Choice("testing.apply_fold_constant") + # Load transformation function from the choice and apply it. + after = choice.apply_transform_func(before) + tvm.ir.assert_structural_equal(after, expected) + + # Create a choice that tags global symbol onto target function. + choice = Choice("testing.add_global_symbol", ["addone", "test-symbol"]) + after = choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "test-symbol" + # The transformation should be applied with Copy-On-Write. + # So, the original module should be unchanged. + assert before["addone"].attrs["global_symbol"] == "addone" + + # Test choice with impossible constraint + choice = Choice( + transform_func_key="testing.add_global_symbol", + transform_func_args=["addone", "test-symbol"], + constr_func_key="testing.check_num_functions", + constr_func_args=[1000], + ) + # Since the constraint is not met, it should return the original function + after = choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "addone" + + # Test choice with the proper constraint + choice = Choice( + transform_func_key="testing.add_global_symbol", + transform_func_args=["addone", "test-symbol"], + constr_func_key="testing.check_num_functions", + constr_func_args=[2], + ) + # Since the constraint is not met, it should return the original function + after = choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "test-symbol" + # The original module should be unchanged. + assert before["addone"].attrs["global_symbol"] == "addone" + + # Test roundtrip. + # Export as JSON. + json_obj = choice.as_json() + # Import JSON. + new_choice = Choice.from_json(json_obj) + # Test imported choice + after = new_choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "test-symbol" + # The original module should be unchanged. + assert before["addone"].attrs["global_symbol"] == "addone" + + +def test_knob(): + # Test setup. + before, expected = setup_test_const_folding() + + # Users can define a set of choices with list. + choices = [ + Choice("testing.apply_fold_constant"), + Choice(), + ] + + # Define knob. + knob = Knob("TestKnob", choices) + # Check the sanity of decision space. + assert knob.verify(0) + assert knob.verify(1) + assert not knob.verify(3) + + # Check the sanity of each decision. + after_apply = knob.apply(before, 0) + after_noapply = knob.apply(before, 1) + + tvm.ir.assert_structural_equal(after_apply, expected) + tvm.ir.assert_structural_equal(after_noapply, before) + + # Users can define a set of choices with dict. + choices = { + "apply": Choice("testing.apply_fold_constant"), + "noapply": Choice(), + "apply_with_impossible_constr": Choice( + transform_func_key="testing.apply_fold_constant", + constr_func_key="testing.check_num_functions", + constr_func_args=[1000], + ), + } + # Define knob. + knob = Knob("TestKnob", choices) + assert knob.verify("apply") + assert knob.verify("noapply") + assert knob.verify("apply_with_impossible_constr") + assert not knob.verify("INVLAID") + + after_apply = knob.apply(before, "apply") + after_noapply = knob.apply(before, "noapply") + # Because constr was not satisfied, it will return the original IRModule + after_apply_with_constr = knob.apply(before, "apply_with_impossible_constr") + tvm.ir.assert_structural_equal(after_apply, expected) + tvm.ir.assert_structural_equal(after_noapply, before) + tvm.ir.assert_structural_equal(after_apply_with_constr, before) + + # Test roundtrip. + # Export as JSON. + json_obj = knob.as_json() + # Import JSON. + new_knob = Knob.from_json(json_obj) + assert new_knob.name == knob.name + # Test imported knob + assert new_knob.verify("apply") + assert new_knob.verify("noapply") + assert new_knob.verify("apply_with_impossible_constr") + assert not new_knob.verify("INVLAID") + + after_apply = new_knob.apply(before, "apply") + after_noapply = new_knob.apply(before, "noapply") + # Because constr was not satisfied, it will return the original IRModule + after_apply_with_constr = knob.apply(before, "apply_with_impossible_constr") + tvm.ir.assert_structural_equal(after_apply, expected) + tvm.ir.assert_structural_equal(after_noapply, before) + tvm.ir.assert_structural_equal(after_apply_with_constr, before) + + +def test_trace(): + before, expected = setup_test_const_folding() + + # Define choices and its knob. + choices = { + "apply": Choice( + transform_func_key="testing.apply_fold_constant", + transform_func_args=[], + constr_func_key="testing.check_num_functions", + constr_func_args=[2], + ), + "noapply": Choice(), + } + knob = Knob("TestKnob", choices) + + # Define a Trace with empty decision (transformation) history. + trace = Trace(before) + assert trace.size == 0 + + # Define a Trace with single decision (transformation) history. + trace = Trace(before, [knob], ["noapply"]) + assert trace.size == 1 + tvm.ir.assert_structural_equal(trace.in_mod, before) + tvm.ir.assert_structural_equal(trace.out_mod, before) + + # Add a new knob and its decision to the trace. + # It will update the current trace and returns its new output IRModule. + out: IRModule = trace.add(knob, "noapply") + assert trace.size == 2 + tvm.ir.assert_structural_equal(trace.in_mod, before) + tvm.ir.assert_structural_equal(trace.out_mod, before) + tvm.ir.assert_structural_equal(out, before) + # Assume we assign arbitrary performance number. + trace.set_perf(100) + assert trace.perf == 100 + + # Add a new knob and its decision to the trace. + out: IRModule = trace.add(knob, "apply") + tvm.ir.assert_structural_equal(trace.in_mod, before) + tvm.ir.assert_structural_equal(trace.out_mod, expected) + tvm.ir.assert_structural_equal(out, expected) + + assert trace.size == 3 + # Should be initalized when new knob is applied. + assert trace.perf == -1 + + # Test roundtrip. + # Export as JSON. + json_obj = trace.as_json() + # Import JSON. + new_trace = Trace.from_json(json_obj) + tvm.ir.assert_structural_equal(trace.in_mod, new_trace.in_mod) + assert str(trace) == str(new_trace) + assert new_trace.size == 3 + tvm.ir.assert_structural_equal(trace.out_mod, new_trace.out_mod) + + +def test_trace_wrapper(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + assert isinstance(Trace(mod), Trace) + assert isinstance(get_trace(mod), Trace) + assert isinstance(get_trace(mod["main"]), Trace) + assert isinstance(get_trace(mod["addone"]), Trace) + + +def create_tmp_database(tmpdir: str) -> JSONDatabase: + path_workload = osp.join(tmpdir, "workloads.json") + path_tuning_record = osp.join(tmpdir, "tuning_records.json") + path_measurement_record = osp.join(tmpdir, "measurement_records.json") + return JSONDatabase(path_workload, path_tuning_record, path_measurement_record) + + +def test_database(): + def equal_measurement_record(a: List[float], b: List[float]): + assert len(a) == len(b) + for i in range(len(a)): + assert isclose(a[i], b[i], rel_tol=1e-5) + + def equal_tuning_record(a: TuningRecord, b: TuningRecord): + assert str(a.trace) == str(b.trace) + equal_measurement_record(a.run_secs, b.run_secs) + + # Test setup. + ( + mod1, + mod2, + ) = setup_test_const_folding() + knob = Knob("test", {"noapply": Choice()}) + trace = Trace(mod1, [knob, knob], ["noapply", "noapply"]) + target = tvm.target.Target("llvm") + + # Test roundtrip + run_secs = [1.0, 0.9, 0.4] + tuning_record = TuningRecord( + trace, + run_secs, + ) + new_tuning_record = TuningRecord.from_json(json_obj=tuning_record.as_json()) + equal_tuning_record(tuning_record, new_tuning_record) + + with tempfile.TemporaryDirectory() as tmpdir: + database = create_tmp_database(tmpdir) + workload1 = database.commit_workload(mod1) + + database.commit_measurement_record(workload1, target, run_secs) + new_run_secs1 = database.get_measurement_record(workload1, target) + equal_measurement_record(run_secs, new_run_secs1) + workload2 = database.commit_workload(mod2) + new_run_secs2 = database.get_measurement_record(workload2, target) + assert len(new_run_secs2) == 0 + + database.commit_tuning_record(workload1, target, tuning_record) + new_tuning_records = database.get_top_k(workload1, target, top_k=1) + assert len(new_tuning_records) == 1 + equal_tuning_record(tuning_record, new_tuning_records[0]) + new_tuning_records = database.get_top_k(workload1, target, top_k=0) + assert len(new_tuning_records) == 0 + + +def test_default_functions(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + + # Define choice, knob, trace. + choices = {"apply": Choice("testing.apply_fold_constant"), "noapply": Choice()} + knob = Knob("TestKnob", choices) + trace = Trace(mod) + + # Launch a pass pipeline in trace mode. + with tempfile.TemporaryDirectory() as tmpdir: + database = create_tmp_database(tmpdir) + with transform.PassContext(trace=trace, tuning_api_database=database): + # Default generation function expands every valid choice. + candidates = default_generate_candidate([knob], trace) + assert len(candidates) == 2 + + # Default evaluate function uses MetaSchedule builder/runner. + # Since builder/runner are not provided, local builder/runner will be used. + default_evaluate(candidates, "llvm --num-cores=16") + assert PassContext.current().num_evals == 2 + + # Because these candidates are already evaluated, num_evals stays the same. + default_evaluate(candidates, "llvm --num-cores=16") + assert PassContext.current().num_evals == 2 + + # Test with multiple knobs + candidates = default_generate_candidate([knob, knob], trace) + assert len(candidates) == 4 + + # Launch new pass pipeline in trace mode. + with transform.PassContext(trace=trace, tuning_api_database=database): + candidates = default_generate_candidate([knob], trace) + assert len(candidates) == 2 + # Provide tuning pass as an eval pass. + # Note that MockConstFoldingTuningPass() has its own generation function, evaluation function. + # Evaluation would be done in a tornament fashion. + # `default_consider_eval_passes` will convert candidates into the best version by considering eval_passes. + # For example, if we say candidates = [C1, C2] + # `default_consider_eval_passes` will return best form of C1 variant (C11 vs C12) and C2 variant (C21 vs C22) + # that can be generated by eval_passes. + # Assume C11 > C12, C21 < C22, + # new_candidates = [C11, C22] + new_candidates = default_consider_eval_passes( + candidates, [MockConstFoldingTuningPass(eval_passes=[])] + ) + + # len(candidates) == len(new candidates). + assert len(new_candidates) == 2 + # To find the best version of each candidate, it would take 4 evals (C11, C12, C21, C22). + assert PassContext.current().num_evals == 4 + + HeuristicPass = relax.transform.FoldConstant + with transform.PassContext(trace=trace, tuning_api_database=database): + candidates = default_generate_candidate([knob], trace) + assert len(candidates) == 2 + # Provide heuristic pass as an eval pass. + new_candidates = default_consider_eval_passes(candidates, [HeuristicPass()]) + # Since heuristic pass has single decision, it won't need any tornament. + # new_candidates = [C11, C21] + assert len(new_candidates) == 2 + # We only conduct evaluation when its necessary (e.g., choose better candidate in tuning pass). + # Heuristic pass won't conduct any evaluation. + assert PassContext.current().num_evals == 0 + + +# TODO(sunggg): Do we need to serialize pass context as well? +def test_pass_context(): + before, expected = setup_test_const_folding() + HeuristicPass = relax.transform.FoldConstant + # FoldConstant implicitly performs TIR passes (prob for constant evaluation). + # If make_traceable is not provided, the pass infra will make every non-traceable pass traceable by default. + seq = transform.Sequential([HeuristicPass()]) + with transform.PassContext( + trace=Trace(before), + ): + after = seq(before) + tvm.ir.assert_structural_equal(after, expected) + assert PassContext.current().get_trace_stack_size() == 1 + # The exact number of implicit passes might change as TVM develops more passes. + # As of today, this size returns 57. + assert PassContext.current().get_current_trace().size > 1 + + # We can explicitly specify which pass we want to keep track of. + with transform.PassContext(trace=Trace(before), make_traceable=["FoldConstant"]): + after = seq(before) + tvm.ir.assert_structural_equal(after, expected) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + + # Check the functionality of trace stack. + with transform.PassContext(trace=Trace(before)): + assert PassContext.current().get_trace_stack_size() == 1 + PassContext.current().push_trace(Trace(before)) + assert PassContext.current().get_trace_stack_size() == 2 + PassContext.current().pop_trace() + assert PassContext.current().get_trace_stack_size() == 1 + PassContext.current().pop_trace() + assert PassContext.current().get_trace_stack_size() == 0 + + +# Mock evaluation pass for testing. +# Assigns arbitrary performance number to each candidate. +def mock_evaluate(candidates: List[Trace], target_str: str, ctx: PassContext): + num_evals = 0 + # Evaluation + for candidate in candidates: + # If this candidate is already evaluated, skip the measurement. + if candidate.perf != -1: + continue + + num_evals += 1 + # Assign arbitrary performance. + mock_perf = 100 - (ctx.num_evals + num_evals) + candidate.set_perf(mock_perf) + # Update number of evals for testing. + ctx.inc_num_evals(num_evals) + + +# Mock tuning pass that determines whether to apply relax.transform.FoldConstant(). +# Each pass invocation will generate two candidates for the incoming IRModule. +# In relax pass infra, each pass will define its own way of generating candidates and evaluating them without needing to know how other passes generate its candidate and evaluate them. +# This will significantly alleviate the development process since it is known to be HARD problem to consider the interaction with (potentially hundreds of) other passes. +@ir.transform.module_pass(opt_level=0, traceable=True) +class MockConstFoldingTuningPass(transform.Pass): + def __init__( + self, + f_generate_candidate=None, + f_evaluate=mock_evaluate, + eval_passes: List[transform.Pass] = None, + required: List[transform.Pass] = [], + ): + self.f_generate_candidate = ( + f_generate_candidate if f_generate_candidate else default_generate_candidate + ) + self.f_evaluate = f_evaluate if f_evaluate else default_evaluate + self.eval_passes = eval_passes + self.required = required + + def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: + trace = ctx.pop_trace() + + # Create mock choices for testing. + choices = {"apply": Choice("testing.apply_fold_constant"), "noapply": Choice()} + # Tuning pass manages a set of transformation functions registered via knob. + knob = Knob("MockTuningKnob", choices) + + candidates = self.f_generate_candidate([knob], trace, self.eval_passes) + self.f_evaluate(candidates, "llvm", ctx) + best_trace = select_best_candidate(candidates) + + ctx.push_trace(best_trace) + return best_trace.out_mod + + +def test_module_pass(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + # Test setup + c0 = np.arange((16 * 16)).astype("int32").reshape(16, 16) + mod = relax.transform.BindParams("main", {"c0": tvm.nd.array(c0)})(mod) + HeuristicPass = relax.transform.FoldConstant + + # Tuning pass without any eval_pass. + mock_pass = MockConstFoldingTuningPass(eval_passes=[]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + + # Heuristic pass should not affect the number of candidates. + mock_pass = MockConstFoldingTuningPass(eval_passes=[HeuristicPass()]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 2 + + # Joint-optimization will increase the search space in the combinatorial way + mock_pass = MockConstFoldingTuningPass(eval_passes=[MockConstFoldingTuningPass(eval_passes=[])]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 2 + + # Joint-optimization can be nested. + mock_pass = MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass(eval_passes=[MockConstFoldingTuningPass(eval_passes=[])]) + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * 2 * 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 3 + + # Tuning pass and heuritic passes can be used together. + # Note that heuristic pass won't increate the search space (num_evals). + # It only increases the length of the trace. + mock_pass = MockConstFoldingTuningPass( + eval_passes=[ + HeuristicPass(), + MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass(eval_passes=[HeuristicPass(), HeuristicPass()]) + ] + ), + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * 2 * 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 6 + + # Users can mix-use sequential application and joint-application. + mock_pass = MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * (2 + 2 + 2) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 4 + + +def test_sequential(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + # Test setup. + c0 = np.arange((16 * 16)).astype("int32").reshape(16, 16) + mod = relax.transform.BindParams("main", {"c0": tvm.nd.array(c0)})(mod) + HeuristicPass = relax.transform.FoldConstant + + # Sequential with a single tuning pass should behave same with a single pass. + seq = transform.Sequential([MockConstFoldingTuningPass(eval_passes=[])]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + + # Sequential pass should increase search space (num_evals) in additive manner. + seq = transform.Sequential( + [ + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == 2 + 2 + 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 3 + + # Heuristic pass will not increase the search space. Just increase trace length. + seq = transform.Sequential( + [ + MockConstFoldingTuningPass(eval_passes=[]), + HeuristicPass(), + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + HeuristicPass(), + ] + ) + + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == 2 + 2 + 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 5 + + # Users can mix-use sequential application and joint-application. + seq = transform.Sequential( + [ + HeuristicPass(), + MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass( + eval_passes=[ + HeuristicPass(), + ] + ) + ] + ), + ] + ), + MockConstFoldingTuningPass(eval_passes=[]), + HeuristicPass(), + ] + ) + + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == (2 * 2 * 2) + 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 7 + + +def test_passes_with_mixed_granularities(): + @tvm.script.ir_module + class MockModule: + @R.function + def f1(x: R.Tensor(("m", "n"), "float32")): + with R.dataflow(): + lv0 = R.multiply(x, x) + gv0 = R.add(x, x) + R.output(gv0) + return gv0 + + @R.function + def main(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): + with R.dataflow(): + lv0 = R.multiply(x, y) + gv0 = R.add(lv0, y) + R.output(gv0) + gv1 = R.multiply(x, y) + gv2 = R.add(gv1, y) + return (gv0, gv1, gv2) + + mod = MockModule + assert isinstance(mod, tvm.IRModule) + + # Helper function for tuning + def pass_func( + mod: IRModule, ctx: PassContext, eval_passes: List[transform.Pass] = None + ) -> IRModule: + trace = ctx.pop_trace() + + # Create mock choices for testing + choices = [Choice(), Choice(), Choice()] + # Tuning pass manages a set of transformation functions registered via knob. + knob = Knob("MockTuningKnob", choices) + + candidates = default_generate_candidate([knob], trace, eval_passes) + mock_evaluate(candidates, "llvm", ctx) + best_trace = select_best_candidate(candidates) + + ctx.push_trace(best_trace) + return best_trace.out_mod + + @ir.transform.module_pass(opt_level=0, traceable=True) + def MockModulePass(mod: IRModule, ctx: PassContext) -> IRModule: + # Input granularity == Candidate granularity. + return pass_func(mod, ctx) + + @relax.transform.function_pass(opt_level=0, traceable=True) + def MockFunctionPass(func: Expr, mod: IRModule, ctx: PassContext) -> Function: + # Input granularity > Candidate granularity. + # Start trace with smaller granularity: IRModule->Function. + ctx.push_trace(Trace(IRModule.from_expr(func))) + # Do something. + pass_func(mod, ctx) + # Pop tuned trace and recover the previous trace. + ctx.pop_trace() + return func + + @relax.transform.dataflowblock_pass(opt_level=0, traceable=True) + def MockDataflowBlockPass( + block: DataflowBlock, mod: IRModule, ctx: PassContext + ) -> DataflowBlock: + # TODO(sunggg): figure out how to create IRModule from DataflowBlock + # Provide random binding for now + x = relax.Var("x", R.Tensor([tvm.tir.Var("n", "int64")], "float32")) + seq_expr = relax.SeqExpr([block], x) + func = relax.Function([x], seq_expr, R.Tensor("float32", ndim=-1)) + ctx.push_trace(Trace(IRModule.from_expr(func))) + # Do something + pass_func(mod, ctx) + ctx.pop_trace() + return block + + seq = transform.Sequential( + [ + MockModulePass, + MockFunctionPass, + MockDataflowBlockPass, + ] + ) + + with transform.PassContext(trace=Trace(mod), make_traceable=[]): + _ = seq(mod) + # Trace length and num eval can be different depending on how each function/dataflow block is treated. + assert PassContext.current().get_trace_stack_size() == 1 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py new file mode 100644 index 000000000000..e103e9cddded --- /dev/null +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax, tir, topi +from tvm.script.ir_builder import relax as R +from tvm.script.ir_builder.base import IRBuilder + + +def test_function_simple(): + """ + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + out = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + return out + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + R.func_attr({"Primitive": 1}) + x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) + R.func_ret_struct_info(relax.TensorStructInfo(dtype="float32", ndim=2)) + y = R.emit( + R.call_dps_packed( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) + out = R.emit( + R.call_dps_packed( + "extern_dps_func", y, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) + IRBuilder.name("out", out) + R.func_ret_value(out) + func = ir_builder.get() + # create with BlockBuilder + x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,), attrs={"Primitive": 1}): + y = bb.emit( + relax.call_dps_packed( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) + out = bb.emit( + relax.call_dps_packed( + "extern_dps_func", y, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) + bb.emit_func_output(out) + mod = bb.get() + + tvm.ir.assert_structural_equal(func, mod["foo"]) + # check names + assert func.params[0].name_hint == "x" + assert func.body.body.name_hint == "out" + + +def test_emits(): + """Tests for R.emit, R.emit_match_cast, R.emit_var_binding + + @R.function + def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) -> R.Shape(ndim=2): + m = T.int64() + n = T.int64() + gv: R.Tensor((m,), dtype="float32") = R.match_cast(x, R.Tensor((m,), dtype="float32")) + gv1: R.Tensor((n,), dtype="float32") = R.match_cast(y, R.Tensor((n,), dtype="float32")) + v: R.Tensor((n,), dtype="float32") = gv1 + return R.shape([m, n * 2]) + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + x = R.arg("x", relax.TensorStructInfo(ndim=-1, dtype="float32")) + y = R.arg("y", relax.TensorStructInfo(ndim=-1, dtype="float32")) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + _ = R.emit_match_cast(x, relax.TensorStructInfo((m,), "float32")) + y1 = R.emit_match_cast(y, relax.TensorStructInfo((n,), "float32")) + v = relax.Var("v", relax.TensorStructInfo((n,), "float32")) + vb = relax.VarBinding(v, y1) + v = R.emit_var_binding(vb) + R.emit(v) + + IRBuilder.name("v", v) + R.func_ret_value(relax.ShapeExpr([m, n * 2])) + func = ir_builder.get() + + # create with BlockBuilder + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + x = relax.Var("x", relax.TensorStructInfo(dtype="float32", ndim=-1)) + y = relax.Var("y", relax.TensorStructInfo(dtype="float32", ndim=-1)) + v = relax.Var("v", relax.TensorStructInfo((n,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + _ = bb.match_cast(x, relax.TensorStructInfo((m,), "float32")) + y1 = bb.match_cast(y, relax.TensorStructInfo((n,), "float32")) + bb.emit_normalized(relax.VarBinding(v, y1)) + bb.emit(v) + bb.emit_func_output(relax.ShapeExpr([m, n * 2])) + mod = bb.get() + + tvm.ir.assert_structural_equal(func, mod["foo"]) + + +def test_dataflow_block(): + """ + @R.function + def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): + # block 0 + with R.dataflow(): + lv0 = R.call_dps_packed("extern_func", (x,), R.Tensor((128, 128), dtype="float32")) + gv: Tensor((128, 128), "float32") = lv0 + R.output(gv) + return gv + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) + with R.dataflow() as df: + lv0 = R.emit( + R.call_dps_packed( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) + IRBuilder.name("lv0", lv0) + gv = R.emit(lv0) + IRBuilder.name("gv", gv) + R.output(gv) + (gv,) = df.output_vars + R.func_ret_value(gv) + func = ir_builder.get() + + # create with BlockBuilder + x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + with bb.dataflow(): + lv0 = bb.emit( + relax.call_dps_packed( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + tvm.ir.assert_structural_equal(func, bb.get()["foo"]) + + +def test_regression_py_print(): + # Test that the py_print directs to python builtin print + from tvm.script.ir_builder.relax.ir import py_print # pylint: disable=import-outside-toplevel + + assert py_print == print + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py new file mode 100644 index 000000000000..9b8865b9436f --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser.py @@ -0,0 +1,1366 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +from typing import Optional, Union + +import pytest +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax, tir, topi +from tvm.ir import DummyGlobalInfo +from tvm.script.parser import ir as I +from tvm.script.parser import relax as R +from tvm.script.parser import tir as T + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]] = None, +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if isinstance(parsed, IRModule) and isinstance(roundtrip_mod, IRModule): + assert relax.analysis.well_formed(parsed) + assert relax.analysis.well_formed(roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_simple_func(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): + R.func_attr({"Primitive": 1}) + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + gv1 = R.call_dps_packed("extern_dps_func", gv0, R.Tensor((128, 128), dtype="float32")) + return gv1 + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,), attrs={"Primitive": 1}): + y = bb.emit(relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + out = bb.emit( + relax.call_dps_packed("extern_dps_func", y, R.Tensor((128, 128), dtype="float32")) + ) + bb.emit_func_output(out) + + _check(foo, bb.get()["foo"]) + + +def test_error_report(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + # error: a = b = c is not allowed. + gv0 = gv1 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + return gv0 + + +def test_mismatch_cast_dims_and_ndim(): + with pytest.raises(Exception): + + @R.function + def f( + x: R.Tensor((2, 3), "float32", ndim=3) + ): # error: ndim and the shape dims are mismatch + return x + + +def test_unexpected_num_kw_args(): + with pytest.raises(Exception): + + @R.function + def f(x: R.Tensor(dtype="float32", ndim=1, foo=2)): # error: unexpected kw args foo + return x + + +def test_unexpected_ndim(): + with pytest.raises(Exception): + + @R.function + # error: dim is expected to be non-negative int or -1 for unknown + def f(x: R.Tensor(dtype="float32", ndim=-2)): + return x + + +def test_unexpected_ndim_type(): + with pytest.raises(Exception): + + @R.function + def f(x: R.Tensor(dtype="float32", ndim="1")): # error: dim is expected to be int + return x + + +def test_unexpected_tir_cast_args(): + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor(("m",), "float32")): + m = T.int64() + # tir.cast expects 2 arguments, but got 3 + return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), dtype="float32")) + + +def test_unexpected_tir_args(): + + with pytest.raises(tvm.error.DiagnosticError): + + @tvm.script.ir_module + class TestWellCallTIR: + @T.prim_func + def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + T.func_attr(({"global_symbol": "tir_addone"})) + for i, j in T.grid(16, 16): + with T.block("tir_addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + @R.function + def foo(x: R.Tensor(("m", "m"), "float32")): + m = T.int64() + # tir.max expects 2 arguments, but got 1 + gv = R.call_tir(tir_addone, (x,), R.Tensor((T.max(16),), dtype="float32")) + return gv + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor(("m", "n"), "float32")): + m = T.int64() + # call_tir expected a tir prim_func + return relax.call_tir("extern_func", (x,), R.Tensor((T.max(m),), dtype="float32")) + + +def test_func_type_annotation_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x, y): # error: the parameter type annotation is missing + z = R.add(x, y) + y = z + return y + + +def test_if_mismatch_var_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + z = R.add(w, w) # error: The binding var is expected to `y` + return z + + +def test_unassigned_call_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor): + R.add(x, x) + return x + + +def test_simple_module(): + @I.ir_module + class TestModule: + @T.prim_func + def tir_func( + x: T.Buffer((T.int64(128), T.int64(128)), "float32"), + y: T.Buffer((T.int64(128), T.int64(128)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j in T.grid(T.int64(128), T.int64(128)): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + y[vi, vj] = x[vi, vj] + 1.0 + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): + cls = TestModule + gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128, 128), dtype="float32")) + return gv0 + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") + bb.emit_func_output(out) + + _check(TestModule, bb.get()) + + +def test_emit_te_primfunc_attrs(): + @I.ir_module + class TestModule: + @T.prim_func + def plus_one( + x: T.Buffer((T.int64(128), T.int64(128)), "float32"), + y: T.Buffer((T.int64(128), T.int64(128)), "float32"), + ): + T.func_attr({"some_attr": "foo", "another_attr": True, "tir.noalias": True}) + for i, j in T.grid(T.int64(128), T.int64(128)): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + y[vi, vj] = x[vi, vj] + 1.0 + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): + cls = TestModule + gv0 = R.call_tir(cls.plus_one, x, R.Tensor((128, 128), dtype="float32")) + return gv0 + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + out = bb.emit_te( + lambda x: x + 1, + x, + primfunc_name_hint="plus_one", + primfunc_attrs={"some_attr": "foo", "another_attr": True}, + ) + bb.emit_func_output(out) + _check(TestModule, bb.get()) + + +def test_emit_te(): + @I.ir_module + class EmitTE: + @R.function + def main(x: R.Tensor((10, 20), "float32")) -> R.Tensor((10, 20), dtype="float32"): + lv1 = R.emit_te(topi.add, x, x) + out = R.emit_te(topi.multiply, lv1, lv1) + return out + + bb = relax.BlockBuilder() + x = relax.Var("x", relax.TensorStructInfo([10, 20], "float32")) + with bb.function("main", [x]): + lv1 = bb.emit_te(topi.add, x, x) + out = bb.emit_te(topi.multiply, lv1, lv1) + bb.emit_func_output(out) + + _check(EmitTE, bb.get()) + + +def test_module_with_attr_and_global_info(): + @I.ir_module + class TestModule: + I.module_attrs({"attr": 10}) + I.module_global_infos( + { + "dummy": [ + I.dummy_global_info(), # dummy[0] + I.dummy_global_info(), # dummy[1] + ] + } + ) + + @T.prim_func + def tir_func( + x: T.Buffer((T.int64(128), T.int64(128)), "float32"), + y: T.Buffer((T.int64(128), T.int64(128)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j in T.grid(T.int64(128), T.int64(128)): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + y[vi, vj] = x[vi, vj] + 1.0 + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): + cls = TestModule + gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128, 128), dtype="float32")) + return gv0 + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") + bb.emit_func_output(out) + mod = bb.get() + mod.update_global_info("dummy", [DummyGlobalInfo(), DummyGlobalInfo()]) + mod = mod.with_attr("attr", tvm.tir.IntImm("int32", 10)) + _check(TestModule, mod) + + +def test_relax_tensor_op(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"): + y = R.add(x, x) + z = R.multiply(x, y) + return z + + x = relax.Var("x", R.Tensor((4, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.op.add(x, x)) + z = bb.emit(relax.op.multiply(x, y)) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_relax_base_op(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype="float32") + shape = R.shape_of(alloc) + return shape + + x = relax.Var("x", R.Tensor((4, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)), "float32", 0)) + shape = bb.emit(relax.op.shape_of(alloc)) + bb.emit_func_output(shape) + + _check(foo, bb.get()["foo"]) + + +def test_relax_shape_to_tensor(): + @R.function + def foo(x: R.Shape((4, 4))): + tensor = R.shape_to_tensor(x) + return tensor + + x = relax.Var("x", R.Shape((4, 4))) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + tensor = bb.emit(relax.op.shape_to_tensor(x)) + bb.emit_func_output(tensor) + + _check(foo, bb.get()["foo"]) + + +def test_symbolic_shape(): + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) + return gv0 + + @R.function + def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.int64() + n = T.int64() + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) + return gv0 + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): + m = T.int64() + n = T.int32() # The shape dtype should be int64 + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) + return gv0 + + def _expected(name: str): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = relax.Var("x", R.Tensor([m, n], "float32")) + bb = relax.BlockBuilder() + with bb.function(name, (x,)): + out = bb.emit( + relax.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) + ) + bb.emit_func_output(out) + return bb.get()[name] + + _check(foo, _expected("foo")) + _check(bar, _expected("bar")) + + +def test_shadowing(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + y = R.add(x, x) + z = R.multiply(x, y) + y = R.add(x, y) + y = z + y = R.multiply(y, x) + z = y + return z + + x = relax.Var("x", R.Tensor((4, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.op.add(x, x)) + z = bb.emit(relax.op.multiply(x, y)) + y = bb.emit(relax.op.add(x, y)) + y = bb.emit(z) + y = bb.emit(relax.op.multiply(y, x)) + z = bb.emit(y) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_match_cast(): + @R.function + def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): + m = T.int64() + n = T.int64() + x0 = R.match_cast(x, R.Tensor([m], "float32")) + with R.dataflow(): + y0 = R.match_cast(y, R.Tensor([n], "float32")) + gv = y0 + R.output(gv) + return (x0, R.shape([m, n * 2])) + + x = relax.Var("x", R.Tensor("float32")) + y = relax.Var("y", R.Tensor("float32")) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + y2 = relax.Var("y", R.Tensor([n], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + x0 = bb.match_cast(x, R.Tensor([m], "float32")) + with bb.dataflow(): + y0 = bb.match_cast(y, R.Tensor([n], "float32")) + bb.emit_output(y0) + bb.emit_func_output(relax.Tuple([x0, relax.ShapeExpr([m, n * 2])])) + + _check(foo, bb.get()["foo"]) + + +def test_tuple_return(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + gv0 = R.call_dps_packed("extern_func_0", x, R.Tensor((4, 4), dtype="float32")) + gv1 = R.call_dps_packed("extern_func_1", x, R.Tensor((4, 4), dtype="float32")) + return (gv0, gv1) + + x = relax.Var("x", R.Tensor((4, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + gv0 = bb.emit(relax.call_dps_packed("extern_func_0", x, R.Tensor((4, 4), dtype="float32"))) + gv1 = bb.emit(relax.call_dps_packed("extern_func_1", x, R.Tensor((4, 4), dtype="float32"))) + bb.emit_func_output(relax.Tuple((gv0, gv1))) + + _check(foo, bb.get()["foo"]) + + +def test_tuple_return_2(): + @R.function + def foo(x: R.Tensor("float32", ndim=2)): + n, m = T.int64(), T.int64() + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) + return (x0, R.shape([n + 1, m, 1])) + + x = relax.Var("x", R.Tensor("float32", ndim=2)) + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + x0 = bb.match_cast(x, R.Tensor((n, m), "float32")) + bb.emit_func_output(relax.Tuple([x0, relax.ShapeExpr([n + 1, m, 1])])) + + _check(foo, bb.get()["foo"]) + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") +def test_tuple_binding(): + @R.function + def foo(x: R.Tensor("float32", ndim=2)): + n, m = T.int64(), T.int64() + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) + t0 = (x, x0) + t1 = (x, R.shape([n, m]), t0) + return t1 + + x = relax.Var("x", R.Tensor("float32", ndim=2)) + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + x0 = bb.match_cast(x, R.Tensor((n, m), "float32")) + t0 = bb.emit(relax.Tuple([x, x0])) + t1 = bb.emit(relax.Tuple([x, relax.ShapeExpr([n, m]), t0])) + bb.emit_func_output(t1) + + _check(foo, bb.get()["foo"]) + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") +def test_tuple_get_item(): + @R.function + def foo(x: R.Tensor, y: R.Tensor): + t1 = R.tuple(x, y) + t2 = (x, y) + a = t1[0] + b = R.TupleGetItem(t2, 1) + c = R.add(a, b) + return c + + x = relax.Var("x", R.Tensor()) + y = relax.Var("y", R.Tensor()) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + t1 = bb.emit(relax.Tuple([x, y])) + t2 = bb.emit(relax.Tuple([x, y])) + a = bb.emit(relax.TupleGetItem(t1, 0)) + b = bb.emit(relax.TupleGetItem(t2, 1)) + c = bb.emit(relax.op.add(a, b)) + bb.emit_func_output(c) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_block(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + lv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + lv1 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + gv = lv1 + R.output(gv) + return gv + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + with bb.dataflow(): + lv0 = bb.emit( + relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + ) + lv1 = bb.emit( + relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + ) + gv = bb.emit_output(lv1) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_block_advanced(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + gv1 = R.call_dps_packed("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) + with R.dataflow(): + m = T.int64() + n = T.int64() + lv0 = R.call_dps_packed("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) + lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32")) + gv2 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + gv2 = R.call_dps_packed("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) + gv3 = R.match_cast(gv2, R.Tensor((m, n), "float32")) + gv3 = R.match_cast(lv0, R.Tensor((m, n), "float32")) + gv4 = gv3 + gv5 = gv2 + R.output(gv5, gv4) + gv6 = R.call_dps_packed("extern_func", gv5, R.Tensor((128, 128), dtype="float32")) + gv7 = R.call_dps_packed("extern_func", gv6, R.Tensor((128, 128), dtype="float32")) + return gv7 + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + with bb.function("foo", (x,)): + gv0 = bb.emit( + relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + ) + gv1 = bb.emit( + relax.call_dps_packed("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) + ) + with bb.dataflow(): + lv0 = bb.emit( + relax.call_dps_packed("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) + ) + lv1 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) + gv2 = bb.emit( + relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + ) + gv21 = bb.emit( + relax.call_dps_packed("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) + ) + gv3 = bb.match_cast(gv21, R.Tensor((m, n), "float32")) + gv31 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) + gv32 = bb.emit_output(gv31) + gv22 = bb.emit_output(gv21) + gv4 = bb.emit( + relax.call_dps_packed("extern_func", gv22, R.Tensor((128, 128), dtype="float32")) + ) + gv5 = bb.emit( + relax.call_dps_packed("extern_func", gv4, R.Tensor((128, 128), dtype="float32")) + ) + bb.emit_func_output(gv5) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_binding_after_output(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + R.output(gv) + lv = R.call_tir("extern_func", gv, R.Tensor((128, 128), dtype="float32")) + return gv + + +def test_dataflow_output_global_var(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + with R.dataflow(): + gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) + R.output(gv0, gv1) + return gv1 + + +def test_dataflow_multiple_output(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + R.output(gv) + R.output(gv) + return gv + + +def test_dataflow_output_outside_dataflow_block(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + R.output(gv) + return gv + + +def test_dataflow_scope_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor(ndim=2)): + with R.dataflow(): + y = R.add(x, x) + z = R.multiply(y, x) + w = R.add(z, x) + R.output(y, w) + t = R.multiply(y, z) # z is not in the outer scope + return t + + +def test_return_without_binding(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + return x + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + bb.emit_func_output(x) + + _check(foo, bb.get()["foo"]) + + +def test_multiple_return(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + return x + return x + + +def test_function_without_return(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + + +def test_tensor_type_without_args(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + v = R.call_dps_packed("extern_relu", x, R.Tensor((32, 32), dtype="float32")) + return v + + x = relax.Var("x", R.Tensor((32, 32), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + v = bb.emit(relax.call_dps_packed("extern_relu", x, R.Tensor((32, 32), dtype="float32"))) + bb.emit_func_output(v) + + _check(foo, bb.get()["foo"]) + + +def test_direct_return(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + return x + + x = relax.Var("x", R.Tensor((32, 32), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + bb.emit_func_output(x) + + _check(foo, bb.get()["foo"]) + + +def test_call_packed(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + z = R.call_packed("vm.builtin.copy", x, sinfo_args=R.Tensor((32, 32), "float32")) + return z + + x = relax.Var("x", R.Tensor((32, 32), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + z = bb.emit( + relax.Call( + relax.ExternFunc("vm.builtin.copy"), + (x,), + None, + sinfo_args=[R.Tensor((32, 32), "float32")], + ) + ) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_annotation(): + @R.function + def foo( + x: R.Tensor((32, "m"), "float32"), + y: R.Tensor(("m",), "float32"), + r: R.Tensor(dtype="int64"), + ) -> R.Object: + m = T.int64() + z: R.Tensor((32, m), "float32") = R.multiply(x, y) + w: R.Tensor = R.multiply(z, z) + q: R.Tensor(ndim=2) = R.add(w, w) + t = R.add(w, z) + sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape) + lv: R.Tensor(sh, dtype="float32") = R.reshape(x, sh) + o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, sinfo_args=R.Object) + return o + + def _check_struct_info(binding, expected_sinfo): + tvm.ir.assert_structural_equal(binding.var.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(binding.value.struct_info, expected_sinfo) + + # Cannot use block builder here because we need to check the annotated type, + # which may be inconsistent with deduced type. + assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) + m = relax.get_shape_of(foo.params[0])[1] + bindings = foo.body.blocks[0].bindings + sh = bindings[4].var + + _check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32")) + _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=-1)) + _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=2)) + _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=-1)) + _check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1)) + _check_struct_info(bindings[5], relax.TensorStructInfo(sh)) + _check_struct_info(bindings[6], relax.ObjectStructInfo()) + + +def test_annotate_override(): + @R.function + def foo(x: R.Tensor): + y = x + # z will be treated as object type even though it's a tensor + z: R.Object = R.add(x, y) + return z + + assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) + y_bind, z_bind = foo.body.blocks[0].bindings + assert isinstance(y_bind.var.struct_info, relax.TensorStructInfo) + assert isinstance(z_bind.var.struct_info, relax.ObjectStructInfo) + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def test(x: R.Tensor): + # Error: x is of Tensor StructInfo, which can not annotate to R.Shape. + z: R.Shape = x + return z + + @R.function + def bar(x: R.Tensor): + # x is of Tensor StructInfo, the annotation of `z` is ignored. + z: R.Object = x + return z + + assert isinstance(bar.ret_struct_info, relax.TensorStructInfo) + (z_bind,) = bar.body.blocks[0].bindings + assert isinstance(z_bind.var.struct_info, relax.TensorStructInfo) + + +def test_call_dps_packed_empty_shape(): + @R.function + def foo(x: R.Tensor((), "float32")): + z = R.call_dps_packed("scalar_add", x, R.Tensor((), dtype="float32")) + return z + + (z_bind,) = foo.body.blocks[0].bindings + shape_expr = z_bind.value.sinfo_args[0].shape + + assert isinstance(shape_expr, relax.ShapeExpr) + assert len(shape_expr.values) == 0 + + +def test_call_tir_empty_tuple_arg(): + bb = relax.BlockBuilder() + dummy_param = relax.Var("dummy_param", R.Tensor(())) + with bb.function("foo", [dummy_param]): + output = bb.emit_te(topi.full, shape=(16, 32), dtype="float32", fill_value=1.0) + bb.emit_func_output(output) + + _check(bb.get()) + + +def test_call_tir_with_tir_var(): + @I.ir_module + class Module: + @R.function + def main( + dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2", "float32")) + ) -> R.Tensor(("n * 2",), "float32"): + n = T.int64() + cls = Module + y = R.call_tir(cls.copy, (x,), R.Tensor(((n * 2,)), dtype="float32"), tir_vars=(n,)) + return y + + @T.prim_func + def copy(var_x: T.handle, var_y: T.handle, n: T.int64): + X = T.match_buffer(var_x, (n * 2,), dtype="float32") + Y = T.match_buffer(var_y, (n * 2,), dtype="float32") + for i in T.grid(n * 2): + with T.block("block"): + vi = T.axis.remap("S", [i]) + Y[vi] = X[vi] + + _check(Module) + + +def test_local_function(): + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + @R.function + def outer_func( + c1: R.Tensor((2, 3), "float32") + ) -> R.Callable((R.Tensor(None, "float32", ndim=2),), R.Tensor(None, "float32", ndim=2)): + @R.function + def inner_func(x1: R.Tensor((2, 3), "float32")): + s: R.Tensor((2, 3), "float32") = R.add(x1, c1) + return s + + return inner_func + + in_call = outer_func(x) + res = in_call(y) + return res + + main_bindings = main.body.blocks[0].bindings + assert len(main_bindings) == 3 + outer_func = main_bindings[0].value + assert isinstance(outer_func, relax.Function) + + outer_func_bindings = outer_func.body.blocks[0].bindings + assert len(outer_func_bindings) == 1 + inner_func = outer_func_bindings[0].value + assert isinstance(inner_func, relax.Function) + + +def test_inline_prim_func(): + with pytest.raises(tvm.error.DiagnosticError): + + @I.ir_module + class TestModule: + @R.function + def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), "float32")): + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + + z = relax.call_tir(my_matmul, (x, y), R.Tensor((128, 128), dtype="float32")) + return z + + +def test_cross_function_call(): + @I.ir_module + class Mod0: + @R.function + def foo(x: R.Tensor((10, 5), "float32")): + s = R.add(x, x) + return s + + @R.function + def main(x: R.Tensor((10, 5), "float32")): + cls = Mod0 + inner = cls.foo + gv1 = inner(x) + gv2 = Mod0.foo(x) + return (inner, gv1, gv2) + + @I.ir_module + class Mod1: + @R.function + def main(x: R.Tensor((10, 5), "float32")): + cls = Mod1 + inner = cls.foo + gv1 = inner(x) + gv2 = Mod1.foo(x) + return (inner, gv1, gv2) + + @R.function + def foo(x: R.Tensor((10, 5), "float32")) -> R.Tensor((10, 5), "float32"): + s = R.add(x, x) + return s + + +def test_if_branch(): + @R.function + def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")) -> R.Tensor((1,), "float32"): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return y + + cond, x = foo.params + y_bind = foo.body.blocks[0].bindings[0] + y, ite = y_bind.var, y_bind.value + + assert isinstance(y, relax.Var) + assert y.name_hint == "y" + + assert isinstance(ite, relax.If) + assert isinstance(ite.true_branch, relax.SeqExpr) + assert isinstance(ite.false_branch, relax.SeqExpr) + + def check_call(call, op, args): + assert isinstance(call, relax.Call) + if isinstance(op, str): + assert call.op.name == op + else: + assert call.op == op + tvm.ir.assert_structural_equal(call.args, args) + + w_bind = ite.true_branch.blocks[0].bindings[0] + # the seq exprts in the branches are normalized to bind any call + # in the seq expr "body" to a var + y_bind = ite.true_branch.blocks[-1].bindings[-1] + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "relax.add", [x, x]) + check_call(y_bind.value, "relax.multiply", [w_bind.var, w_bind.var]) + + w_bind = ite.false_branch.blocks[0].bindings[0] + y_bind = ite.false_branch.blocks[-1].bindings[-1] + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "relax.multiply", [x, x]) + check_call(y_bind.value, "relax.add", [w_bind.var, w_bind.var]) + + +def test_if_inside_dataflow(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): + with R.dataflow(): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + R.output(y) + return y + + +def test_var_if_scoping_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return w # error: The w is not defined in the outer scope + + +def test_if_branch_var_scope(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return w + + +def test_erase_to_well_defined(): + @R.function + def foo(x: R.Tensor): + q = x + m, n = T.int64(), T.int64() + z = R.match_cast(q, R.Tensor((m, n))) + w = z + return w + + tvm.ir.assert_structural_equal(foo.ret_struct_info, R.Tensor(ndim=2)) + _check(foo) + + +def test_empty_tuple(): + @R.function + def foo(x: R.Tuple()): + y: R.Tuple() = R.tuple() + return y + + x = relax.Var("x", relax.TupleStructInfo([])) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.Tuple([])) + bb.emit_func_output(y) + + _check(foo, bb.get()["foo"]) + + +def test_symbolic_shape_computing(): + # Tensor Case 1 + @R.function + def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")): + z = R.add(x, y) + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.TensorStructInfo([m + 1], "float32")) + y = relax.Var("y", relax.TensorStructInfo([m, 1], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + z = bb.emit(relax.op.add(x, y)) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + # Tensor Case 2 + @R.function + def bar( + x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32") + ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"): + m = T.int64() + z = R.call_dps_packed("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.TensorStructInfo([m], "float32")) + y = relax.Var("y", relax.TensorStructInfo([tir.max(m, 20)], "float32")) + bb = relax.BlockBuilder() + with bb.function("bar", (x, y)): + z = bb.emit( + relax.call_dps_packed( + "test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,), dtype="float32") + ) + ) + bb.emit_func_output(z) + + _check(bar, bb.get()["bar"]) + + # Shape Case + @R.function + def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): + m = T.int64() + z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.ShapeStructInfo([m])) + y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) + bb = relax.BlockBuilder() + with bb.function("baz", (x, y)): + z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) + bb.emit_func_output(z) + + _check(baz, bb.get()["baz"]) + + # Error Case + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor(("m + 1", "m * 2"), "float32")): # name 'm' is not defined + z = R.add(x, x) + return z + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") +def test_arith_operators(): + @R.function + def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): + a0 = -x + a1 = x + y + a2 = x - y + a3 = x * y + a4 = x / y + a5 = x // y + a6 = x**y + + c0 = x > y + c1 = x < y + c2 = x >= y + c3 = x <= y + + tuple_expr = ((x, x), y) + t0 = tuple_expr[0] + t1 = tuple_expr[1] + t2 = tuple_expr[0][0] # <= Will normalize to two bindings + return (a0, a1, a2, a3, a4, a5, a6, c0, c1, c2, c3, t0, t1, t2) + + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = relax.Var("x", relax.TensorStructInfo([m, n], "float32")) + y = relax.Var("y", relax.TensorStructInfo([m, n], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + a0 = bb.emit(relax.op.negative(x)) + a1 = bb.emit(relax.op.add(x, y)) + a2 = bb.emit(relax.op.subtract(x, y)) + a3 = bb.emit(relax.op.multiply(x, y)) + a4 = bb.emit(relax.op.divide(x, y)) + a5 = bb.emit(relax.op.floor_divide(x, y)) + a6 = bb.emit(relax.op.power(x, y)) + + c0 = bb.emit(relax.op.greater(x, y)) + c1 = bb.emit(relax.op.less(x, y)) + c2 = bb.emit(relax.op.greater_equal(x, y)) + c3 = bb.emit(relax.op.less_equal(x, y)) + + tuple_expr = bb.emit(relax.Tuple((relax.Tuple((x, x)), y))) + t0 = bb.emit(relax.TupleGetItem(tuple_expr, 0)) + t1 = bb.emit(relax.TupleGetItem(tuple_expr, 1)) + tmp = bb.emit(relax.TupleGetItem(tuple_expr, 0)) + t2 = bb.emit(relax.TupleGetItem(tmp, 0)) + bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, a5, a6, c0, c1, c2, c3, t0, t1, t2))) + + _check(foo, bb.get()["foo"]) + + +def test_memory_ops(): + @R.function + def foo(x: R.Tensor(("m", "n"), dtype="float32")): + m = T.int64() + n = T.int64() + storage = R.memory.alloc_storage( + R.shape([4 * m * n]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc = R.memory.alloc_tensor(storage, offset=0, shape=R.shape([m, n]), dtype="float32") + tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0) + gv = tensor + return alloc, gv + + _check(foo) + + +def test_vm_ops(): + @R.function + def foo(x: R.Tensor(("m", "n"), dtype="float32")): + m = T.int64() + n = T.int64() + storage = R.vm.alloc_storage(R.shape([4 * m * n]), runtime_device_index=0, dtype="float32") + alloc = R.vm.alloc_tensor(storage, offset=0, shape=R.shape([m, n]), dtype="float32") + tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0) + tir_dym = R.vm.call_tir_dyn("te_func", (x, tensor, R.ShapeExpr((m, n)))) + return alloc, tir_dym + + _check(foo) + + +def test_builtin_ops(): + @R.function + def foo(x: R.Tensor(("m", "n"), dtype="float32")): + tensor = R.builtin.stop_lift_params(x) + gv = tensor + return gv + + _check(foo) + + +def test_prim_value(): + @R.function + def foo(): + gv = R.call_packed("test", 1, sinfo_args=R.Tensor((32, 32), "float32")) + return gv + + _check(foo) + + +def test_string_imm(): + @R.function + def foo(): + gv = R.call_packed("test", "hello", sinfo_args=R.Tensor((32, 32), "float32")) + return gv + + _check(foo) + + +def test_datatype_imm(): + @R.function + def foo(): + gv = R.call_packed("test", R.dtype("float32"), sinfo_args=R.Tensor((32, 32), "float32")) + return gv + + _check(foo) + + +def test_function_void_return_type(): + @tvm.script.ir_module + class Foo: + @R.function + def main(x: R.Tensor((3, 3), dtype="float32")): + res = Foo.mul(x) + return res + + @R.function + def mul(x: R.Tensor((3, 3), dtype="float32")): + res = R.multiply(x, x) + return res + + _check(Foo) + # Since the return type of function `mul` is not annotated, + # the function `main` regards it as a generic return type. + assert isinstance(Foo["main"].ret_struct_info, relax.ObjectStructInfo) + assert isinstance(Foo["mul"].ret_struct_info, relax.TensorStructInfo) + + @tvm.script.ir_module + class Bar: + @R.function + def main(x1: R.Tensor((3, 3), dtype="float32")): + res1 = Bar.mul(x1) + return res1 + + @R.function + def mul(x: R.Tensor((3, 3), dtype="float32")) -> None: + res = R.multiply(x, x) + return res + + # Since the return type of function `mul` is not annotated, + # the function `main` regards it as a generic return type. + _check(Bar) + tvm.ir.assert_structural_equal(Bar["main"].ret_struct_info, relax.TupleStructInfo([])) + tvm.ir.assert_structural_equal(Bar["mul"].ret_struct_info, relax.TupleStructInfo([])) + + +def test_class_normalize(): + @tvm.script.ir_module + class InputModule: + @R.function + def mul_add(x: R.Tensor) -> R.Tensor: + return R.multiply(R.add(x, x), R.add(x, x)) + + # The parser automatically normalizes the input AST to the following ANF form + @tvm.script.ir_module + class OutputModule: + @R.function + def mul_add(x: R.Tensor) -> R.Tensor: + gv = R.add(x, x) + gv1 = R.add(x, x) + return R.multiply(gv, gv1) + + _check(InputModule, OutputModule) + + +def test_context_aware_parsing(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add( + X: T.Buffer(T.int64(8), "float32"), + Y: T.Buffer((), "float32"), + Z: T.Buffer(T.int64(8), "float32"), + ): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = Module + alloc = R.builtin.alloc_tensor(R.shape([2, 4]), dtype="float32", runtime_device_index=0) + _: R.Tuple() = cls.add(x, R.const(1, "float32"), alloc) + return alloc + + _check(Module) + + # Break the env settings, but context-aware parsing can still handle it + def _break_env(self, *args): + raise RuntimeError("Fail to pass context-aware parsing") + + tvm.ir.GlobalVar.__call__ = _break_env + + _check(Module) + + +def test_unit_tuple_on_rhs_of_assign(): + @I.ir_module + class Module: + @R.function + def main(input: R.Tensor((5, 5))) -> R.Tuple(R.Tensor((5, 5))): + gv = (input,) + return gv + + _check(Module) + + +def test_empty_tuple_on_rhs_of_assign(): + @I.ir_module + class Module: + @R.function + def main(input: R.Tensor((5, 5))) -> R.Tuple(): + gv = () + return gv + + _check(Module) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py new file mode 100644 index 000000000000..d43e9a626b66 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union, Callable + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +(unary_arith_op,) = tvm.testing.parameters( + (relax.op.abs,), + (relax.op.acos,), + (relax.op.acosh,), + (relax.op.asin,), + (relax.op.asinh,), + (relax.op.atan,), + (relax.op.atanh,), + (relax.op.ceil,), + (relax.op.cos,), + (relax.op.cosh,), + (relax.op.exp,), + (relax.op.floor,), + (relax.op.log,), + (relax.op.negative,), + (relax.op.round,), + (relax.op.sigmoid,), + (relax.op.sign,), + (relax.op.sin,), + (relax.op.sinh,), + (relax.op.square,), + (relax.op.sqrt,), + (relax.op.tan,), + (relax.op.tanh,), +) + + +def test_unary_arith(unary_arith_op: Callable): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = unary_arith_op(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(unary_arith_op(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +(unary_check_op,) = tvm.testing.parameters( + (relax.op.isfinite,), + (relax.op.isinf,), + (relax.op.isnan,), +) + + +def test_unary_check(unary_check_op: Callable): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), "bool") = unary_check_op(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(unary_check_op(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +(binary_arith_op,) = tvm.testing.parameters( + (relax.op.add,), + (relax.op.divide,), + (relax.op.floor_divide,), + (relax.op.multiply,), + (relax.op.power,), + (relax.op.subtract,), + (relax.op.maximum,), + (relax.op.minimum,), +) + + +def test_binary_arith(binary_arith_op: Callable): + @R.function + def foo( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32") + ) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = binary_arith_op(x, y) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 1), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y]): + gv = bb.emit(binary_arith_op(x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +(binary_cmp_op,) = tvm.testing.parameters( + (relax.op.equal,), + (relax.op.greater,), + (relax.op.greater_equal,), + (relax.op.less,), + (relax.op.less_equal,), + (relax.op.not_equal,), +) + + +def test_binary_cmp(binary_cmp_op: Callable): + @R.function + def foo( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32") + ) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), "bool") = binary_cmp_op(x, y) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 1), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y]): + gv = bb.emit(binary_cmp_op(x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_relax_ewise_fma(): + @R.function + def foo( + x: R.Tensor((2, 3, 4), dtype="float32"), + y: R.Tensor((2, 3, 4), dtype="float32"), + z: R.Tensor((2, 3, 4), dtype="float32"), + ) -> R.Tensor((2, 3, 4), dtype="float32"): + gv: R.Tensor((2, 3, 4), dtype="float32") = R.ewise_fma(x, y, z) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + y = relax.Var("y", R.Tensor((2, 3, 4), "float32")) + z = relax.Var("z", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y, z]): + gv = bb.emit(relax.op.ewise_fma(x, y, z)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_create.py b/tests/python/relax/test_tvmscript_parser_op_create.py new file mode 100644 index 000000000000..6cbc0ebf906a --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_create.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_full(): + @R.function + def foo(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.full((2, 3), v, dtype="float32") + return gv + + bb = relax.BlockBuilder() + v = relax.Var("v", R.Tensor((), "int32")) + with bb.function("foo", [v]): + gv = bb.emit(relax.op.full((2, 3), v, "float32")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_full_like(): + @R.function + def foo( + x: R.Tensor((2, 3), "float16"), v: R.Tensor((), "float32") + ) -> R.Tensor((2, 3), "float16"): + gv: R.Tensor((2, 3), "float16") = R.full_like(x, v) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float16")) + v = relax.Var("y", R.Tensor((), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, v]): + gv = bb.emit(relax.op.full_like(x, v)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_ones(): + @R.function + def foo(dumb_param: R.Tensor()) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.ones((2, 3), "float32") + return gv + + bb = relax.BlockBuilder() + dumb_param = relax.Var("dumb_param", R.Tensor()) + with bb.function("foo", [dumb_param]): + gv = bb.emit(relax.op.ones((2, 3), "float32")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_ones_like(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.ones_like(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.ones_like(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_zeros(): + @R.function + def foo(dumb_param: R.Tensor()) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.zeros((2, 3), "float32") + return gv + + bb = relax.BlockBuilder() + dumb_param = relax.Var("dumb_param", R.Tensor()) + with bb.function("foo", [dumb_param]): + gv = bb.emit(relax.op.zeros((2, 3), "float32")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_zeros_like(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.zeros_like(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.zeros_like(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_tril(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.tril(x, k=2) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.tril(x, k=2)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_triu(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.triu(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.triu(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_datatype.py b/tests/python/relax/test_tvmscript_parser_op_datatype.py new file mode 100644 index 000000000000..ec71e868d45b --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_datatype.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_astype(): + @R.function + def expected(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float16"): + gv: R.Tensor((2, 3, 4), "float16") = R.astype(x, "float16") + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("main", [x]): + gv = bb.emit(relax.op.astype(x, "float16")) + bb.emit_func_output(gv) + + _check(expected, bb.get()["main"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_image.py b/tests/python/relax/test_tvmscript_parser_op_image.py new file mode 100644 index 000000000000..a90da37812ef --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_image.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_resize2d(): + @R.function + def foo(x: R.Tensor((2, 14, 14, 3), "float32")) -> R.Tensor((2, 28, 28, 3), "float32"): + gv: R.Tensor((2, 28, 28, 3), "float32") = R.image.resize2d(x, size=(28, 28), layout="NHWC") + return gv + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 14, 14, 3), "float32")) + with bb.function("foo", [x]): + gv = bb.emit(relax.op.image.resize2d(x, (28, 28), layout="NHWC")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_index.py b/tests/python/relax/test_tvmscript_parser_op_index.py new file mode 100644 index 000000000000..b271d1a7f3bc --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_index.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_take(): + @R.function + def foo( + x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((3,), "int64") + ) -> R.Tensor((2, 3, 3), "float32"): + gv: R.Tensor((2, 3, 3), "float32") = R.take(x, indices, axis=2) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + indices = relax.Var("indices", R.Tensor((3,), "int64")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, indices]): + gv = bb.emit(relax.op.take(x, indices, axis=2)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_strided_slice(): + @R.function + def foo(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor((4, 9, 10, 3), "float32"): + gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice( + x, + axes=[0, 1, -1], + begin=[1, 0, 8], + end=[8, 9, 0], + strides=[2, 1, -3], + ) + return gv + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) + with bb.function("foo", [x]): + gv = bb.emit( + relax.op.strided_slice( + x, axes=[0, 1, -1], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_linear_algebra.py b/tests/python/relax/test_tvmscript_parser_op_linear_algebra.py new file mode 100644 index 000000000000..1ed7fa9b917c --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_linear_algebra.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_matmul(): + @R.function + def foo( + x: R.Tensor((2, 3, 4, 5), "float32"), y: R.Tensor((6, 2, 3, 5, 7), "float32") + ) -> R.Tensor((6, 2, 3, 4, 7), "float32"): + gv: R.Tensor((6, 2, 3, 4, 7), "float32") = R.matmul(x, y) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((6, 2, 3, 5, 7), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y]): + gv = bb.emit(relax.op.matmul(x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_linear(): + @R.function + def foo( + x: R.Tensor((2, 3, 4, 5), "float32"), + w: R.Tensor((3, 5), "float32"), + bias: R.Tensor((3,), "float32"), + ): + gv = R.linear(x, w, bias) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + w = relax.Var("y", R.Tensor((3, 5), "float32")) + bias = relax.Var("bias", R.Tensor((3,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, w, bias]): + w_T = bb.emit(relax.op.permute_dims(w, axes=None)) + matmul = bb.emit(relax.op.matmul(x, w_T)) + out = matmul + bias + bb.emit_func_output(out) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py new file mode 100644 index 000000000000..a797885e9669 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py @@ -0,0 +1,407 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_broadcast_to(): + @R.function + def foo(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor((4, 2, 5, 3), "float32"): + gv: R.Tensor((4, 2, 5, 3), "float32") = R.broadcast_to(x, (4, 2, 5, 3)) + return gv + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 1, 3), "float32")) + with bb.function("foo", [x]): + gv = bb.emit(relax.op.broadcast_to(x, (4, 2, 5, 3))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_concat(): + @R.function + def foo( + x1: R.Tensor((1, 2, 3), "float32"), + x2: R.Tensor((1, 3, 3), "float32"), + x3: R.Tensor((1, 4, 3), "float32"), + ) -> R.Tensor((1, 9, 3), "float32"): + gv: R.Tensor((1, 9, 3), "float32") = R.concat((x1, x2, x3), axis=1) + return gv + + x1 = relax.Var("x1", R.Tensor((1, 2, 3), "float32")) + x2 = relax.Var("x2", R.Tensor((1, 3, 3), "float32")) + x3 = relax.Var("x3", R.Tensor((1, 4, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x1, x2, x3]): + gv = bb.emit(relax.op.concat((x1, x2, x3), axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_concat_without_specified_axis(): + @R.function + def foo( + x1: R.Tensor((2,), "float32"), x2: R.Tensor((3,), "float32"), x3: R.Tensor((4,), "float32") + ) -> R.Tensor((9,), "float32"): + gv: R.Tensor((9,), "float32") = R.concat((x1, x2, x3), axis=None) + return gv + + x1 = relax.Var("x1", R.Tensor((2,), "float32")) + x2 = relax.Var("x2", R.Tensor((3,), "float32")) + x3 = relax.Var("x3", R.Tensor((4,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x1, x2, x3]): + gv = bb.emit(relax.op.concat((x1, x2, x3), axis=None)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_expand_dims(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32"): + gv: R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32") = R.expand_dims(x, axis=[-1, 1, -6, 3, 5]) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.expand_dims(x, axis=[-1, 1, -6, 3, 5])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_flatten(): + @R.function + def foo(x: R.Tensor((3, 4, 5), "float32")) -> R.Tensor((60,), "float32"): + gv: R.Tensor((60,), "float32") = R.flatten(x) + return gv + + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.flatten(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_layout_transform(): + transformation = lambda n, c, h, w: (n, h, w, c) + + @R.function + def foo(x: R.Tensor((2, 3, 4, 5), "float32")): + gv: R.Tensor((2, 4, 5, 3), "float32") = R.layout_transform(x, index_map=transformation) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.layout_transform(x, index_map=transformation)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_layout_transform_with_padding(): + transformation = lambda n, c, h, w: (n, c // 3, h, w, c % 3) + + @R.function + def foo(x: R.Tensor((10, 20, 2, 2), "float32")): + gv: R.Tensor((10, 7, 2, 2, 3), "float32") = R.layout_transform( + x, index_map=transformation, pad_value=2 + ) + return gv + + x = relax.Var("x", R.Tensor((10, 20, 2, 2), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.layout_transform(x, index_map=transformation, pad_value=2)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_permute_dims(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((2, 4, 3, 1), "float32"): + gv: R.Tensor((2, 4, 3, 1), "float32") = R.permute_dims(x, axes=[1, -1, 2, -4]) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.permute_dims(x, axes=[1, -1, 2, -4])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_permute_dims_none_arg(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((4, 3, 2, 1), "float32"): + gv: R.Tensor((4, 3, 2, 1), "float32") = R.permute_dims(x) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.permute_dims(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_reshape(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3), "float32"): + gv: R.Tensor((8, 3), "float32") = R.reshape(x, (8, 3)) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.reshape(x, shape=(8, 3))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_reshape_infer_dim(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 1, 3), "float32"): + gv: R.Tensor((8, 1, 3), "float32") = R.reshape(x, (8, -1, 3)) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.reshape(x, shape=(8, -1, 3))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_split_by_indices(): + @R.function + def foo( + x: R.Tensor((2, 10, 4), dtype="float32") + ) -> R.Tuple( + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 4, 4), dtype="float32"), + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 4, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 1, 4), dtype="float32"), + ): + gv: R.Tuple( + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 4, 4), dtype="float32"), + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 4, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 1, 4), dtype="float32"), + ) = R.split(x, indices_or_sections=[-2, 2, 6, 4, 8, 12, 9], axis=1) + return gv + + x = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.split(x, indices_or_sections=[-2, 2, 6, 4, 8, 12, 9], axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_split_by_n_section(): + @R.function + def foo( + x: R.Tensor((2, 10, 4), dtype="float32") + ) -> R.Tuple( + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + ): + gv: R.Tuple( + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + ) = R.split(x, indices_or_sections=5, axis=1) + return gv + + x = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.split(x, indices_or_sections=5, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_squeeze(): + @R.function + def foo(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.squeeze(x) + return gv + + x = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.squeeze(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_squeeze_with_indices(): + @R.function + def foo(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "float32"): + gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x, axis=[3, -5]) + return gv + + x = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.squeeze(x, axis=[3, -5])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_collapse_sum_like(): + @R.function + def foo( + x: R.Tensor((3, 4, 5), "float32"), y: R.Tensor((4, 5), "float32") + ) -> R.Tensor((4, 5), "float32"): + gv: R.Tensor((4, 5), "float32") = R.collapse_sum_like(x, y) + return gv + + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y]): + gv = bb.emit(relax.op.collapse_sum_like(x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_collapse_sum_to(): + @R.function + def foo(x: R.Tensor((3, 4, 5), "float32")) -> R.Tensor((4, 5), "float32"): + gv: R.Tensor((4, 5), "float32") = R.collapse_sum_to(x, (4, 5)) + return gv + + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.collapse_sum_to(x, (4, 5))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_repeat(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")): + gv = R.repeat(x, 3, 1) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.repeat(x, 3, 1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_repeat_no_axis(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")): + gv = R.repeat(x, 3) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.repeat(x, 3)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_tile(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")): + gv = R.tile(x, (2, 3)) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.tile(x, (2, 3))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_cumsum(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")): + gv = R.cumsum(x, axis=1, dtype="int32") + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.cumsum(x, axis=1, dtype="int32")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py new file mode 100644 index 000000000000..a822fae71922 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_nn.py @@ -0,0 +1,306 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_conv1d(): + @R.function + def foo( + x: R.Tensor((2, 3, 228), "float16"), w: R.Tensor((16, 3, 5), "float16") + ) -> R.Tensor((2, 16, 224), "float16"): + gv: R.Tensor((2, 16, 224), "float16") = R.nn.conv1d(x, w, out_dtype="float16") + return gv + + x = relax.Var("x", R.Tensor([2, 3, 228], "float16")) + w = relax.Var("w", R.Tensor([16, 3, 5], "float16")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, w]): + gv = bb.emit(relax.op.nn.conv1d(x, w, out_dtype="float16")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_conv2d(): + @R.function + def foo( + x: R.Tensor((2, 3, 228, 228), "float16"), w: R.Tensor((16, 3, 5, 5), "float16") + ) -> R.Tensor((2, 16, 224, 224), "float16"): + gv: R.Tensor((2, 16, 224, 224), "float16") = R.nn.conv2d(x, w, out_dtype="float16") + return gv + + x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float16")) + w = relax.Var("w", R.Tensor([16, 3, 5, 5], "float16")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, w]): + gv = bb.emit(relax.op.nn.conv2d(x, w, out_dtype="float16")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_conv2d_transpose(): + @R.function + def foo( + x: R.Tensor((2, 3, 228, 228), "float16"), w: R.Tensor((3, 16, 5, 5), "float16") + ) -> R.Tensor((2, 16, 232, 232), "float16"): + gv: R.Tensor((2, 16, 232, 232), "float16") = R.nn.conv2d_transpose( + x, w, out_dtype="float16" + ) + return gv + + x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float16")) + w = relax.Var("w", R.Tensor([3, 16, 5, 5], "float16")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, w]): + gv = bb.emit(relax.op.nn.conv2d_transpose(x, w, out_dtype="float16")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_max_pool2d(): + @R.function + def foo( + x: R.Tensor((1, 1, 32, 32), dtype="float32") + ) -> R.Tensor((1, 1, 30, 30), dtype="float32"): + gv: R.Tensor((1, 1, 30, 30), dtype="float32") = R.nn.max_pool2d(x, pool_size=(3,)) + return gv + + x = relax.Var("x", R.Tensor([1, 1, 32, 32], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.max_pool2d(x, pool_size=(3,))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_avg_pool2d(): + @R.function + def foo( + x: R.Tensor((1, 1, 32, 32), dtype="float32") + ) -> R.Tensor((1, 1, 30, 30), dtype="float32"): + gv: R.Tensor((1, 1, 30, 30), dtype="float32") = R.nn.avg_pool2d(x, pool_size=(3,)) + return gv + + x = relax.Var("x", R.Tensor([1, 1, 32, 32], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.avg_pool2d(x, pool_size=(3,))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_adaptive_avg_pool2d(): + @R.function + def foo(x: R.Tensor((2, 64, 8, 9), "float32")) -> R.Tensor((2, 64, 7, 7), "float32"): + gv: R.Tensor((2, 64, 7, 7), "float32") = R.nn.adaptive_avg_pool2d(x, output_size=(7, 7)) + return gv + + x = relax.Var("x", R.Tensor((2, 64, 8, 9), dtype="float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.adaptive_avg_pool2d(x, output_size=(7, 7))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_gelu(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.gelu(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.gelu(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_softmax(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.softmax(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.softmax(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_log_softmax(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.log_softmax(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.log_softmax(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_batch_norm(): + @R.function + def foo( + x: R.Tensor((2, 4, 3, 3), dtype="float32"), + gamma: R.Tensor((4,), dtype="float32"), + beta: R.Tensor((4,), dtype="float32"), + moving_mean: R.Tensor((4,), dtype="float32"), + moving_var: R.Tensor((4,), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 4, 3, 3), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ): + gv: R.Tuple( + R.Tensor((2, 4, 3, 3), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) + return gv + + x = relax.Var("x", R.Tensor((2, 4, 3, 3), "float32")) + gamma = relax.Var("gamma", R.Tensor((4,), "float32")) + beta = relax.Var("beta", R.Tensor((4,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((4,), "float32")) + moving_var = relax.Var("moving_var", R.Tensor((4,), "float32")) + + bb = relax.BlockBuilder() + with bb.function("foo", [x, gamma, beta, moving_mean, moving_var]): + gv = bb.emit(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_layer_norm(): + @R.function + def foo( + x: R.Tensor((2, 3, 4, 5), "float32"), + gamma: R.Tensor((4, 5), "float32"), + beta: R.Tensor((4, 5), "float32"), + ) -> R.Tensor((2, 3, 4, 5), "float32"): + gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.layer_norm(x, gamma, beta, axes=[-2, -1]) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4, 5), "float32")) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, gamma, beta]): + gv = bb.emit(relax.op.nn.layer_norm(x, gamma, beta, axes=[-2, -1])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_group_norm(): + @R.function + def foo( + x: R.Tensor((2, 4, 4, 5), "float32"), + gamma: R.Tensor((4,), "float32"), + beta: R.Tensor((4,), "float32"), + ) -> R.Tensor((2, 4, 4, 5), "float32"): + gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.group_norm( + x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3] + ) + return gv + + x = relax.Var("x", R.Tensor((2, 4, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4,), "float32")) + beta = relax.Var("beta", R.Tensor((4,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, gamma, beta]): + gv = bb.emit( + relax.op.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3]) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_dropout(): + @R.function + def foo( + x: R.Tensor((2, 3), "float32") + ) -> R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), "float32")): + gv: R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), "float32")) = R.nn.dropout( + x, rate=0.5 + ) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.dropout(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_cross_entropy_with_logits(): + @R.function + def foo( + predictions: R.Tensor((2, 3), "float32"), labels: R.Tensor((2, 3), "float32") + ) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(predictions, labels) + return gv + + predictions = relax.Var("predictions", R.Tensor((2, 3), "float32")) + labels = relax.Var("labels", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [predictions, labels]): + gv = bb.emit(relax.op.nn.cross_entropy_with_logits(predictions, labels)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_search.py b/tests/python/relax/test_tvmscript_parser_op_search.py new file mode 100644 index 000000000000..86d0cfc9bcf9 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_search.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_where(): + @R.function + def foo( + condition: R.Tensor((2, 1), "bool"), + x: R.Tensor((2, 3), "float32"), + y: R.Tensor((1, 3), "float32"), + ) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.where(condition, x, y) + return gv + + bb = relax.BlockBuilder() + condition = relax.Var("condition", R.Tensor((2, 1), "bool")) + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((1, 3), "float32")) + with bb.function("foo", [condition, x, y]): + gv = bb.emit(relax.op.where(condition, x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_argmax(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), "int64"): + gv: R.Tensor((1, 3, 4), "int64") = R.argmax(x, axis=1) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.argmax(x, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_argmax_without_specified_axis(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((), "int64"): + gv: R.Tensor((), "int64") = R.argmax(x) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.argmax(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_argmax_keep_dims(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 1, 3, 4), "int64"): + gv: R.Tensor((1, 1, 3, 4), "int64") = R.argmax(x, axis=1, keepdims=True) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.argmax(x, axis=1, keepdims=True)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_set.py b/tests/python/relax/test_tvmscript_parser_op_set.py new file mode 100644 index 000000000000..8e01fa6f6215 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_set.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_unique(): + @R.function + def foo( + x: R.Tensor((2, 3, 4), dtype="float32") + ) -> R.Tuple( + R.Tensor(dtype="float32", ndim=3), + R.Tensor(dtype="int64", ndim=1), + R.Tensor(dtype="int64", ndim=1), + ): + gv: R.Tuple( + R.Tensor(dtype="float32", ndim=3), + R.Tensor(dtype="int64", ndim=1), + R.Tensor(dtype="int64", ndim=1), + ) = R.unique( + x, sorted=True, return_index=False, return_inverse=True, return_counts=True, axis=1 + ) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit( + relax.op.unique(x, sorted=True, return_inverse=True, return_counts=True, axis=1) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_statistical.py b/tests/python/relax/test_tvmscript_parser_op_statistical.py new file mode 100644 index 000000000000..221d2a17a8b8 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_statistical.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_sum(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3), "float32"): + gv: R.Tensor((1, 3), "float32") = R.sum(x, axis=[1, 3]) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.sum(x, axis=[1, 3])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_sum_without_specified_axis(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.sum(x) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.sum(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_sum_keep_dims(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 1, 3, 1), "float32"): + gv: R.Tensor((1, 1, 3, 1), "float32") = R.sum(x, axis=[1, 3], keepdims=True) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.sum(x, axis=[1, 3], keepdims=True)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_mean(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3), "float32"): + gv: R.Tensor((1, 3), "float32") = R.mean(x, axis=[1, 3]) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.mean(x, axis=[1, 3])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_variance(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1,), "float32"): + gv: R.Tensor((1,), "float32") = R.variance(x, axis=[-1, -2, -3]) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.variance(x, axis=[-1, -2, -3])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_max(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 1, 1, 1), "float32"): + gv: R.Tensor((1, 1, 1, 1), "float32") = R.variance(x, axis=[-1, -2, -3], keepdims=True) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.variance(x, axis=[-1, -2, -3], keepdims=True)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_min(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), "float32"): + gv: R.Tensor((1, 3, 4), "float32") = R.min(x, axis=1) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.min(x, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_prod(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), "float32"): + gv: R.Tensor((1, 3, 4), "float32") = R.prod(x, axis=1) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.prod(x, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_std(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), "float32"): + gv: R.Tensor((1, 3, 4), "float32") = R.std(x, axis=1) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.std(x, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py new file mode 100644 index 000000000000..bffa741353a9 --- /dev/null +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -0,0 +1,533 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import tvm +import tvm.testing +from tvm import IRModule, relax, tir +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def _assert_print(obj, expected): + if not isinstance(obj, str): + obj = obj.script(verbose_expr=True) + obj = obj.strip() + assert obj == expected.strip(), "\n" + obj + + +def test_function(): + @R.function + def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + return a + + _assert_print( + func, + """ +# from tvm.script import relax as R + +@R.function +def main(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): + return a""", + ) + + +def test_extern_func(): + @R.function + def relax_func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + return a + + obj = IRModule( + { + "func": relax_func, + "my_ext": relax.ExternFunc("my_ext"), + } + ) + _assert_print( + obj, + """ +# from tvm.script import ir as I +# from tvm.script import relax as R + +@I.ir_module +class Module: + "my_ext" + @R.function + def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): + return a +""", + ) + + +def test_object_struct_info(): + obj = relax.ObjectStructInfo() + _assert_print( + obj, + "R.Object", + ) + + +def test_prim_struct_info(): + obj = relax.PrimStructInfo("float32") + _assert_print(obj, 'R.Prim("float32")') + + +def test_shape_struct_info_0(): + obj = relax.ShapeStructInfo(ndim=-1) + _assert_print(obj, "R.Shape(ndim=-1)") + + +def test_shape_struct_info_1(): + obj = relax.ShapeStructInfo([1, 2, 3]) + _assert_print(obj, "R.Shape([1, 2, 3])") + + +def test_shape_struct_info_2(): + obj = relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]) + _assert_print( + obj, + """ +a = T.int64() +R.Shape([1, a, 3])""", + ) + + +def test_tensor_struct_info(): + obj = relax.TensorStructInfo( + shape=relax.ShapeExpr([1, tir.Var("a", "int64"), 3]), + dtype="float32", + ) + _assert_print( + obj, + """ +a = T.int64() +R.Tensor((1, a, 3), dtype="float32") +""", + ) + + +def test_tuple_struct_info_empty(): + obj = relax.TupleStructInfo([]) + _assert_print(obj, "R.Tuple") + + +def test_tuple_struct_info(): + obj = relax.TupleStructInfo( + [ + relax.PrimStructInfo("float32"), + relax.ObjectStructInfo(), + relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]), + ] + ) + _assert_print( + obj, + """ +a = T.int64() +R.Tuple(R.Prim("float32"), R.Object, R.Shape([1, a, 3])) +""", + ) + + +def test_func_struct_info(): + obj = relax.FuncStructInfo( + params=[ + relax.PrimStructInfo("float32"), + relax.ObjectStructInfo(), + relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]), + ], + ret=relax.TensorStructInfo( + shape=relax.ShapeExpr([1, 2, 3]), + dtype="float32", + ), + ) + _assert_print( + obj, + """ +a = T.int64() +R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3])), R.Tensor((1, 2, 3), dtype="float32")) +""", + ) + + +def test_shape_type(): + obj = relax.ShapeType(ndim=3) + _assert_print(obj, "R.Shape(ndim=3)") + + +def test_object_type(): + obj = relax.ObjectType() + _assert_print(obj, "R.Object") + + +def test_dyn_tensor_type(): + obj = relax.DynTensorType() + _assert_print(obj, 'R.Tensor(ndim=-1, dtype="float32")') + + +def test_packed_func_type(): + obj = relax.PackedFuncType() + _assert_print(obj, "R.PackedFunc") + + +def test_tuple_type(): + obj = relax.TupleType([relax.ShapeType(ndim=3), relax.ObjectType()]) + _assert_print( + obj._relax_script(), # pylint: disable=protected-access + "R.Tuple(R.Shape(ndim=3), R.Object)", + ) + + +def test_func_type(): + obj = relax.FuncType( + arg_types=[ + relax.ObjectType(), + relax.ShapeType(ndim=3), + ], + ret_type=relax.DynTensorType( + ndim=3, + dtype="float32", + ), + ) + _assert_print( + obj._relax_script(), # pylint: disable=protected-access + 'R.Callable((R.Object, R.Shape(ndim=3)), R.Tensor(ndim=3, dtype="float32"))', + ) + + +def test_prim_value(): + obj = relax.PrimValue(1) + _assert_print(obj, "R.prim_value(1)") + + +def test_string_imm(): + obj = relax.StringImm("hello") + _assert_print(obj, 'R.str("hello")') + + +def test_data_type_imm(): + obj = relax.DataTypeImm("float32") + _assert_print(obj, 'R.dtype("float32")') + + +def test_var(): + obj = relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")) + _assert_print( + obj, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +a""", + ) + + +def test_dataflow_var(): + obj = relax.DataflowVar("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")) + _assert_print( + obj, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +a""", + ) + + +def test_tuple(): + obj = relax.Tuple( + [ + relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")), + relax.Var("b", relax.TensorStructInfo([1, tir.Var("y", "int64"), 3], "float32")), + relax.Var("c", relax.TensorStructInfo([1, tir.Var("z", "int64"), 3], "float32")), + ] + ) + _assert_print( + obj, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +y = T.int64() +b: R.Tensor((1, y, 3), dtype="float32") +z = T.int64() +c: R.Tensor((1, z, 3), dtype="float32") +(a, b, c) +""", + ) + + +def test_tuple_get_item(): + obj = relax.TupleGetItem( + relax.Tuple( + [ + relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")), + relax.Var("b", relax.TensorStructInfo([1, tir.Var("y", "int64"), 3], "float32")), + relax.Var("c", relax.TensorStructInfo([1, tir.Var("z", "int64"), 3], "float32")), + ] + ), + 0, + ) + _assert_print( + obj, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +y = T.int64() +b: R.Tensor((1, y, 3), dtype="float32") +z = T.int64() +c: R.Tensor((1, z, 3), dtype="float32") +(a, b, c)[0] +""", + ) + + +def test_shape_expr(): + obj = relax.ShapeExpr([1, 2, 3]) + _assert_print(obj, "R.shape([1, 2, 3])") + + +def test_call(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + o0 = relax.call_tir(relax.GlobalVar("tir_func"), args=a, out_sinfo=a.struct_info, tir_vars=[x]) + o1 = relax.call_dps_packed("my_dps_func", args=a, out_sinfo=a.struct_info) + _assert_print( + o0, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +R.call_tir(tir_func, (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x])) +""", + ) + _assert_print( + o1, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +R.call_dps_packed("my_dps_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32")) +""", + ) + + +def test_seq_expr(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.DataflowVar("b", relax.TensorStructInfo([1, x, 3], "float32")) + c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) + + obj = relax.SeqExpr( + blocks=[ + relax.DataflowBlock( + bindings=[ + relax.VarBinding(b, relax.op.sin(a)), + relax.VarBinding(c, relax.op.sin(b)), + ] + ), + ], + body=c, + ) + _assert_print( + obj, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +with R.dataflow(): + b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) + c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b) + R.output(c) +c +""", + ) + + +def test_binding_block(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.Var("b", relax.TensorStructInfo([1, x, 3], "float32")) + c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) + obj = relax.BindingBlock( + bindings=[ + relax.VarBinding(b, relax.op.sin(a)), + relax.VarBinding(c, relax.op.sin(b)), + ] + ) + _assert_print( + obj, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) +c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b) +""", + ) + + +def test_dataflow_block(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.DataflowVar("b", relax.TensorStructInfo([1, x, 3], "float32")) + c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) + obj = relax.DataflowBlock( + bindings=[ + relax.VarBinding(b, relax.op.sin(a)), + relax.VarBinding(c, relax.op.sin(b)), + ] + ) + _assert_print( + obj, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +with R.dataflow(): + b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) + c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b) + R.output(c) +""", + ) + + +def test_match_cast(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3])) + b = relax.Var("b", relax.TensorStructInfo([1, 5, 3])) + obj = relax.MatchCast( + var=b, + value=a, + struct_info=b.struct_info, + ) + _assert_print( + obj, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +b: R.Tensor((1, 5, 3), dtype="float32") = R.match_cast(a, R.Tensor((1, 5, 3), dtype="float32")) +""", + ) + + +def test_var_binding(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.Var("b", relax.TensorStructInfo([1, x, 3], "float32")) + obj = relax.VarBinding(b, relax.op.sin(a)) + _assert_print( + obj, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) +""", + ) + + +def test_if(): + a = relax.Var("a", relax.TensorStructInfo([], "bool")) + b = relax.Var("b", relax.TensorStructInfo([1, 2, 3], "float32")) + c = relax.Var("c", relax.TensorStructInfo([1, 2, 3], "float32")) + obj = relax.If( + a, + relax.SeqExpr([], b), + relax.SeqExpr([], c), + ) + _assert_print( + obj, + """ +a: R.Tensor((), dtype="bool") +if a: + b: R.Tensor((1, 2, 3), dtype="float32") + b +else: + c: R.Tensor((1, 2, 3), dtype="float32") + c +""", + ) + + +def test_builtin_keywords(): + x = tir.Var("x", "int64") + a = relax.Var("R", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.Var("T", relax.TensorStructInfo([1, x, 3], "float32")) + obj = relax.VarBinding(b, relax.op.sin(a)) + _assert_print( + obj, + """ +x = T.int64() +R_1: R.Tensor((1, x, 3), dtype="float32") +T_1: R.Tensor((1, x, 3), dtype="float32") = R.sin(R_1) +""", + ) + + +def test_module_cross_func_call(): + @I.ir_module + class TestModule: + @T.prim_func + def tir_func( + x: T.Buffer((T.int64(128),), "float32"), y: T.Buffer((T.int64(128),), "float32") + ): + T.evaluate(0) + + @R.function + def foo(x: R.Tensor((128,), "float32")) -> R.Tensor((128,), "float32"): + cls = TestModule + gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128,), dtype="float32")) + return gv0 + + # default behavior + _assert_print( + TestModule, + """ +# from tvm.script import ir as I +# from tvm.script import tir as T +# from tvm.script import relax as R + +@I.ir_module +class Module: + @T.prim_func + def tir_func(x: T.Buffer((T.int64(128),), "float32"), y: T.Buffer((T.int64(128),), "float32")): + T.evaluate(0) + + @R.function + def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32"): + cls = Module + gv0 = R.call_tir(cls.tir_func, (x,), out_sinfo=R.Tensor((128,), dtype="float32")) + return gv0 +""", + ) + + # empty module alias + module_str = TestModule.script(module_alias="") + _assert_print( + module_str, + """ +# from tvm.script import ir as I +# from tvm.script import tir as T +# from tvm.script import relax as R + +@I.ir_module +class Module: + @T.prim_func + def tir_func(x: T.Buffer((T.int64(128),), "float32"), y: T.Buffer((T.int64(128),), "float32")): + T.evaluate(0) + + @R.function + def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32"): + gv0 = R.call_tir(Module.tir_func, (x,), out_sinfo=R.Tensor((128,), dtype="float32")) + return gv0 +""", + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py new file mode 100644 index 000000000000..15122dab3771 --- /dev/null +++ b/tests/python/relax/test_utils.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import relax +from tvm.ir.base import assert_structural_equal +from tvm.script.parser import relax as R + + +def test_copy_with_new_vars(): + @R.function + def before(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + gv = R.add(x, y) + return gv + + after = relax.utils.copy_with_new_vars(before) + assert_structural_equal(after, before) + + assert len(after.params) == len(before.params) + for before_var, after_var in zip(before.params, after.params): + assert before_var != after_var + + +def test_copy_with_new_vars_on_ir_module(): + @tvm.script.ir_module + class Actual: + @R.function + def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + gv = R.add(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + gv = R.add(x, y) + return gv + + @R.function + def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + gv = R.add(x, y) + return gv + + Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]) + + # Assertion will fail if the f_copied contains the same VarNode that's used in + # the original function, due to var mapping during structural equal. + assert_structural_equal(Actual, Expected) + + +def test_copy_with_new_vars_on_ir_module_nested_function(): + @tvm.script.ir_module + class Actual: + @R.function + def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + @R.function + def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): + gv = R.add(x, x) + return gv + + gv = R.add(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + @R.function + def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): + gv = R.add(x, x) + return gv + + gv = R.add(x, y) + return gv + + @R.function + def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + @R.function + def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): + gv = R.add(x, x) + return gv + + gv = R.add(x, y) + return gv + + Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]) + + assert_structural_equal(Actual, Expected) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py new file mode 100644 index 000000000000..149138335557 --- /dev/null +++ b/tests/python/relax/test_vm_build.py @@ -0,0 +1,910 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from typing import Tuple, Callable + +import sys +import tempfile +import numpy as np +import pytest +import tvm +import tvm.script +import tvm.testing +from tvm import relax, rpc, te, tir, topi +from tvm.contrib import utils +from tvm.relax.testing import nn +from tvm.script import relax as R, tir as T +from tvm.relax.testing.vm import check_saved_func + +EXEC_MODE = ["bytecode", "compiled"] + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_simple(exec_mode): + @tvm.script.ir_module + class TestVMCompileStage0: + @R.function + def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): + z = R.call_packed( + "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + ) + return y + + mod = TestVMCompileStage0 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + vm = relax.VirtualMachine(ex, tvm.cpu()) + vm["foo"](inp1, inp2) + tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_match_check(exec_mode): + @tvm.script.ir_module + class TestMatchCheck: + @R.function + def foo(x: R.Tensor(["n", "m"], "int32"), y: R.Object) -> R.Tensor(["m", "n"], dtype=None): + return y + + mod = TestMatchCheck + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x0 = tvm.nd.array(np.zeros((1, 2)).astype("int32")) + y0 = tvm.nd.array(np.zeros((2, 1)).astype("float32")) + y1 = tvm.nd.array(np.zeros((1, 2)).astype("float32")) + y2 = tvm.nd.array(np.zeros((2, 1, 1)).astype("float32")) + + vm["foo"](x0, y0) + + with pytest.raises(RuntimeError, match=".*return.*"): + vm["foo"](x0, y1) + + with pytest.raises(ValueError, match=".*return.*"): + vm["foo"](x0, y2) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_stage2(exec_mode): + @tvm.script.ir_module + class TestVMCompileStage2: + @R.function + def foo(x: R.Tensor(dtype="float32")) -> R.Shape: + n, m = T.int64(), T.int64() + _ = R.match_cast(x, R.Tensor((n, m), "float32")) + return R.shape([n * 2, m * 3]) + + mod = TestVMCompileStage2 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + arr = tvm.nd.array(np.random.rand(*shape).astype("float32")) + res = vm["foo"](arr) + assert res[0] == shape[0] * 2 + assert res[1] == shape[1] * 3 + + # dtype mismatch + with pytest.raises(ValueError, match=".*dtype.*"): + vm["foo"](tvm.nd.array(np.zeros((1, 2)).astype("int32"))) + + # ndim mismatch + with pytest.raises(ValueError, match=".*match_cast.*ndim.*"): + vm["foo"](tvm.nd.array(np.zeros((1,)).astype("float32"))) + + # type mismach + with pytest.raises(TypeError): + vm["foo"]([]) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_stage3(exec_mode): + @tvm.script.ir_module + class TestVMCompileStage3: + @R.function + def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor: + with R.dataflow(): + y = R.call_dps_packed("test.vm.identity", (x), R.Tensor((32, 16), dtype="float32")) + R.output(y) + return y + + mod = TestVMCompileStage3 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = vm["foo"](inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_e2e(exec_mode): + @tvm.script.ir_module + class TestVMCompileE2E: + @R.function + def foo(x: R.Tensor(dtype="float32")) -> R.Tensor: + with R.dataflow(): + n, m = T.int64(), T.int64() + _ = R.match_cast(x, R.Tensor((n, m), "float32")) + y = R.call_dps_packed("test.vm.tile", (x), R.Tensor((n, m * 2), dtype="float32")) + R.output(y) + return y + + mod = TestVMCompileE2E + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = check_saved_func(vm, "foo", inp) + tvm.testing.assert_allclose(res.numpy(), np.tile(inp.numpy(), (1, 2)), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_e2e_func_param_with_shape(exec_mode): + @tvm.script.ir_module + class TestVMCompileE2E2: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.int32() + n = T.int32() + k = T.int32() + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def func( + x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") + ) -> R.Tensor: + m, k = T.int64(), T.int64() + cls = TestVMCompileE2E2 + gv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) + return gv0 + + mod = TestVMCompileE2E2 + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + res = check_saved_func(vm, "func", data, weight) + expected = np.dot(data.numpy(), weight.numpy()) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_extern(exec_mode): + if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): + print("skip because extern function is not available") + return + bb = relax.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = relax.Var("x", R.Tensor([n, m], "float32")) + y = relax.Var("y", R.Tensor([m, n], "float32")) + + with bb.function("rx_cblas_matmul", [x, y]): + out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) + bb.emit_func_output(out) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + res = check_saved_func(vm, "rx_cblas_matmul", data, weight) + expected = np.dot(data.numpy(), weight.numpy()) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_concat(exec_mode): + # concatenate of two vectors of size (n,) and (m,) + bb = relax.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = relax.Var("x", R.Tensor([n], "float32")) + y = relax.Var("y", R.Tensor([m], "float32")) + + def te_func(A, B): + C = te.compute((n + m), lambda i: tvm.tir.if_then_else(i < n, A[i], B[i - n])) + return C + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, x, y) + bb.emit_func_output(x1) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array( + np.random.rand( + 1, + ).astype(np.float32) + ) + inp2 = tvm.nd.array( + np.random.rand( + 2, + ).astype(np.float32) + ) + res = check_saved_func(vm, "rx_func", inp, inp2) + tvm.testing.assert_allclose( + res.numpy(), np.append(inp.numpy(), inp2.numpy()), rtol=1e-7, atol=1e-7 + ) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_dtype_change(exec_mode): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor([n], "float32")) + + # convert a tensor with dtype of float32 to int16 + def te_func(A): + B = te.compute((n,), lambda i: A[i].astype("int16")) + return B + + with bb.function("rx_func", [x]): + y = bb.emit_te(te_func, x) + bb.emit_func_output(y) + + mod = bb.get() + + new_mod = relax.transform.CallTIRRewrite()(mod) + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array( + np.random.rand( + 1, + ).astype(np.float32) + ) + res = check_saved_func(vm, "rx_func", inp) + np.testing.assert_allclose(res.numpy(), inp.numpy().astype("int16")) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_floor_symbolic_shape(exec_mode): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor([n], "float32")) + + def te_func(A): + C = te.compute((tir.floordiv(n, 2),), lambda i: A[i] + 1) + return C + + with bb.function("rx_func", [x]): + x1 = bb.emit_te(te_func, x) + bb.emit_func_output(x1) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + shape = (9,) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = check_saved_func(vm, "rx_func", inp) + + def expected_output(): + output_shape = (shape[0] // 2,) + return inp.numpy()[: output_shape[0]] + 1 + + tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_constant_param_cpu(exec_mode): + x_np = np.random.rand(2, 2).astype("float32") + c_np = np.random.rand(2, 2).astype("float32") + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 2), "float32")) + c = relax.const(c_np, "float32") + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, c) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + mod = bb.get() + exec = relax.build(mod, "llvm", exec_mode=exec_mode) + dev = tvm.cpu() + vm = relax.VirtualMachine(exec, dev) + + add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev)) + tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@tvm.testing.requires_gpu +def test_vm_emit_te_constant_param_gpu(exec_mode): + x_np = np.random.rand(2, 2).astype("float32") + c_np = np.random.rand(2, 2).astype("float32") + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 2), "float32")) + c = relax.const(c_np, "float32") + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, c) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + mod = bb.get() + sch = tvm.tir.Schedule(mod, debug_mask="all") + loops = sch.get_loops(sch.get_block(name="T_add", func_name="add")) + sch.bind(loops[0], "threadIdx.x") + + exec = relax.build(sch.mod, "cuda", exec_mode=exec_mode) + dev = tvm.cuda() + vm = relax.VirtualMachine(exec, dev) + + add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev)) + tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_relax_symbolic_shape(exec_mode): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor([n], "float32")) + y = relax.Var("y", R.Tensor([(n // 2) + 1], "float32")) + + def te_func(A, B): + C = te.compute((n,), lambda i: A[i] + B[i // 2]) + return C + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, x, y) + bb.emit_func_output(x1) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + shape1 = (5,) + shape2 = (3,) + inp = tvm.nd.array(np.random.rand(*shape1).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(*shape2).astype(np.float32)) + res = check_saved_func(vm, "rx_func", inp, inp2) + + def expected_output(): + return inp.numpy() + np.repeat(inp2.numpy(), 2)[:5] + + tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_relax_dyn_tir_shape(exec_mode): + # case where TIR variables are unbound in generated PrimFunc + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + + def te_func(A): + C = te.compute((n + 1), lambda i: A[i]) + return C + + with bb.function("rx_func"): + x = nn.Placeholder((n,), dtype="float32", name="x") + y = nn.Placeholder((n + 1,), dtype="float32", name="y") + + x1 = bb.emit_te(te_func, y) + bb.emit_func_output(x1, params=[x, y]) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + + ex.export_library("exec.so") + vm = relax.VirtualMachine(tvm.runtime.load_module("exec.so"), tvm.cpu()) + inp = tvm.nd.array(np.random.rand(2).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(3).astype(np.float32)) + + res = check_saved_func(vm, "rx_func", inp, inp2) + + tvm.testing.assert_allclose(res.numpy(), inp2.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_tuple(exec_mode): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + + with bb.function("rx_func"): + x = nn.Placeholder((n,), dtype="float32", name="x") + y = nn.Placeholder((n,), dtype="float32", name="y") + tup = relax.Tuple([x, y]) + item = tup[0] + bb.emit_func_output([tup, item], params=[x, y]) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + shape = (5,) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + (res1, res2), res3 = vm["rx_func"](inp, inp2) + + tvm.testing.assert_allclose(res1.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res2.numpy(), inp2.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res3.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_tuplegetitem(exec_mode): + @tvm.script.ir_module + class TestVMTupleGetItem: + @R.function + def tuple_get_item( + x: R.Tensor(ndim=2, dtype="float32"), + y: R.Tensor(ndim=2, dtype="float32"), + ): + t = (x, y) + a = t[0] + b = t[1] + c = R.call_packed("test.vm.add", a, b, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) + return c + + mod = TestVMTupleGetItem + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + res = check_saved_func(vm, "tuple_get_item", x_inp, y_inp) + tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_lower_memory_alloc_storage_tensor(exec_mode): + @tvm.script.ir_module + class TestMemoryAllocStorageTensor: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")): + cls = TestMemoryAllocStorageTensor + storage = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + y = R.memory.alloc_tensor(storage, 0, R.shape([2, 3]), dtype="float32") + _ = cls.copy(x, y) + return y + + @T.prim_func + def copy(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + for i0, i1 in T.grid(2, 3): + with T.block("block"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + B[vi0, vi1] = A[vi0, vi1] + + mod = TestMemoryAllocStorageTensor + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y = vm["main"](x) + tvm.testing.assert_allclose(y.numpy(), x.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_sub_func_call(exec_mode): + @tvm.script.ir_module + class TestVMSubFunction: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.int32() + n = T.int32() + k = T.int32() + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def relax_matmul_tir( + x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") + ) -> R.Tensor((32, 32), dtype="float32"): + cls = TestVMSubFunction + with R.dataflow(): + gv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + R.output(gv0) + return gv0 + + @R.function + def relax_matmul_packed( + x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") + ) -> R.Object: + gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) + return gv0 + + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Object: + cls = TestVMSubFunction + gv0 = cls.relax_matmul_tir(x, w) + gv1 = cls.relax_matmul_packed(gv0, gv0) + return gv1 + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(TestVMSubFunction, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) + y_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) + res = check_saved_func(vm, "main", x_inp, y_inp) + product = np.dot(x_inp.numpy(), y_inp.numpy()) + expected = product * product + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_recursion(exec_mode): + @tvm.script.ir_module + class TestVMRecursion: + @R.function + def recursion(n: R.Tensor((1,), "float32")) -> R.Tensor: + cond = R.call_packed( + "test.vm.equal_zero", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) + if cond: + res = R.const(1.0) + else: + gv0 = R.call_packed( + "test.vm.subtract_one", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) + tmp = TestVMRecursion.recursion(gv0) + res = R.call_packed( + "test.vm.add", tmp, tmp, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) + return res + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(TestVMRecursion, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + inp = np.empty(1).astype("float32") + recursion_runs = np.random.randint(1, 10) + inp.fill(recursion_runs) + inp = tvm.nd.array(inp) + res = check_saved_func(vm, "recursion", inp) + tvm.testing.assert_allclose(res.numpy(), np.power(2.0, recursion_runs), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_closure(exec_mode): + @tvm.script.ir_module + class TestClosure: + @R.function + def lifted_func_1(x: R.Tensor((2, 3), "float32"), env: R.Tensor((2, 3), "float32")): + return R.call_packed("test.vm.add", x, env, sinfo_args=(R.Tensor)) + + @R.function + def main( + x: R.Tensor((2, 3), "float32"), + y: R.Tensor((2, 3), "float32"), + ): + cls = TestClosure + clo = R.make_closure(cls.lifted_func_1, (x,)) + res = R.invoke_closure(clo, (y,), sinfo_args=(R.Tensor)) + return res + + mod = TestClosure + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y_inp = tvm.nd.array(np.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]], dtype="float32")) + res = check_saved_func(vm, "main", x_inp, y_inp) + tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy()) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_time_evaluator(exec_mode): + @tvm.script.ir_module + class TestTimeEvaluator: + @R.function + def main(x: R.Tensor((1,), "float32"), y: R.Tensor((1,), "float32")): + return R.call_packed( + "test.vm.add", x, y, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(TestTimeEvaluator, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x = tvm.nd.array(np.random.rand(1).astype("float32")) + y = tvm.nd.array(np.random.rand(1).astype("float32")) + + # ensure we can use time_evaluator with the stateful API + vm.set_input("main", x, y) + timing_res = vm.time_evaluator("invoke_stateful", tvm.cpu())("main") + # just checking that it has some results at all + assert timing_res.results + + # ensure we can use it with a closure + vm.save_function("main", "saved_main", x, y) + timing_res = vm.time_evaluator("saved_main", tvm.cpu())() + assert timing_res.results + + +@tvm.script.ir_module +class TestVMSetInput: + @T.prim_func + def test_vm_mul(x: T.handle, y: T.handle, z: T.handle): + T.func_attr({"global_symbol": "test_vm_mul"}) + m = T.int32() + n = T.int32() + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (m, n)) + C = T.match_buffer(z, (m, n)) + + for i, j in T.grid(m, n): + with T.block("mul"): + vi = T.axis.spatial(m, i) + vj = T.axis.spatial(n, j) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = A[vi, vj] * B[vi, vj] + + # test returning a tuple + @R.function + def test_vm_tuple( + x: R.Tensor((), "int32") + ) -> R.Tuple(R.Tensor((), "int32"), R.Tensor((), "int32")): + return (x, x) + + # nested tuple too + @R.function + def test_vm_nested_tuple( + x: R.Tensor((), "int32") + ) -> R.Tuple( + R.Tuple( + R.Tensor((), "int32"), + R.Tuple( + R.Tensor((), "int32"), + ), + ), + R.Tensor((), "int32"), + ): + return ((x, (x,)), x) + + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + cls = TestVMSetInput + gv0 = R.call_tir(cls.test_vm_mul, (x, w), R.Tensor((32, 32), dtype="float32")) + return gv0 + + +def set_input_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.set_input("main", a, b) + vm.invoke_stateful("main") + res0 = vm.get_outputs("main") + + data_dict = {"x": a, "w": b} + vm.set_input("main", **data_dict) + vm.invoke_stateful("main") + res1 = vm.get_outputs("main") + tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res0.numpy(), res1.numpy(), rtol=1e-7, atol=1e-7) + + # bug! If you don't bind the NDArray to a var, the memory will get corrupted. + # Possibly due to object lifecycles and other FFI issues + a = tvm.nd.array(np.array(2).astype("int32"), device) + vm.set_input("test_vm_tuple", a) + vm.invoke_stateful("test_vm_tuple") + res2 = vm.get_outputs("test_vm_tuple") + # the results are NDArrays wrapped around scalars, + # so we have to get the scalar out of the NDArray + assert tuple(map(lambda a: int(a.numpy()), res2)) == (2, 2) + + b = tvm.nd.array(np.array(1).astype("int32"), device) + vm.set_input("test_vm_nested_tuple", b) + vm.invoke_stateful("test_vm_nested_tuple") + res3 = vm.get_outputs("test_vm_nested_tuple") + assert len(res3) == 2 and len(res3[0]) == 2 and len(res3[0][1]) == 1 + result_cast = ((int(res3[0][0].numpy()), (int(res3[0][1][0].numpy()),)), int(res3[1].numpy())) + assert result_cast == ((1, (1,)), 1) + + +def set_input_attempt_stateless(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # this should fail: once you set inputs, you cannot run statelessly + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.set_input("main", a, b) + # must use invoke stateful! + vm["main"]() + + +def set_input_attempt_invoke(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # this should fail: if the function needs inputs, you can't invoke directly + vm.invoke_stateful("main") + + +def set_input_attempt_get(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # this should fail: you can't get outputs without invoking the function first + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.set_input("main", a, b) + _ = vm.get_outputs("main") + + +def make_vm(mod, exec_mode) -> Tuple[relax.VirtualMachine, tvm.runtime.Device]: + """Returns a local VM for the given mod and the device""" + target = tvm.target.Target("llvm", host="llvm") + exec = relax.build(TestVMSetInput, target, exec_mode=exec_mode) + exec.export_library("exec.so") + exec_loaded = tvm.runtime.load_module("exec.so") + os.remove("exec.so") + device = tvm.cpu() + return relax.VirtualMachine(exec_loaded, device), device + + +def run_on_rpc( + mod: tvm.IRModule, + trial_func: Callable[[relax.VirtualMachine, tvm.runtime.Device], None], + exec_mode: str, +): + """ + Sets up a VM over localhost using the given mod and runs the given trial function. + The trial function should take a VM and a device + """ + target = tvm.target.Target("llvm", host="llvm") + exec = relax.build(mod, target, exec_mode=exec_mode) + temp = utils.tempdir() + path = temp.relpath("vm_library.so") + exec.export_library(path) + + # Use local rpc server for testing. + # Server must use popen so it doesn't inherit the current process state. It + # will crash otherwise. + # Adapted from relay/test_vm.py + def check_remote(server): + remote = rpc.connect(server.host, server.port, session_timeout=10) + + # Upload the serialized Executable. + remote.upload(path) + # Get a handle to remote Executable. + rexec = remote.load_module("vm_library.so") + + device = remote.cpu() + # Build a VM out of the executable and context. + vm = relax.VirtualMachine(rexec, device=device) + trial_func(vm, device) + + check_remote(rpc.Server("127.0.0.1")) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_set_input(exec_mode): + set_input_trial(*make_vm(TestVMSetInput, exec_mode)) + + +def save_function_kwargs_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # just checking that we can use kwargs for the args when saving a function + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.save_function("main", "saved_main", x=a, w=b) + res0 = vm["saved_main"]() + tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_save_function_kwargs(exec_mode): + save_function_kwargs_trial(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_save_function_kwargs_rpc(exec_mode): + run_on_rpc(TestVMSetInput, save_function_kwargs_trial, exec_mode) + + +def save_function_time_evaluator_trial( + vm: relax.VirtualMachine, device: tvm.runtime.Device +) -> None: + # just checking that the saved function can be called in the time evaluator + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.save_function("main", "saved_main", a, b) + vm.time_evaluator("saved_main", device)() + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_save_function_time_evaluator(exec_mode): + save_function_time_evaluator_trial(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_save_function_time_evaluator(exec_mode): + run_on_rpc(TestVMSetInput, save_function_time_evaluator_trial, exec_mode) + + +# if you set an input, you should not be able to call statelessly +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_stateless_failure(exec_mode): + set_input_attempt_stateless(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_stateless_failure_rpc(exec_mode): + run_on_rpc(TestVMSetInput, set_input_attempt_stateless, exec_mode) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_invoke_failure(exec_mode): + set_input_attempt_invoke(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_invoke_failure_rpc(exec_mode): + run_on_rpc(TestVMSetInput, set_input_attempt_invoke, exec_mode) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_get_failure(exec_mode): + set_input_attempt_get(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_get_failure_rpc(exec_mode): + run_on_rpc(TestVMSetInput, set_input_attempt_get, exec_mode) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py new file mode 100644 index 000000000000..b9904429f3b8 --- /dev/null +++ b/tests/python/relax/test_vm_codegen_only.py @@ -0,0 +1,335 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test last-stage of codegen VM. + +Restrictions: all shape lowered, explicit allocation. +""" +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import relax +from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode +from tvm.relax.testing.vm import check_saved_func +from tvm.script import relax as R +from tvm.script import tir as T + +EXEC_MODE = ["bytecode", "compiled"] + + +def codegen(mod, target, exec_mode="bytecode"): + builder = relax.ExecBuilder() + tir_mod = relax.vm_build._vmcodegen(builder, mod, exec_mode=exec_mode) + return relax.vm_build._vmlink(builder, target, tir_mod) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_copy(exec_mode): + @tvm.script.ir_module + class TestVMMove: + @R.function + def foo(x: R.Tensor((3, 4), "float32")): + R.func_attr({"global_symbol": "foo"}) + z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))) + return z + + mod = TestVMMove + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + inp = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = check_saved_func(vm, "foo", inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_if_cond_const(exec_mode): + @tvm.script.ir_module + class TestVMIfCondConst: + @R.function + def main(x: R.Tensor(ndim=2, dtype="float32")) -> R.Tensor(ndim=2, dtype="float32"): + R.func_attr({"global_symbol": "main"}) + if relax.const(True, dtype="bool"): + ret = x + else: + ret = x + return ret + + mod = TestVMIfCondConst + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(3, 4)) + res = vm["main"](inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy()) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_exec_serialize_export_library(exec_mode): + @tvm.script.ir_module + class TestVMMove: + @R.function + def foo(x: R.Tensor((3, 4), "float32")): + R.func_attr({"global_symbol": "foo"}) + z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))) + return z + + mod = TestVMMove + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target) + from tvm.contrib import utils + + temp_dir = utils.tempdir() + path_exec = temp_dir.relpath("exec.so") + ex.export_library(path_exec) + + loaded_exec = tvm.runtime.load_module(path_exec) + assert ex.as_text() == loaded_exec["as_text"]() + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_if_cond(exec_mode): + @tvm.script.ir_module + class TestVMCompileIf: + @R.function + def ife(cond: R.Tensor((), "bool"), x: R.Tensor((3, 4), "float32")) -> R.Tensor: + R.func_attr({"global_symbol": "ife"}) + if cond: + w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) + else: + w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor)) + return w + + mod = TestVMCompileIf + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(3, 4)) + res = vm["ife"](tvm.nd.array(1), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7) + res = vm["ife"](tvm.nd.array(True), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7) + res = vm["ife"](tvm.nd.array(0), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7) + res = vm["ife"](tvm.nd.array(False), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_return_const_tuple(exec_mode): + @tvm.script.ir_module + class ReturnConstTuple: + @R.function + def main(x: R.Tensor(ndim=2, dtype="float32")): + R.func_attr({"global_symbol": "main"}) + y = R.const([1, 2]) + z = (y, R.const([3, 4]), x) + return z + + mod = ReturnConstTuple + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(2, 3)) + res0, res1, res2 = vm["main"](inp) + tvm.testing.assert_allclose(res0.numpy(), np.array([1, 2])) + tvm.testing.assert_allclose(res1.numpy(), np.array([3, 4])) + tvm.testing.assert_allclose(res2.numpy(), inp.numpy()) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_const_as_call_arg(exec_mode): + @tvm.script.ir_module + class TestVMConstAsCallArg: + @R.function + def main(x: R.Tensor(ndim=2, dtype="float32")): + R.func_attr({"global_symbol": "main"}) + a = R.call_packed( + "test.vm.add", + relax.const([1, 2]), + relax.const([3, 4]), + sinfo_args=(R.Tensor(ndim=2, dtype="float32")), + ) + b = R.call_packed( + "test.vm.add", + a, + x, + sinfo_args=(R.Tensor(ndim=2, dtype="float32")), + ) + return b + + mod = TestVMConstAsCallArg + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(1, 2)) + res = vm["main"](inp) + tvm.testing.assert_allclose(res.numpy(), np.array([4, 6]) + inp.numpy()) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_shape_check_builtin(exec_mode): + MS = MatchShapeCode + MK = MakeShapeCode + # slot assignment: + # 0: n, 1: m + sindex = {"n": 0, "m": 1} + + @tvm.script.ir_module + class TestVMShapeCheck: + @R.function + def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3): + R.func_attr({"global_symbol": "main"}) + n = T.int64() + k = T.int64() + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(3)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + # construct shape value for return + s = R.call_packed( + "vm.builtin.make_shape", + shape_heap, + 3, + MK.LOAD_SHAPE, + sindex["m"], + MK.LOAD_SHAPE, + sindex["n"], + MK.USE_IMM, + 2, + sinfo_args=[R.Shape(ndim=3)], + ) + return s + + mod = TestVMShapeCheck + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x = tvm.nd.array(np.zeros((1, 2)).astype("float32")) + res = vm["main"](x) + assert res == tvm.runtime.container.ShapeTuple([2, 1, 2]) + + # wrong input type + with pytest.raises(TypeError): + vm["main"]([]) + + # wrong ndim + with pytest.raises(ValueError, match=r".*ndim.*"): + vm["main"](tvm.nd.array(np.zeros(1).astype("float32"))) + + # wrong dtype + with pytest.raises(ValueError, match=r".*dtype.*"): + vm["main"](tvm.nd.array(np.zeros((1, 2)).astype("int32"))) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_prim_value(exec_mode): + @tvm.script.ir_module + class TestVMPrimValue: + @R.function + def main(): + R.func_attr({"global_symbol": "main"}) + ret = R.prim_value(T.int64(1)) + return ret + + mod = TestVMPrimValue + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + assert res == 1 + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_string_imm(exec_mode): + @tvm.script.ir_module + class TestVMStringImm: + @R.function + def main(): + R.func_attr({"global_symbol": "main"}) + ret = R.str("hello") + return ret + + mod = TestVMStringImm + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + assert res == "hello" + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_datatype_imm(exec_mode): + @tvm.script.ir_module + class TestDataTypeImm: + @R.function + def main(): + R.func_attr({"global_symbol": "main"}) + ret = R.dtype("float32") + return ret + + mod = TestDataTypeImm + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + assert res == "float32" + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_builtin_reshape(exec_mode): + @tvm.script.ir_module + class TestVMBuiltinReshape: + @R.function + def main(x: R.Tensor((3, 4), "float32")): + R.func_attr({"global_symbol": "main"}) + y = R.call_packed( + "vm.builtin.reshape", x, R.shape([6, 2]), sinfo_args=R.Tensor((6, 2), "float32") + ) + return y + + mod = TestVMBuiltinReshape + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + dev = tvm.cpu() + vm = relax.VirtualMachine(ex, dev) + + input_np = np.random.rand(3, 4).astype("float32") + input = tvm.nd.array(input_np, dev) + res = vm["main"](input) + expected = input_np.reshape(6, 2) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-7, atol=1e-7) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py new file mode 100644 index 000000000000..d82715a3946f --- /dev/null +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test the TIR codegen path of VM compiled mode. + +Restrictions: all shape lowered, explicit allocation. +""" +import tvm +import tvm.testing +from tvm import relax +from tvm.ir import assert_structural_equal +from tvm.script import relax as R +from tvm.script import tir as T + + +def get_tir_mod(mod): + builder = relax.ExecBuilder() + return relax.vm_build._vmcodegen(builder, mod, exec_mode="compiled") + + +def test_add(): + @tvm.script.ir_module + class Before: + @R.function + def foo(x: R.Tensor): + R.func_attr({"global_symbol": "foo"}) + z = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + T.func_attr({"global_symbol": "__vmtir__foo"}) + T.anylist_setitem_call_packed( + r, + T.int32(2), + "test.vm.add", + T.anylist_getitem(r, T.int32(0)), + T.anylist_getitem(r, T.int32(0)), + ) + T.anylist_setitem_call_packed( + r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(2)) + ) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +def test_tir_call(): + @tvm.script.ir_module + class Before: + @T.prim_func + def shape_func(H: T.Buffer(T.int64(4), "int64")): + T.func_attr({"global_symbol": "shape_func"}) + # generated compute function + H[T.int64(0)] = H[T.int64(0)] + T.int64(1) + + @R.function + def foo(x: R.Tensor): + R.func_attr({"global_symbol": "foo"}) + _ = Before.shape_func(x) + return x + + @tvm.script.ir_module + class Expected: + @T.prim_func + def shape_func(H: T.Buffer(T.int64(4), "int64")): + T.func_attr({"global_symbol": "shape_func"}) + # generated compute function + H[T.int64(0)] = H[T.int64(0)] + T.int64(1) + + @T.prim_func + def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + T.func_attr({"global_symbol": "__vmtir__foo"}) + T.call_cpacked( + "shape_func", T.anylist_getitem(r, T.int32(0)), T.reinterpret("handle", T.uint64(0)) + ) + T.anylist_setitem_call_packed( + r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(0)) + ) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +def test_if_cond(): + @tvm.script.ir_module + class Before: + @R.function + def ife(cond: R.Tensor((), "bool"), x: R.Tensor) -> R.Tensor: + R.func_attr({"global_symbol": "ife"}) + if cond: + w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) + else: + w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor)) + return w + + @tvm.script.ir_module + class Expected: + @T.prim_func + def __vmtir__ife(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + T.func_attr({"global_symbol": "__vmtir__ife"}) + if T.cast( + T.tvm_call_packed("vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))), + "bool", + ): + T.anylist_setitem_call_packed( + r, + T.int32(4), + "test.vm.add", + T.anylist_getitem(r, T.int32(1)), + T.anylist_getitem(r, T.int32(1)), + ) + T.anylist_setitem_call_packed( + r, T.int32(3), "vm.builtin.copy", T.anylist_getitem(r, T.int32(4)) + ) + else: + T.anylist_setitem_call_packed( + r, + T.int32(5), + "test.vm.mul", + T.anylist_getitem(r, T.int32(1)), + T.anylist_getitem(r, T.int32(1)), + ) + T.anylist_setitem_call_packed( + r, T.int32(3), "vm.builtin.copy", T.anylist_getitem(r, T.int32(5)) + ) + T.anylist_setitem_call_packed( + r, T.int32(2), "vm.builtin.copy", T.anylist_getitem(r, T.int32(3)) + ) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +def test_const(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + R.func_attr({"global_symbol": "main"}) + y = R.const([1, 2]) + z = (y, R.const([3, 4]), x) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + # function attr dict + T.func_attr({"global_symbol": "__vmtir__main"}) + # body + T.anylist_setitem_call_packed( + r, + T.int32(2), + "vm.builtin.make_tuple", + T.anylist_getitem(c, T.int32(0)), + T.anylist_getitem(c, T.int32(1)), + T.anylist_getitem(r, T.int32(0)), + ) + T.anylist_setitem_call_packed( + r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(2)) + ) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +def test_const_call(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + R.func_attr({"global_symbol": "main"}) + y = R.const([1, 2]) + z = R.call_packed("test.vm.add", x, y, sinfo_args=(R.Tensor)) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + # function attr dict + T.func_attr({"global_symbol": "__vmtir__main"}) + # body + T.anylist_setitem_call_packed( + r, + 2, + "test.vm.add", + T.anylist_getitem(r, 0), + T.anylist_getitem(c, 0), + ) + T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", T.anylist_getitem(r, 2)) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py new file mode 100644 index 000000000000..bd4b3fe90f10 --- /dev/null +++ b/tests/python/relax/test_vm_cuda_graph.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.script import tir as T, relax as R, ir as I +from tvm import relax +import tvm.testing +import numpy as np + + +# fmt: off + + +@I.ir_module +class Module: + @R.function + def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"): + cls = Module + R.func_attr({"global_symbol": "main"}) + gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object),)) + storage: R.Object = gv[0] + alloc: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) + _: R.Tuple = cls.add(x, alloc) + storage1: R.Object = gv[1] + gv1: R.Tuple(R.Tensor(dtype="float32"), R.Object, R.Object) = (alloc, storage1, storage) + gv2: R.Tuple(R.Tensor((16, 16), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, gv1, R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((16, 16), dtype="float32")),)) + storage2: R.Object = R.vm.alloc_storage(R.shape((1024,)), R.prim_value(0), R.dtype("float32")) + alloc3: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) + lv4: R.Tensor((16, 16), dtype="float32") = gv2[0] + _3: R.Tuple = cls.add(lv4, alloc3) + lv5: R.Tensor(dtype="float32") = alloc3 + return lv5 + + @T.prim_func + def add(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): + T.func_attr({"global_symbol": "add"}) + with T.block("root"): + for i in T.thread_binding(16, thread="threadIdx.x"): + for j in range(16): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @R.function + def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): + R.func_attr({"global_symbol": "cuda_graph_alloc"}) + storage: R.Object = R.vm.alloc_storage(R.shape((1024,)), R.prim_value(0), R.dtype("float32")) + storage1: R.Object = R.vm.alloc_storage(R.shape((1024,)), R.prim_value(0), R.dtype("float32")) + gv: R.Tuple(R.Object, R.Object) = (storage, storage1) + return gv + + @R.function + def cuda_graph_capture(alloc: R.Tensor((16, 16), dtype="float32"), storage1: R.Object, storage: R.Object) -> R.Tuple(R.Tensor((16, 16), dtype="float32")): + cls = Module + R.func_attr({"global_symbol": "cuda_graph_capture"}) + lv0: R.Tensor((16, 16), dtype="float32") = alloc + alloc1: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) + _1: R.Tuple = cls.add(lv0, alloc1) + lv1: R.Tensor(dtype="float32") = alloc1 + lv2: R.Tuple(R.Tensor(dtype="float32")) = (lv1,) + lv3: R.Tensor(dtype="float32") = lv2[0] + alloc2: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) + _2: R.Tuple = cls.add(lv3, alloc2) + lv4: R.Tensor(dtype="float32") = alloc2 + gv: R.Tuple(R.Tensor(dtype="float32")) = (lv4,) + return gv + + +# fmt: on + + +def codegen(mod, target, exec_mode="bytecode"): + builder = relax.ExecBuilder() + leftover_mod = relax.vm_build._vmcodegen(builder, mod, exec_mode=exec_mode) + tir_mod = relax.vm_build._filter_tir(leftover_mod) + return relax.vm_build._vmlink(builder, target, tir_mod) + + +@tvm.testing.requires_cuda +def test_vm_run(): + mod = Module + target = tvm.target.Target("cuda", host="llvm") + ex = codegen(mod, target) + dev = tvm.cuda(0) + vm = relax.VirtualMachine(ex, dev) + x_np = np.random.uniform(size=(16, 16)).astype("float32") + x = tvm.nd.array(x_np, dev) + y = vm["main"](x) + y_np = x_np + 1.0 + 1.0 + 1.0 + 1.0 + tvm.testing.assert_allclose(y.asnumpy(), y_np, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_vm_execbuilder.py b/tests/python/relax/test_vm_execbuilder.py new file mode 100644 index 000000000000..9a7cd0c87938 --- /dev/null +++ b/tests/python/relax/test_vm_execbuilder.py @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Lowest level testing VM. Test execbuilder and execution.""" +import tvm +import pytest +import numpy as np +from tvm import relax, TVMError +from tvm.relax.testing.vm import check_saved_func + + +def test_vm_execute(): + ib = relax.ExecBuilder() + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + + add_res = check_saved_func(vm, "func0", a, b) + tvm.testing.assert_allclose(add_res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_multiple_func(): + ib = relax.ExecBuilder() + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + with ib.function("func1", num_inputs=2): + ib.emit_call("test.vm.mul", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + mul_res = check_saved_func(vm, "func1", a, b) + add_res = check_saved_func(vm, "func0", a, b) + tvm.testing.assert_allclose(add_res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(mul_res.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_checker(): + ib = relax.ExecBuilder() + with pytest.raises(TVMError): + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(2)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ib.get() + + +def test_neg_imm(): + ib = relax.ExecBuilder() + + with ib.function("func0", num_inputs=1): + ib.emit_call("test.vm.add_scalar", args=[ib.imm(-3), ib.r(0)], dst=ib.r(1)) + ib.emit_ret(ib.r(1)) + ib.get() + + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + assert vm["func0"](1) == -2 + assert vm["func0"](-3) == -6 + + +def test_emit_cache(): + ib = relax.ExecBuilder() + + with ib.function("func0", num_inputs=1): + x0 = ib.convert_constant("str0") + x1 = ib.convert_constant("str0") + # cache constant str + assert x0 == x1 + s0 = ib.convert_constant(tvm.runtime.container.ShapeTuple([1, 2])) + s1 = ib.convert_constant(tvm.runtime.container.ShapeTuple([1, 2])) + s2 = ib.convert_constant(tvm.runtime.container.ShapeTuple([1, 3])) + assert s0 == s1 + assert s1 != s2 + y0 = ib.convert_constant(tvm.nd.array(np.array([1, 2, 3]).astype("int32"))) + y1 = ib.convert_constant(tvm.nd.array(np.array([1, 2, 3]).astype("int32"))) + assert y0 == y1 + ib.emit_ret(ib.r(0)) + + +def test_vm_formalize(): + ib0 = relax.ExecBuilder() + ib1 = relax.ExecBuilder() + with ib0.function("func0", num_inputs=2): + ib0.emit_call("test.vm.add", args=[ib0.r(0), ib0.r(1)], dst=ib0.r(100)) + ib0.emit_call("test.vm.mul", args=[ib0.r(1), ib0.r(100)], dst=ib0.r(50)) + ib0.emit_ret(ib0.r(50)) + with ib1.function("func0", num_inputs=2): + ib1.emit_call("test.vm.add", args=[ib1.r(0), ib1.r(1)], dst=ib1.r(2)) + ib1.emit_call("test.vm.mul", args=[ib1.r(1), ib1.r(2)], dst=ib1.r(3)) + ib1.emit_ret(ib1.r(3)) + exec0 = ib0.get() + exec1 = ib1.get() + assert exec0.as_text() == exec1.as_text() + + +def test_vm_operand(): + ib0 = relax.ExecBuilder() + with ib0.function("func0", num_inputs=2): + ib0.emit_call("test.vm.add_scalar", args=[ib0.r(0), ib0.r(1)], dst=ib0.r(2)) + ib0.emit_ret(ib0.r(2)) + exec0 = ib0.get() + vm = relax.VirtualMachine(exec0, tvm.cpu()) + res = vm["func0"](2, 3) + assert res == 5 + + ib1 = relax.ExecBuilder() + with ib1.function("func1", num_inputs=1): + ib1.emit_call("test.vm.get_device_id", args=[ib1.r(0)], dst=ib1.r(1)) + ib1.emit_ret(ib1.r(1)) + exec1 = ib1.get() + vm = relax.VirtualMachine(exec1, tvm.cpu()) + res = vm["func1"](tvm.cpu(3)) + assert res == 3 + + +def test_vm_shapeof(): + ib = relax.ExecBuilder() + shape = (32, 16) + arr = tvm.nd.array(np.random.rand(*shape)) + with ib.function("main", num_inputs=0): + ib.emit_call("vm.builtin.shape_of", args=[arr], dst=ib.r(0)) + ib.emit_ret(ib.r(0)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + for i, s in enumerate(res): + assert s == shape[i] + + +def test_vm_storage(): + dtype = tvm.DataType("float32") + shape = (4, 6) + ib = relax.ExecBuilder() + with ib.function("main", num_inputs=0): + ib.emit_call( + "vm.builtin.alloc_storage", + args=[ib.vm_state(), (24,), ib.convert_constant(0), dtype], + dst=ib.r(1), + ) + ib.emit_call( + "vm.builtin.alloc_tensor", args=[ib.r(1), ib.imm(0), shape, dtype], dst=ib.r(2) + ) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + assert res.device == tvm.cpu() + assert res.shape == shape + + +def test_vm_goto(): + ib = relax.ExecBuilder() + with ib.function("main", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_goto(2) + ib.emit_call("test.vm.mul", args=[ib.r(2), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + res = check_saved_func(vm, "main", a, b) + tvm.testing.assert_allclose(res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_if(): + ib = relax.ExecBuilder() + with ib.function("main", num_inputs=3): + ib.emit_if(ib.r(0), 3) + ib.emit_call("test.vm.add", args=[ib.r(1), ib.r(2)], dst=ib.r(3)) + ib.emit_goto(2) + ib.emit_call("test.vm.mul", args=[ib.r(1), ib.r(2)], dst=ib.r(3)) + ib.emit_ret(ib.r(3)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + res = vm["main"](0, a, b) + tvm.testing.assert_allclose(res.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + res = vm["main"](1, a, b) + tvm.testing.assert_allclose(res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_invoke_closure(): + ib = relax.ExecBuilder() + with ib.function("lifted_func_1", num_inputs=4): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(4)) + ib.emit_call("test.vm.add", args=[ib.r(2), ib.r(4)], dst=ib.r(5)) + ib.emit_call("test.vm.add", args=[ib.r(3), ib.r(5)], dst=ib.r(6)) + ib.emit_ret(ib.r(6)) + with ib.function("main", num_inputs=2): + ib.emit_call( + "vm.builtin.make_closure", args=[ib.f("lifted_func_1"), ib.r(0), ib.r(1)], dst=ib.r(2) + ) + ib.emit_ret(ib.r(2)) + + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + w_inp = tvm.nd.array(np.random.rand(2, 3)) + x_inp = tvm.nd.array(np.random.rand(2, 3)) + y_inp = tvm.nd.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]]) + z_inp = tvm.nd.array(np.random.rand(2, 3)) + clo = vm["main"](w_inp, x_inp) + res = vm.invoke_closure(clo, y_inp, z_inp) + tvm.testing.assert_allclose( + res.numpy(), w_inp.numpy() + x_inp.numpy() + y_inp.numpy() + z_inp.numpy() + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_vm_instrument.py b/tests/python/relax/test_vm_instrument.py new file mode 100644 index 000000000000..8297da1b744f --- /dev/null +++ b/tests/python/relax/test_vm_instrument.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing + +from tvm import relax +from tvm.relax.testing import nn +from tvm.relax.testing.lib_comparator import LibCompareVMInstrument + + +def get_exec(data_shape): + builder = relax.BlockBuilder() + weight1_np = np.random.randn(64, 64).astype("float32") + weight2_np = np.random.randn(64, 64).astype("float32") + + with builder.function("main"): + model = nn.Sequential( + nn.Linear(data_shape[1], weight1_np.shape[0], bias=False), + nn.ReLU(), + nn.Linear(weight2_np.shape[0], weight2_np.shape[1], bias=False), + nn.ReLU(), + ) + data = nn.Placeholder(data_shape, name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + + mod = builder.get() + + params = {"linear_weight": weight1_np, "linear_weight1": weight2_np} + mod = relax.transform.BindParams("main", params)(mod) + + target = "llvm" + return relax.build(mod, target) + + +def test_conv2d_cpu(): + data_np = np.random.randn(1, 64).astype("float32") + ex = get_exec(data_np.shape) + vm = relax.VirtualMachine(ex, tvm.cpu()) + hit_count = {} + + def instrument(func, name, before_run, ret_val, *args): + if (name, before_run) not in hit_count: + hit_count[(name, before_run)] = 0 + hit_count[(name, before_run)] += 1 + assert callable(func) + if before_run: + assert ret_val is None + if name == "matmul": + return relax.VMInstrumentReturnKind.SKIP_RUN + + vm.set_instrument(instrument) + vm["main"](tvm.nd.array(data_np)) + assert hit_count[("matmul", True)] == 2 + assert ("matmul", False) not in hit_count + assert hit_count[("relu", True)] == 2 + assert hit_count[("relu", False)] == 2 + + +def test_lib_comparator(): + data_np = np.random.randn(1, 64).astype("float32") + ex = get_exec(data_np.shape) + vm = relax.VirtualMachine(ex, tvm.cpu()) + # compare against library module + cmp = LibCompareVMInstrument(vm.module.imported_modules[0], tvm.cpu(), verbose=False) + vm.set_instrument(cmp) + vm["main"](tvm.nd.array(data_np)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_vm_profiler.py b/tests/python/relax/test_vm_profiler.py new file mode 100644 index 000000000000..114596741113 --- /dev/null +++ b/tests/python/relax/test_vm_profiler.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing + +from tvm import relax, rpc +from tvm.contrib import utils +from tvm.relax.testing import nn +from tvm.script import relax as R + + +def get_exec(data_shape): + builder = relax.BlockBuilder() + weight1_np = np.random.randn(64, 64).astype("float32") + weight2_np = np.random.randn(64, 64).astype("float32") + + with builder.function("main"): + model = nn.Sequential( + nn.Linear(data_shape[1], weight1_np.shape[0], bias=False), + nn.ReLU(), + nn.Linear(weight2_np.shape[0], weight2_np.shape[1], bias=False), + nn.ReLU(), + ) + data = nn.Placeholder(data_shape, name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + + mod = builder.get() + + params = {"linear_weight": weight1_np, "linear_weight1": weight2_np} + mod = relax.transform.BindParams("main", params)(mod) + + target = "llvm" + return relax.build(mod, target) + + +def test_conv2d_cpu(): + data_np = np.random.randn(1, 64).astype("float32") + ex = get_exec(data_np.shape) + + vm = relax.VirtualMachine(ex, tvm.cpu(), profile=True) + report = vm.profile("main", tvm.nd.array(data_np)) + print(report) + + assert "Duration" in str(report) + assert "matmul" in str(report) + + +def with_rpc(ex, f, data_np): + temp = utils.tempdir() + path = temp.relpath("vm_library.so") + ex.export_library(path) + + server = rpc.Server("127.0.0.1") + remote = rpc.connect(server.host, server.port, session_timeout=10) + + remote.upload(path) + rexec = remote.load_module("vm_library.so") + + device = remote.cpu() + + vm = relax.VirtualMachine(rexec, device=device, profile=True) + data = tvm.nd.array(data_np, device) + + f(vm, data) + + +def test_rpc(): + data_np = np.random.randn(1, 64).astype("float32") + ex = get_exec(data_np.shape) + + def callback(vm, data): + vm.profile("main", data) + + vm.set_input("main", data) + report = vm.profile("main") + + assert "matmul" in str(report) + print(report) + + with_rpc(ex, callback, data_np) + + +def test_tuple(): + @tvm.script.ir_module + class NestedTuple: + @R.function + def main( + x: R.Tensor((16,), "float32") + ) -> R.Tuple( + R.Tuple( + R.Tensor((16,), "float32"), + R.Tuple( + R.Tensor((16,), "float32"), + ), + ), + R.Tensor((16,), "float32"), + ): + return ((x, (x,)), x) + + target = "llvm" + ex = relax.build(NestedTuple, target) + + data_np = np.random.randn(16).astype("float32") + + def callback(vm, data): + report = vm.profile("main", data) + assert "vm.builtin.make_tuple" in str(report) + + with_rpc(ex, callback, data_np) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 6443d50f9e98..63ff66eaa291 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -862,7 +862,7 @@ def prepare_vm_model(path, tensor_shape): vm_exec = vm.compile(mod, target=target) # Export to Disk - vm_exec.mod.export_library(path) + vm_exec.export_library(path) def test_vm_rpc(): @@ -1393,7 +1393,7 @@ def test_large_constants(): path_consts = temp.relpath("consts") vm_exec.move_late_bound_consts(path_consts, byte_limit=256) path_dso = temp.relpath("lib.so") - vm_exec.mod.export_library(path_dso) + vm_exec.export_library(path_dso) # Load library files and constants mod = runtime.load_module(path_dso) @@ -1442,7 +1442,7 @@ def test_load_late_bound_consts_with_no_late_bound_consts(): # Ensure const_data is below the byte threshold for a late-bound const. byte_limit = len(const_data.tobytes()) + 1 vm_exec.move_late_bound_consts(path_consts, byte_limit=byte_limit) - vm_exec.mod.export_library(path_dso) + vm_exec.export_library(path_dso) mod = runtime.load_module(path_dso) mod["load_late_bound_consts"](path_consts) @@ -1503,7 +1503,7 @@ def test_load_and_save_constants_via_map(): # Save to constants and library files temp = utils.tempdir() path_dso = temp.relpath("lib.so") - vm_exec.mod.export_library(path_dso) + vm_exec.export_library(path_dso) # Load library files and constants mod = runtime.load_module(path_dso) @@ -1551,7 +1551,7 @@ def test_load_late_bound_consts_via_map_with_no_late_bound_consts(): # Ensure const_data is below the byte threshold for a late-bound const. byte_limit = len(const_data.tobytes()) + 1 consts_map = vm_exec.get_late_bound_consts(byte_limit=byte_limit) - vm_exec.mod.export_library(path_dso) + vm_exec.export_library(path_dso) mod = runtime.load_module(path_dso) mod["load_late_bound_consts_from_map"](consts_map) diff --git a/tests/python/topi/python/test_topi_group_norm.py b/tests/python/topi/python/test_topi_group_norm.py index f09442391672..8f8ab75b8a2e 100644 --- a/tests/python/topi/python/test_topi_group_norm.py +++ b/tests/python/topi/python/test_topi_group_norm.py @@ -34,7 +34,8 @@ # only test on llvm because schedule is missing @tvm.testing.parametrize_targets("llvm") @pytest.mark.parametrize("shape, axis", [([2, 4, 16], (2,)), ([2, 4, 4, 16], (2, 3))]) -def test_group_norm(target, dev, shape, axis, epsilon=1e-5, dtype="float32", rtol=1e-5, atol=1e-5): +@pytest.mark.parametrize("dtype", ["float32", "float16"]) +def test_group_norm(target, dev, shape, axis, dtype, epsilon=1e-5, rtol=1e-5, atol=1e-5): data = te.placeholder(shape, dtype=dtype, name="data") num_groups = 2 channel_axis = 1 diff --git a/tests/python/topi/python/test_topi_layer_norm.py b/tests/python/topi/python/test_topi_layer_norm.py index f875bb09e2a4..ff9eedd4e5e4 100644 --- a/tests/python/topi/python/test_topi_layer_norm.py +++ b/tests/python/topi/python/test_topi_layer_norm.py @@ -34,7 +34,8 @@ # only test on llvm because schedule is missing @tvm.testing.parametrize_targets("llvm") @pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))]) -def test_layer_norm(target, dev, shape, axis, episilon=1e-5, dtype="float32", rtol=1e-5, atol=1e-5): +@pytest.mark.parametrize("dtype", ["float32", "float16"]) +def test_layer_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-4, atol=5e-4): data = te.placeholder(shape, dtype=dtype, name="data") scale_shape = [shape[dim] for dim in axis] gamma = te.placeholder(scale_shape, dtype=dtype, name="gamma") diff --git a/tests/python/topi/python/test_topi_reduce.py b/tests/python/topi/python/test_topi_reduce.py index 3c4c170d0dd9..71ce654913f1 100644 --- a/tests/python/topi/python/test_topi_reduce.py +++ b/tests/python/topi/python/test_topi_reduce.py @@ -26,6 +26,7 @@ import tvm.topi.testing from tvm import te, topi +from tvm.topi.utils import get_const_tuple in_shape, axis, keepdims, reduce_type, dtype = tvm.testing.parameters( ((32,), 0, False, "argmax", "float32"), @@ -191,5 +192,43 @@ def test_complex_reduce(target, dev): tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3) +data_shape, target_shape = tvm.testing.parameters( + ((2, 3), (3,)), + ((2, 3, 4), (2, 1, 4)), + ((2, 3, 4, 5), (3, 1, 5)), +) + + +def _my_npy_collapse_sum(data, target_shape): + reduce_axes = [] + i = data.ndim - 1 + j = len(target_shape) - 1 + while i >= 0: + if j < 0: + reduce_axes.append(i) + elif target_shape[j] == 1 and data.shape[i] > 1: + reduce_axes.append(i) + i -= 1 + j -= 1 + return np.sum(data, tuple(reduce_axes)).reshape(target_shape) + + +def test_collapse_sum(data_shape, target_shape): + A = te.placeholder(data_shape, name="A") + B = topi.collapse_sum(A, target_shape) + s = te.create_schedule([B.op]) + + a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) + b_np = _my_npy_collapse_sum(a_np, target_shape) + dev = tvm.cpu(0) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + # Building with the CSE pass disabled + with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): + foo = tvm.build(s, [A, B], "llvm", name="collapse_sum") + foo(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_arith_detect_cse.py b/tests/python/unittest/test_arith_detect_cse.py index eba0920cb2da..dd7362ff1b7c 100644 --- a/tests/python/unittest/test_arith_detect_cse.py +++ b/tests/python/unittest/test_arith_detect_cse.py @@ -20,9 +20,9 @@ def test_detect_cs(): - x = T.Var("x", dtype="int32") - y = T.Var("y", dtype="int32") - z = T.Var("z", dtype="int32") + x = T.int32() + y = T.int32() + z = T.int32() c = T.floor(x + y + 0.5) + x + z * (T.floor(x + y + 0.5)) m = tvm.arith.detect_common_subexpr(c, 2) assert c.a.a in m diff --git a/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py b/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py index d1ba84d836be..2dbf14d45195 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py +++ b/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py @@ -31,6 +31,7 @@ from tvm.script import tir as T from tvm.tir.schedule import BlockRV +# fmt: off # Small gpu parameters which should work for nearly every (modern-ish) gpu. TARGET = tvm.target.Target( "cuda -max_threads_per_block=32 -max_num_threads=128 -thread_warp_size=32 -max_shared_memory_per_block=8192 -registers_per_block=1024" diff --git a/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py b/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py new file mode 100644 index 000000000000..f275d438a740 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +from tvm import TVMError +from tvm.script import tir as T +import tvm.testing + + +def test_thread_axis1(): + @T.prim_func + def before(A: T.Buffer((T.int64(64),), "float32"), B: T.Buffer((T.int64(64),), "float32")): + blockIdx_x = T.env_thread("blockIdx.x") + T.launch_thread(blockIdx_x, T.int64(2)) + threadIdx_x = T.env_thread("threadIdx.x") + T.launch_thread(threadIdx_x, T.int64(32)) + B[T.Cast("int64", blockIdx_x) * T.int64(32) + T.Cast("int64", threadIdx_x)] = A[ + T.Cast("int64", blockIdx_x) * T.int64(32) + T.Cast("int64", threadIdx_x) + ] + T.float32(1) + + @T.prim_func + def expected(A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")): + blockIdx_x = T.env_thread("blockIdx.x") + T.launch_thread(blockIdx_x, 2) + threadIdx_x = T.env_thread("threadIdx.x") + T.launch_thread(threadIdx_x, 32) + B[blockIdx_x * 32 + threadIdx_x] = A[blockIdx_x * 32 + threadIdx_x] + T.float32(1) + + mod = tvm.IRModule.from_expr(before) + func = tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"] + tvm.ir.assert_structural_equal(func, expected) + + +def test_thread_axis2(): + @T.prim_func + def before( + T_reshape: T.Buffer((1, 12, 384, 384), "float32"), + placeholder_1: T.Buffer((T.int64(1), T.int64(12), T.int64(384), 384), "bool"), + T_where: T.Buffer((T.int64(1), T.int64(12), T.int64(384), 384), "float32"), + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): + for i0_i1_i2_i3_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)): + with T.block("T_where"): + ax0 = T.axis.spatial(T.int64(1), T.int64(0)) + ax1 = T.axis.spatial( + T.int64(12), + ( + (i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) + * T.int64(1024) + + i0_i1_i2_i3_fused_2 + ) + % T.int64(1769472) + // T.int64(147456), + ) + ax2 = T.axis.spatial( + T.int64(384), + ( + (i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) + * T.int64(1024) + + i0_i1_i2_i3_fused_2 + ) + % T.int64(147456) + // T.int64(384), + ) + ax3 = T.axis.spatial( + 384, + T.cast( + ( + (i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) + * T.int64(1024) + + i0_i1_i2_i3_fused_2 + ) + % T.int64(384), + "int32", + ), + ) + T.where( + (i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) + * T.int64(1024) + + i0_i1_i2_i3_fused_2 + < T.int64(1769472) + ) + T.reads(placeholder_1[ax0, ax1, ax2, ax3], T_reshape[ax0, ax1, ax2, ax3]) + T.writes(T_where[ax0, ax1, ax2, ax3]) + T_where[ax0, ax1, ax2, ax3] = T.Select( + T.cast(placeholder_1[ax0, ax1, ax2, ax3], "int32") != 0, + T.float32(-1000000000), + T_reshape[ax0, ax1, ax2, ax3], + ) + + @T.prim_func + def expected( + T_reshape: T.Buffer((1, 12, 384, 384), "float32"), + placeholder_1: T.Buffer((1, 12, 384, 384), "bool"), + T_where: T.Buffer((1, 12, 384, 384), "float32"), + ): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0_i1_i2_i3_fused_1 in T.thread_binding(256, thread="blockIdx.x"): + for i0_i1_i2_i3_fused_2 in T.thread_binding(1024, thread="threadIdx.x"): + for i0_i1_i2_i3_fused_0 in range(7): + with T.block("T_where"): + ax0 = T.axis.spatial(1, 0) + ax1 = T.axis.spatial( + 12, + ( + (i0_i1_i2_i3_fused_0 * 256 + i0_i1_i2_i3_fused_1) * 1024 + + i0_i1_i2_i3_fused_2 + ) + % 1769472 + // 147456, + ) + ax2 = T.axis.spatial( + 384, + ( + (i0_i1_i2_i3_fused_0 * 256 + i0_i1_i2_i3_fused_1) * 1024 + + i0_i1_i2_i3_fused_2 + ) + % 147456 + // 384, + ) + ax3 = T.axis.spatial( + 384, + ( + (i0_i1_i2_i3_fused_0 * 256 + i0_i1_i2_i3_fused_1) * 1024 + + i0_i1_i2_i3_fused_2 + ) + % 384, + ) + T.where( + (i0_i1_i2_i3_fused_0 * 256 + i0_i1_i2_i3_fused_1) * 1024 + + i0_i1_i2_i3_fused_2 + < 1769472 + ) + T.reads(placeholder_1[ax0, ax1, ax2, ax3], T_reshape[ax0, ax1, ax2, ax3]) + T.writes(T_where[ax0, ax1, ax2, ax3]) + T_where[ax0, ax1, ax2, ax3] = T.Select( + T.Cast("int32", placeholder_1[ax0, ax1, ax2, ax3]) != 0, + T.float32(-1000000000), + T_reshape[ax0, ax1, ax2, ax3], + ) + + mod = tvm.IRModule.from_expr(before) + func = tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"] + tvm.ir.assert_structural_equal(func, expected) + + +def test_block(): + @T.prim_func + def before(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): + for i in T.serial(0, T.int64(16)): + for j in T.serial(0, T.int64(8)): + with T.block(): + vi = T.axis.spatial(T.int64(128), i * T.int64(8) + j) + B[vi] = A[vi] + T.float32(1) + + @T.prim_func + def expected(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): + for i in T.serial(0, T.int32(16)): + for j in T.serial(0, T.int32(8)): + with T.block(): + vi = T.axis.spatial(T.int32(128), i * T.int32(8) + j) + B[vi] = A[vi] + T.float32(1) + + mod = tvm.IRModule.from_expr(before) + func = tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"] + tvm.ir.assert_structural_equal(func, expected) + + +def test_fail_on_buffer_map(): + @T.prim_func + def func(A: T.Buffer((128,), "int64"), B: T.Buffer((128,), "int64")): + for i in T.serial(0, 16): + for j in T.serial(0, 8): + with T.block(): + vi = T.axis.spatial(128, i * 8 + j) + B[vi] = A[vi] + T.int64(1) + + mod = tvm.IRModule.from_expr(func) + with pytest.raises(TVMError): + tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"] + + +def test_fail_on_buffer_map(): + @T.prim_func + def func(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")): + C = T.alloc_buffer((128,), "int64") + for i in T.serial(0, 16): + for j in T.serial(0, 8): + with T.block(): + vi = T.axis.spatial(128, i * 8 + j) + C[vi] = T.cast(A[vi], "int64") + T.int64(1) + for i in T.serial(0, 16): + for j in T.serial(0, 8): + with T.block(): + vi = T.axis.spatial(128, i * 8 + j) + B[vi] = T.cast(C[vi] + T.int64(1), "int32") + + mod = tvm.IRModule.from_expr(func) + with pytest.raises(TVMError): + tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"] + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_transform_default_gpu_schedule.py b/tests/python/unittest/test_transform_default_gpu_schedule.py new file mode 100644 index 000000000000..644a9aede0cd --- /dev/null +++ b/tests/python/unittest/test_transform_default_gpu_schedule.py @@ -0,0 +1,417 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,,missing-function-docstring +import tvm +from tvm.tir.transform import DefaultGPUSchedule +from tvm.script import tir as T +import tvm.testing + + +def test_broadcast_to_symbolic(): + # pylint: disable=no-self-argument,missing-class-docstring,line-too-long + # fmt: off + @tvm.script.ir_module + class Before: + @T.prim_func + def broadcast_to( + rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"), + var_T_broadcast_to: T.handle, + ): + T.func_attr({"tir.noalias": True}) + x_0 = T.int64() + x_1 = T.int64() + T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1)) + # with T.block("root"): + for ax0, ax1 in T.grid(x_0, x_1): + with T.block("T_broadcast_to"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, T.int64(0)]) + T.writes(T_broadcast_to[v_ax0, v_ax1]) + T_broadcast_to[v_ax0, v_ax1] = rxplaceholder[v_ax0, T.int64(0)] + + @tvm.script.ir_module + class Expected: + @T.prim_func + def broadcast_to( + rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"), + var_T_broadcast_to: T.handle, + ): + T.func_attr({"tir.noalias": True}) + x_0 = T.int64() + x_1 = T.int64() + T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1)) + # with T.block("root"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): + for ax0_ax1_fused_2 in T.thread_binding( + T.int64(1024), thread="threadIdx.x" + ): + for ax0_ax1_fused_0 in range( + (x_0 * x_1 + T.int64(262143)) // T.int64(262144) + ): + with T.block("T_broadcast_to"): + v_ax0 = T.axis.spatial( + x_0, + ( + (ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1) + * T.int64(1024) + + ax0_ax1_fused_2 + ) + // x_1, + ) + v_ax1 = T.axis.spatial( + x_1, + ( + (ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1) + * T.int64(1024) + + ax0_ax1_fused_2 + ) + % x_1, + ) + T.where( + (ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1) + * T.int64(1024) + + ax0_ax1_fused_2 + < x_0 * x_1 + ) + T.reads(rxplaceholder[v_ax0, T.int64(0)]) + T.writes(T_broadcast_to[v_ax0, v_ax1]) + T_broadcast_to[v_ax0, v_ax1] = rxplaceholder[v_ax0, T.int64(0)] + # fmt: on + # pylint: enable=no-self-argument,missing-class-docstring,line-too-long + target = tvm.target.Target("nvidia/geforce-rtx-3070") + with target, tvm.transform.PassContext(opt_level=3): + After = DefaultGPUSchedule()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + +def test_matmul(): + # pylint: disable=no-self-argument,missing-class-docstring,line-too-long + # fmt: off + @tvm.script.ir_module + class Before: + @T.prim_func + def matmul( + A: T.Buffer((32, 32), "float16"), + B: T.Buffer((32, 32), "float16"), + C: T.Buffer((32, 32), "float16"), + ): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # with T.block("root"): + for i, j, k in T.grid(32, 32, 32): + with T.block("C"): + v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) + T.reads(A[v_i, v_k], B[v_k, v_j]) + T.writes(C[v_i, v_j]) + with T.init(): + C[v_i, v_j] = T.float16(0) + C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j] + + @tvm.script.ir_module + class Expected: + @T.prim_func + def matmul( + A: T.Buffer((32, 32), "float16"), + B: T.Buffer((32, 32), "float16"), + C: T.Buffer((32, 32), "float16"), + ): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # with T.block("root"): + for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + for k in range(32): + with T.block("C"): + v_i = T.axis.spatial( + 32, (i_j_fused_0 * 1024 + i_j_fused_1) // 32 + ) + v_j = T.axis.spatial( + 32, (i_j_fused_0 * 1024 + i_j_fused_1) % 32 + ) + v_k = T.axis.reduce(32, k) + T.reads(A[v_i, v_k], B[v_k, v_j]) + T.writes(C[v_i, v_j]) + with T.init(): + C[v_i, v_j] = T.float16(0) + C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j] + # fmt: on + # pylint: enable=no-self-argument,missing-class-docstring,line-too-long + target = tvm.target.Target("nvidia/geforce-rtx-3070") + with target, tvm.transform.PassContext(opt_level=3): + After = DefaultGPUSchedule()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + +def test_add(): + # pylint: disable=no-self-argument,missing-class-docstring,line-too-long + # fmt: off + @tvm.script.ir_module + class Before: + @T.prim_func + def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), + rxplaceholder_1: T.Buffer( + (T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32" + ), + T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): + for i0_i1_i2_i3_fused_1 in T.thread_binding( + T.int64(72), thread="threadIdx.x" + ): + with T.block("T_add"): + ax0 = T.axis.spatial( + T.int64(4), + (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) + // T.int64(18), + ) + ax1 = T.axis.spatial( + T.int64(3), + (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) + % T.int64(18) + // T.int64(6), + ) + ax2 = T.axis.spatial( + T.int64(2), + (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) + % T.int64(6) + // T.int64(3), + ) + ax3 = T.axis.spatial( + T.int64(3), + (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) + % T.int64(3), + ) + T.reads( + rxplaceholder[T.int64(0), ax2, ax3], + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], + ) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = ( + rxplaceholder[T.int64(0), ax2, ax3] + + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + ) + + # fmt: on + # pylint: enable=no-self-argument,missing-class-docstring,line-too-long + target = tvm.target.Target("nvidia/geforce-rtx-3070") + with target, tvm.transform.PassContext(opt_level=3): + After = DefaultGPUSchedule()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + +def test_full(): + # pylint: disable=no-self-argument,missing-class-docstring,line-too-long + # fmt: off + @tvm.script.ir_module + class Before: + @T.prim_func + def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + + @tvm.script.ir_module + class Expected: + @T.prim_func + def full( + rxplaceholder: T.Buffer((), "int32"), + T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): + for i0_i1_fused_1 in T.thread_binding(T.int64(6), thread="threadIdx.x"): + with T.block("T_full"): + ax0 = T.axis.spatial( + T.int64(2), + (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) // T.int64(3), + ) + ax1 = T.axis.spatial( + T.int64(3), + (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) % T.int64(3), + ) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + + # fmt: on + # pylint: enable=no-self-argument,missing-class-docstring,line-too-long + target = tvm.target.Target("nvidia/geforce-rtx-3070") + with target, tvm.transform.PassContext(opt_level=3): + After = DefaultGPUSchedule()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + +def test_scheduled(): + # pylint: disable=no-self-argument,missing-class-docstring,line-too-long + # fmt: off + + @tvm.script.ir_module + class Scheduled: + @T.prim_func + def full( + rxplaceholder: T.Buffer((), "int32"), + T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): + for i0_i1_fused_1 in T.thread_binding(T.int64(6), thread="threadIdx.x"): + with T.block("T_full"): + ax0 = T.axis.spatial( + T.int64(2), + (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) // T.int64(3), + ) + ax1 = T.axis.spatial( + T.int64(3), + (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) % T.int64(3), + ) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + + # fmt: on + # pylint: enable=no-self-argument,missing-class-docstring,line-too-long + target = tvm.target.Target("nvidia/geforce-rtx-3070") + with target, tvm.transform.PassContext(opt_level=3): + # should do nothing + After = DefaultGPUSchedule()(Scheduled) + tvm.ir.assert_structural_equal(After, Scheduled) + + +def test_multiple(): + # pylint: disable=no-self-argument,missing-class-docstring,line-too-long + # fmt: off + @tvm.script.ir_module + class Before: + @T.prim_func + def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), + rxplaceholder_1: T.Buffer( + (T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32" + ), + T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): + for i0_i1_i2_i3_fused_1 in T.thread_binding( + T.int64(72), thread="threadIdx.x" + ): + with T.block("T_add"): + ax0 = T.axis.spatial( + T.int64(4), + (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) + // T.int64(18), + ) + ax1 = T.axis.spatial( + T.int64(3), + (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) + % T.int64(18) + // T.int64(6), + ) + ax2 = T.axis.spatial( + T.int64(2), + (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) + % T.int64(6) + // T.int64(3), + ) + ax3 = T.axis.spatial( + T.int64(3), + (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) + % T.int64(3), + ) + T.reads( + rxplaceholder[T.int64(0), ax2, ax3], + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], + ) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = ( + rxplaceholder[T.int64(0), ax2, ax3] + + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + ) + + @T.prim_func + def full( + rxplaceholder: T.Buffer((), "int32"), + T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): + for i0_i1_fused_1 in T.thread_binding(T.int64(6), thread="threadIdx.x"): + with T.block("T_full"): + ax0 = T.axis.spatial( + T.int64(2), + (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) // T.int64(3), + ) + ax1 = T.axis.spatial( + T.int64(3), + (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) % T.int64(3), + ) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + # fmt: on + # pylint: enable=no-self-argument,missing-class-docstring,line-too-long + target = tvm.target.Target("nvidia/geforce-rtx-3070") + with target, tvm.transform.PassContext(opt_level=3): + After = DefaultGPUSchedule()(Before) + assert tvm.ir.structural_equal(After, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_tvmscript_printer_ir.py b/tests/python/unittest/test_tvmscript_printer_ir.py index 6b3ac19a5ef8..d82e682c2949 100644 --- a/tests/python/unittest/test_tvmscript_printer_ir.py +++ b/tests/python/unittest/test_tvmscript_printer_ir.py @@ -15,7 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring -from tvm import IRModule + +import pytest + +from tvm import IRModule, TVMError from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import ir as I from tvm.script.ir_builder import tir as T @@ -48,5 +51,17 @@ def foo(): ) +def test_failed_invalid_prefix(): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module(): + with T.prim_func(): + T.func_name("foo") + mod = ib.get() + + with pytest.raises(TVMError): + mod.script(ir_prefix="2I") + + if __name__ == "__main__": test_ir_module() + test_failed_invalid_prefix() diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh index 83ea86ecccb8..9ca83ece5cd5 100755 --- a/tests/scripts/task_lint.sh +++ b/tests/scripts/task_lint.sh @@ -31,8 +31,8 @@ function shard1 { echo "Convert scripts to Python..." tests/scripts/task_convert_scripts_to_python.sh - echo "Check Jenkinsfile generation" - python3 ci/jenkins/generate.py --check + # echo "Check Jenkinsfile generation" + # python3 ci/jenkins/generate.py --check echo "Checking file types..." python3 tests/lint/check_file_type.py diff --git a/tests/scripts/unity/README b/tests/scripts/unity/README new file mode 100644 index 000000000000..42f8c3e040ea --- /dev/null +++ b/tests/scripts/unity/README @@ -0,0 +1,2 @@ +This folder contains CI task scripts that are specialized +to unity branch, please do not send to other places. diff --git a/tests/scripts/unity/task_extra_lint.sh b/tests/scripts/unity/task_extra_lint.sh new file mode 100755 index 000000000000..989f4df7389e --- /dev/null +++ b/tests/scripts/unity/task_extra_lint.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euxo pipefail + +source tests/scripts/setup-pytest-env.sh + +# place extra lint here. diff --git a/tests/scripts/unity/task_python_relax.sh b/tests/scripts/unity/task_python_relax.sh new file mode 100755 index 000000000000..8869c318fab7 --- /dev/null +++ b/tests/scripts/unity/task_python_relax.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euxo pipefail + +source tests/scripts/setup-pytest-env.sh +export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python +export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}" + +# to avoid CI CPU thread throttling. +export TVM_BIND_THREADS=0 +export TVM_NUM_THREADS=2 + +make cython3 + +# Run Relax tests +TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax + +# Run Relax examples +# python3 ./apps/relax_examples/mlp.py +# python3 ./apps/relax_examples/nn_module.py +# python3 ./apps/relax_examples/resnet.py diff --git a/tests/scripts/unity/task_python_relax_gpuonly.sh b/tests/scripts/unity/task_python_relax_gpuonly.sh new file mode 100755 index 000000000000..acbcce44f279 --- /dev/null +++ b/tests/scripts/unity/task_python_relax_gpuonly.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +export TVM_TEST_TARGETS="llvm;cuda" +export PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" +export TVM_RELAY_TEST_TARGETS="cuda" +export TVM_INTEGRATION_TESTSUITE_NAME=python-integration-gpu +export TVM_INTEGRATION_GPU_ONLY=1 + +./tests/scripts/unity/task_python_relax.sh diff --git a/web/.gitignore b/web/.gitignore index 1f7cc0916a5f..69bf96a8a726 100644 --- a/web/.gitignore +++ b/web/.gitignore @@ -4,3 +4,4 @@ out node_modules build debug +.ndarray_cache diff --git a/web/apps/browser/rpc_plugin.html b/web/apps/browser/rpc_plugin.html new file mode 100644 index 000000000000..87df60d42b60 --- /dev/null +++ b/web/apps/browser/rpc_plugin.html @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html index 6d353e29b08d..07e6fe87fc95 100644 --- a/web/apps/browser/rpc_server.html +++ b/web/apps/browser/rpc_server.html @@ -15,13 +15,19 @@ + - + TVM RPC Test Page + + + - + + +

TVM WebSocket RPC Server

To use this page
    @@ -59,21 +105,35 @@

    TVM WebSocket RPC Server

Options

- Proxy URL
- RPC Server Key
+ NDArrayCache - + + CacheDevice - + +
+
+
+ +
+
- diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 00d2a8c579f1..8f16365eee39 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -26,6 +26,7 @@ #define TVM_LOG_STACK_TRACE 0 #define TVM_LOG_DEBUG 0 #define TVM_LOG_CUSTOMIZE 1 + #define DMLC_USE_LOGGING_LIBRARY #include @@ -51,6 +52,12 @@ #include "src/runtime/rpc/rpc_session.cc" #include "src/runtime/system_library.cc" #include "src/runtime/workspace_pool.cc" +// relax setup +#include "src/runtime/relax_vm/builtin.cc" +#include "src/runtime/relax_vm/bytecode.cc" +#include "src/runtime/relax_vm/executable.cc" +#include "src/runtime/relax_vm/memory_manager.cc" +#include "src/runtime/relax_vm/vm.cc" // --- Implementations of backend and wasm runtime API. --- @@ -111,5 +118,107 @@ TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRet // and get another value. *ret = (obj.use_count() - 1); }); + +/*! + * A NDArray cache to store pre-loaded arrays in the system. + */ +class NDArrayCache { + public: + static NDArrayCache* Global() { + static NDArrayCache* inst = new NDArrayCache(); + return inst; + } + + static void Update(String name, NDArray arr, bool override) { + NDArrayCache* pool = Global(); + if (!override) { + ICHECK_EQ(pool->pool_.count(name), 0) << "Name " << name << " already exists in the cache"; + } + pool->pool_.Set(name, arr); + } + + static Optional Get(String name) { + NDArrayCache* pool = Global(); + auto it = pool->pool_.find(name); + if (it != pool->pool_.end()) { + return (*it).second; + } else { + return NullOpt; + } + } + + static void Remove(String name) { + NDArrayCache* pool = Global(); + pool->pool_.erase(name); + } + + static void Clear() { Global()->pool_.clear(); } + + private: + Map pool_; +}; + +TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.get").set_body_typed(NDArrayCache::Get); +TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.update").set_body_typed(NDArrayCache::Update); +TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove); +TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear); + +void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format) { + if (format == "f32-to-bf16") { + std::vector buffer(bytes.length() / 2); + std::memcpy(buffer.data(), bytes.data(), buffer.size() * 2); + // decode bf16 to f32 + const uint16_t* bf16 = reinterpret_cast(buffer.data()); + uint32_t* data = static_cast(cpu_arr->data); + ICHECK(cpu_arr.IsContiguous()); + size_t size = 1; + for (int i = 0; i < cpu_arr->ndim; ++i) { + size *= cpu_arr->shape[i]; + } + ICHECK_EQ(size, bytes.length() / 2); + for (size_t i = 0; i < size; ++i) { + data[i] = static_cast(bf16[i]) << 16; + } + } else { + cpu_arr.CopyFromBytes(bytes.data(), bytes.length()); + } +} + +TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage); + +class ParamModuleNode : public runtime::ModuleNode { + public: + const char* type_key() const final { return "param_module"; } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + if (name == "get_params") { + auto params = params_; + return PackedFunc([params](TVMArgs args, TVMRetValue* rv) { *rv = params; }); + } else { + return PackedFunc(); + } + } + + static Module Create(std::string prefix, int num_params) { + Array params; + for (int i = 0; i < num_params; ++i) { + std::string name = prefix + "_" + std::to_string(i); + auto opt = NDArrayCache::Get(name); + if (opt) { + params.push_back(opt.value()); + } else { + LOG(FATAL) << "Cannot find " << name << " in cache"; + } + } + auto n = make_object(); + n->params_ = params; + return Module(n); + } + + private: + Array params_; +}; + +TVM_REGISTER_GLOBAL("tvmjs.param_module_from_cache").set_body_typed(ParamModuleNode::Create); } // namespace runtime } // namespace tvm diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 17efcc8c70a7..6c27a8207edc 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -160,6 +160,34 @@ class WebGPUModuleNode final : public runtime::ModuleNode { const char* type_key() const final { return "webgpu"; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + // special function + if (name == "webgpu.get_fmap") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + std::ostringstream os; + dmlc::JSONWriter writer(&os); + writer.Write(fmap_); + *rv = os.str(); + }); + } else if (name == "webgpu.get_shader") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + std::string name = args[0]; + auto it = smap_.find(name); + ICHECK(it != smap_.end()) << "Cannot find code " << name; + *rv = it->second; + }); + } else if (name == "webgpu.update_prebuild") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + std::string name = args[0]; + PackedFunc func = args[1]; + prebuild_[name] = func; + }); + } + // check prebuild cache + auto prebuild_it = prebuild_.find(name); + if (prebuild_it != prebuild_.end()) { + return prebuild_it->second; + } + auto it = smap_.find(name); if (it != smap_.end()) { FunctionInfo info = fmap_.at(name); @@ -173,6 +201,8 @@ class WebGPUModuleNode final : public runtime::ModuleNode { } } + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; + void SaveToFile(const std::string& file_name, const std::string& format) final { LOG(FATAL) << "Not implemented"; } @@ -185,12 +215,14 @@ class WebGPUModuleNode final : public runtime::ModuleNode { } private: - // function information table. + // code table std::unordered_map smap_; // function information table. std::unordered_map fmap_; // The source std::string source_; + // prebuild_ functions + std::unordered_map prebuild_; // Callback to get the GPU function. TypedPackedFunc create_shader_; }; diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts index e37d1838d604..22c4b4617e06 100644 --- a/web/src/rpc_server.ts +++ b/web/src/rpc_server.ts @@ -19,10 +19,9 @@ import { SizeOf, ArgTypeCode } from "./ctypes"; import { assert, StringToUint8Array, Uint8ArrayToString } from "./support"; -import { detectGPUDevice } from "./webgpu"; +import { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu"; import * as compact from "./compact"; import * as runtime from "./runtime"; -import { timeStamp } from "console"; import { Disposable } from "./types"; enum RPCServerState { @@ -82,6 +81,10 @@ export class RPCServer { state: RPCServerState = RPCServerState.InitHeader; logger: (msg: string) => void; getImports: () => Record; + private ndarrayCacheUrl: string; + private ndarrayCacheDevice: string; + private initProgressCallback?: runtime.InitProgressCallback; + private asyncOnServerLoad?: (inst: runtime.Instance) => Promise; private pendingSend: Promise = Promise.resolve(); private name: string; private inst?: runtime.Instance = undefined; @@ -98,14 +101,21 @@ export class RPCServer { url: string, key: string, getImports: () => Record, - logger: (msg: string) => void = console.log + logger: (msg: string) => void = console.log, + ndarrayCacheUrl: string = "", + ndarrayCacheDevice: string = "cpu", + initProgressCallback: runtime.InitProgressCallback | undefined = undefined, + asyncOnServerLoad: ((inst: runtime.Instance) => Promise) | undefined = undefined, ) { this.url = url; this.key = key; this.name = "WebSocketRPCServer[" + this.key + "]: "; this.getImports = getImports; this.logger = logger; - + this.ndarrayCacheUrl = ndarrayCacheUrl; + this.ndarrayCacheDevice = ndarrayCacheDevice; + this.initProgressCallback = initProgressCallback; + this.asyncOnServerLoad = asyncOnServerLoad; this.checkLittleEndian(); this.socket = compact.createWebSocket(url); this.socket.binaryType = "arraybuffer"; @@ -127,12 +137,16 @@ export class RPCServer { this.globalObjects.forEach(obj => { obj.dispose(); }); + this.log(this.inst.runtimeStatsText()); this.inst.dispose(); } if (this.state == RPCServerState.ReceivePacketHeader) { this.log("Closing the server in clean state"); this.log("Automatic reconnecting.."); - new RPCServer(this.url, this.key, this.getImports, this.logger); + new RPCServer( + this.url, this.key, this.getImports, this.logger, + this.ndarrayCacheUrl, this.ndarrayCacheDevice, + this.initProgressCallback, this.asyncOnServerLoad); } else { this.log("Closing the server, final state=" + this.state); } @@ -257,12 +271,15 @@ export class RPCServer { this.getImports(), this.logger ); + try { - const gpuDevice: GPUDevice | undefined | null = await detectGPUDevice(); - if (gpuDevice !== undefined && gpuDevice !== null) { - const label = gpuDevice.label?.toString() || "WebGPU"; + const output: GPUDeviceDetectOutput | undefined = await detectGPUDevice(); + if (output !== undefined) { + const label = "WebGPU: "+ output.adapterInfo.description; this.log("Initialize GPU device: " + label); - inst.initWebGPU(gpuDevice); + inst.initWebGPU(output.device); + } else { + this.log("Cannot find WebGPU device in the env"); } } catch (err) { this.log("Cannnot initialize WebGPU, " + err.toString()); @@ -270,10 +287,25 @@ export class RPCServer { this.inst = inst; // begin scope to allow handling of objects - // the object should stay alive during all sessions. this.inst.beginScope(); - const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer"); + if (this.initProgressCallback !== undefined) { + this.inst.registerInitProgressCallback(this.initProgressCallback); + } + if (this.ndarrayCacheUrl.length != 0) { + if (this.ndarrayCacheDevice == "cpu") { + await this.inst.fetchNDArrayCache(this.ndarrayCacheUrl, this.inst.cpu()); + } else { + assert(this.ndarrayCacheDevice == "webgpu"); + await this.inst.fetchNDArrayCache(this.ndarrayCacheUrl, this.inst.webgpu()); + } + } + + assert(this.inst !== undefined); + if (this.asyncOnServerLoad !== undefined) { + await this.asyncOnServerLoad(this.inst); + } + const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer"); const messageHandler = fcreate( (cbytes: Uint8Array): runtime.Scalar => { assert(this.inst !== undefined); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index a24459ca29a0..f3a6029bbe8e 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -25,11 +25,10 @@ import { Disposable } from "./types"; import { Memory, CachedCallStack } from "./memory"; import { assert, StringToUint8Array } from "./support"; import { Environment } from "./environment"; -import { WebGPUContext } from "./webgpu"; +import { FunctionInfo, WebGPUContext } from "./webgpu"; import * as compact from "./compact"; import * as ctypes from "./ctypes"; -import { tsImportEqualsDeclaration } from "@babel/types"; /** * Type for PackedFunc inthe TVMRuntime. @@ -144,6 +143,12 @@ class RuntimeContext implements Disposable { arrayGetSize : PackedFunc; arrayMake : PackedFunc; getSysLib: PackedFunc; + arrayCacheGet: PackedFunc; + arrayCacheUpdate: PackedFunc; + arrayCacheRemove: PackedFunc; + arrayCacheClear: PackedFunc; + arrayDecodeStorage: PackedFunc; + paramModuleFromCache: PackedFunc; private autoDisposeScope: Array> = []; @@ -152,12 +157,27 @@ class RuntimeContext implements Disposable { this.arrayGetSize = getGlobalFunc("runtime.ArraySize"); this.arrayMake = getGlobalFunc("runtime.Array"); this.getSysLib = getGlobalFunc("runtime.SystemLib"); + this.arrayCacheGet = getGlobalFunc("tvmjs.ndarray_cache.get"); + this.arrayCacheRemove = getGlobalFunc("tvmjs.ndarray_cache.remove"); + this.arrayCacheUpdate = getGlobalFunc("tvmjs.ndarray_cache.update"); + this.arrayCacheClear = getGlobalFunc("tvmjs.ndarray_cache.clear"); + this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage"); + this.paramModuleFromCache = getGlobalFunc("tvmjs.param_module_from_cache"); + } dispose(): void { + // call array cache clear to clear all cached items + this.arrayCacheClear(); this.arrayGetItem.dispose(); this.arrayGetSize.dispose(); this.arrayMake.dispose(); + this.arrayCacheGet.dispose(); + this.arrayCacheRemove.dispose(); + this.arrayCacheUpdate.dispose(); + this.arrayCacheClear.dispose(); + this.arrayDecodeStorage.dispose(); + this.paramModuleFromCache.dispose(); } beginScope() : void { @@ -393,7 +413,7 @@ export class NDArray implements Disposable { /** Device of the array. */ device: DLDevice; /** Whether it is a temporary view that can become invalid after the call. */ - private isView: boolean; + isView: boolean; private byteOffset: number; private dltensor: Pointer; private dataPtr: Pointer; @@ -462,6 +482,18 @@ export class NDArray implements Disposable { return this.handle; } + /** + * Get dataPtr of NDarray + * + * @returns The handle. + */ + getDataPtr(): Pointer { + if (this.handle == 0) { + throw Error("NDArray has already been disposed"); + } + return this.dataPtr; + } + dispose(): void { if (this.handle != 0 && !this.isView) { this.lib.checkCall( @@ -522,6 +554,12 @@ export class NDArray implements Disposable { * @returns this */ copyFromRawBytes(data: Uint8Array): this { + // short cut for gpu copy + if (this.device.deviceType == DeviceStrToEnum.webgpu) { + this.lib.webGPUContext?.copyRawBytesToBuffer(data, this.getDataPtr(), 0, data.length); + return this; + } + // CPU copy const size = this.shape.reduce((a, b) => { return a * b; }, 1); @@ -552,7 +590,7 @@ export class NDArray implements Disposable { */ toRawBytes(): Uint8Array { if (this.device.deviceType != DeviceStrToEnum.cpu) { - throw new Error("Can only synchronize copy for GPU array, use copyfrom instead."); + throw new Error("Can only sync copy CPU array, use cpu_arr.copyfrom(gpu_arr) then sync instead."); } const size = this.shape.reduce((a, b) => { return a * b; @@ -648,9 +686,10 @@ export class Module implements Disposable { /** * Get a function in the module. * @param name The name of the function. + * @param queryImports Whether to also query imports * @returns The result function. */ - getFunction(name: string): PackedFunc { + getFunction(name: string, queryImports: boolean = true): PackedFunc { if (this.handle == 0) { throw Error("Module has already been disposed"); } @@ -667,7 +706,7 @@ export class Module implements Disposable { (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)( this.getHandle(), stack.ptrFromOffset(nameOffset), - 1, + queryImports? 1 : 0, outPtr ) ); @@ -806,11 +845,82 @@ export class TVMArray extends TVMObject { } } +export const enum VMAllocatorKind { + NAIVE_ALLOCATOR = 1, + POOLED_ALLOCATOR = 2, +} + +/** + * VirtualMachine Executor. + * + * This is a thin wrapper of the underlying TVM module. + * you can also directly call set_input, run, and get_output + * of underlying module functions + */ +export class VirtualMachine implements Disposable { + private mod: Module; + /** + * Constructor + * @param mod The underlying module, need to be detached. + * @param device The main device ro run VM on. + */ + constructor(mod: Module, device: DLDevice) { + this.mod = mod; + this.mod.getFunction("vm_initialization")( + new Scalar(device.deviceType, "int"), + new Scalar(device.deviceId, "int"), + new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int") + ); + } + + dispose(): void { + this.mod.dispose(); + } + /** + * Get a function in the VM module. + * @param name The name of the function. + * @returns The result function. + */ + getFunction(name: string): PackedFunc { + return this.mod.getFunction(name); + } + + /** + * Get the internal module. + */ + getInternalModule(): Module { + return this.mod; + } +} + /** Code used as the first argument of the async callback. */ const enum AyncCallbackCode { kReturn = 4, kException = 5, } +export interface NDArrayCacheEntry { + name: string; + shape: Array; + dtype: string; + format: "f32-to-bf16" | "raw"; + byteOffset: number; + nbytes: number; +} + +export interface NDArrayShardEntry { + dataPath: string; + format: "raw-shard"; + nbytes: number; + records: Array; +} + +export interface InitProgressReport { + progress: number; + timeElapsed: number; + text: string; +} + +export type InitProgressCallback = (report: InitProgressReport) => void; /** * TVM runtime instance. @@ -832,10 +942,12 @@ const enum AyncCallbackCode { export class Instance implements Disposable { memory: Memory; exports: Record; + cacheMetadata: Record = {}; private lib: FFILibrary; private env: Environment; private objFactory: Map; private ctx: RuntimeContext; + private initProgressCallback: Array = []; /** * Internal function(registered by the runtime) @@ -872,7 +984,6 @@ export class Instance implements Disposable { env = new Environment(importObject); wasmInstance = new WebAssembly.Instance(wasmModule, env.imports); } - env.start(wasmInstance); this.env = env; this.lib = new FFILibrary(wasmInstance, env.imports); @@ -898,33 +1009,45 @@ export class Instance implements Disposable { * @number The number of times to compute the average. * @repeat The number of times to repeat the run. */ - async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=4): Promise { - // Skip first run as it can involve GPU warmup and module loading time. - const perf = compact.getPerformance(); - const results = []; + async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=1): Promise { + // Skip first run as it can involve GPU warmup and module loading time. + const perf = compact.getPerformance(); + const results = []; - // run with new scope - this.withNewScope(run); - await dev.sync(); + // run with new scope + this.withNewScope(run); + await dev.sync(); - for (let k = 0; k < repeat; ++k) { - const tstart = perf.now(); - for (let i = 0; i < number; ++i) { - this.withNewScope(run); - } - await dev.sync(); - const tend = perf.now(); - results.push((tend - tstart) / number); + for (let k = 0; k < repeat; ++k) { + const tstart = perf.now(); + for (let i = 0; i < number; ++i) { + this.withNewScope(run); } - return results; + await dev.sync(); + const tend = perf.now(); + results.push((tend - tstart) / number); } + return results; + } dispose(): void { + // dispose canvas resource + this.lib.webGPUContext?.disposeCanvas(); // order matters // ctx release goes back into lib. this.ctx.dispose(); this.lib.dispose(); } + /** + * Obtain the runtime information in readable format. + */ + runtimeStatsText(): string { + if (this.lib.webGPUContext !== undefined) { + return this.lib.webGPUContext.runtimeStatsText(); + } else { + return ""; + } + } /** * Begin a new scope for tracking object disposal. @@ -937,7 +1060,7 @@ export class Instance implements Disposable { * End a scope and release all created TVM objects * under the current scope. * - * Exception: one can call retainToParentScope to move + * Exception: one can call {@link moveToParentScope} to move * a value to parent scope. */ endScope(): void { @@ -951,7 +1074,7 @@ export class Instance implements Disposable { * @returns The result value. * * @note For action to return a valid value, - * we will need to call {@link retainToParentScope} + * we will need to call {@link moveToParentScope} * for the objects that are created in the scope. */ withNewScope(action: ()=>T): T { @@ -1131,9 +1254,9 @@ export class Instance implements Disposable { * @param func Input function. * @returns The converted function. */ - toPackedFunc(func: Function): PackedFunc { - return this.toPackedFuncInternal(func, true); - } + toPackedFunc(func: Function): PackedFunc { + return this.toPackedFuncInternal(func, true); + } private toPackedFuncInternal(func: Function, autoAttachToScope: boolean): PackedFunc { if (this.isPackedFunc(func)) return func as PackedFunc; @@ -1142,6 +1265,202 @@ export class Instance implements Disposable { return ret; } + /** + * Setup a virtual machine module with given device. + * + * @param dev DLDevice the device. + * @returns The created virtual machime. + */ + createVirtualMachine(dev: DLDevice): VirtualMachine { + const mod = this.ctx.detachFromCurrentScope( + this.systemLib().getFunction("vm_load_executable")() + ); + return this.ctx.attachToCurrentScope( + new VirtualMachine(mod, dev) + ); + } + + //----------------------------------------------- + // Native NDArray Cache Support + //----------------------------------------------- + /** + * Register a call back for fetch progress. + * + * @param cb the fetch progress callback. + */ + registerInitProgressCallback(cb: InitProgressCallback) { + this.initProgressCallback.push(cb); + } + + /** + * Get parameters in the form of prefix_i + * + * @param prefix The parameter prefix. + * @param numParams Number of parameters. + * @returns + */ + getParamsFromCache(prefix: string, numParams: number) : TVMObject { + return (this.ctx.paramModuleFromCache( + prefix, new Scalar(numParams, "int32")) as Module).getFunction("get_params")(); + } + + /** + * Get NDArray from cache. + * @param name The name of array. + * @returns The result. + */ + ndarrayCacheGet(name: string) : NDArray | undefined { + return this.ctx.arrayCacheGet(name); + } + + /** + * Get NDArray from cache. + * @param name The name of array. + * @returns The result. + */ + ndarrayCacheRemove(name: string) : NDArray | undefined { + return this.ctx.arrayCacheRemove(name); + } + + /** + * Update the ndarray cache. + * @param name The name of the array. + * @param arr The content. + */ + ndarrayCacheUpdate(name: string, arr: NDArray, override: boolean = false) { + this.ctx.arrayCacheUpdate(name, arr, this.scalar(override ? 1 : 0, "int32")); + } + + /** + * Update the ndarray cache. + * @param name The name of the array. + * @param arr The content. + */ + ndarrayCacheClear() { + this.ctx.arrayCacheClear(); + } + + /** + * Fetch NDArray cache from url. + * + * @param ndarrayCacheUrl The cache url. + * @param device The device to be fetched to. + * @returns The meta data + */ + async fetchNDArrayCache(ndarrayCacheUrl: string, device: DLDevice) : Promise { + const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href; + var list; + try { + list = await (await fetch(jsonUrl)).json(); + } catch(err) { + this.env.logger("Cannot fetch " + jsonUrl); + } + await this.fetchNDArrayCacheInternal( + ndarrayCacheUrl, + list["records"] as Array, device); + this.cacheMetadata = { ...this.cacheMetadata, ...(list["metadata"] as Record) }; + } + + /** + * Fetch list of NDArray into the NDArrayCache. + * + * @param ndarrayCacheUrl The cache url. + * @param list The list of array data. + * @param device The device to store the data to. + */ + private async fetchNDArrayCacheInternal(ndarrayCacheUrl: string, list: Array, device: DLDevice) { + const perf = compact.getPerformance(); + let tstart = perf.now(); + + let totalBytes = 0; + for (let i = 0; i < list.length; ++i) { + totalBytes += list[i].nbytes; + }; + let fetchedBytes = 0; + let timeElapsed = 0; + + const reportCallback = (iter: number)=> { + // report + for (let j = 0; j < this.initProgressCallback.length; ++j) { + let text = "Fetching param cache[" + iter + "/" + list.length+ "]: "; + text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched. " + text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, " + text += timeElapsed + " secs elapsed."; + text += " It can take a while when we first visit this page to populate the cache." + text += " Later refreshes will become faster."; + this.initProgressCallback[j]({ + progress: fetchedBytes / totalBytes, + timeElapsed: timeElapsed, + text: text + }); + } + }; + + for (let j = 0; j < this.initProgressCallback.length; ++j) { + this.initProgressCallback[j]({ + progress: fetchedBytes / totalBytes, + timeElapsed: 0, + text: "Start to fetch params", + }); + } + const cache = await caches.open("tvmjs"); + + for (let i = 0; i < list.length; ++i) { + reportCallback(i); + fetchedBytes += list[i].nbytes; + const dataUrl = new URL(list[i].dataPath, ndarrayCacheUrl).href; + const request = new Request(dataUrl); + let buffer; + try { + // use native cache + let result = await cache.match(request); + if (result === undefined) { + await cache.add(request); + result = await cache.match(request); + } + if (result == undefined) { + this.env.logger("Error: Cannot cache " + dataUrl + ", reloading will be slow"); + result = await fetch(request); + } + buffer = await result.arrayBuffer(); + } catch (err) { + this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err); + throw err; + } + const shardRecords = list[i].records; + for (let j = 0; j < shardRecords.length; ++j) { + const rec = shardRecords[j]; + const cpu_arr = this.withNewScope(() => { + return this.detachFromCurrentScope( + this.empty(rec.shape, rec.dtype, this.cpu()) + ) + }); + const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes); + // first sync copy to cpu. + this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format); + // then async stream into GPU if needed + if (device.deviceType == DeviceStrToEnum.cpu) { + this.ndarrayCacheUpdate(rec.name, cpu_arr, false); + cpu_arr.dispose(); + } else { + // allocate a gpu arr and async copy to it. + const gpu_arr = this.withNewScope(() => { + return this.detachFromCurrentScope( + this.empty(rec.shape, rec.dtype, device) + ) + }); + gpu_arr.copyFrom(cpu_arr); + await device.sync(); + this.ndarrayCacheUpdate(rec.name, gpu_arr, false); + cpu_arr.dispose(); + gpu_arr.dispose(); + } + } + timeElapsed = Math.ceil((perf.now() - tstart) / 1000); + } + reportCallback(list.length); + } + /** * Convert dtype to {@link DLDataType} * @@ -1269,6 +1588,72 @@ export class Instance implements Disposable { return ret; } + /** + * Create am uniform {@link NDArray} with given shape. + * + * @param shape The shape of the array. + * @param low The low value. + * @param high The high value. + * @param dev The device of the ndarray. + * @returns The created ndarray. + */ + uniform( + shape: Array, + low: number, + high: number, + dev: DLDevice + ): NDArray { + const ret = this.empty(shape, "float32", dev); + const size = shape.reduce((a, b) => { + return a * b; + }, 1); + const scale = high - low; + const input = new Float32Array(size); + for (let i = 0; i < input.length; ++i) { + input[i] = low + Math.random() * scale; + } + return ret.copyFrom(input); + } + + /** + * Bind canvas to the current WebGPU context + * @param canvas The canvas. + */ + bindCanvas(canvas: HTMLCanvasElement) { + this.lib.webGPUContext?.bindCanvas(canvas); + } + + /** + * Show image in canvas. + * + * @param dataRGBA Image array in height x width uint32 NDArray RGBA format on GPU. + */ + showImage(dataRGBA: NDArray) { + if (dataRGBA.shape.length != 2) { + throw Error("Require a height x width uint32 NDArray in RGBA" + + "get shape=" + dataRGBA.shape.toString() + " instead." + ); + } + if (dataRGBA.device.deviceType != DeviceStrToEnum.webgpu) { + throw new Error("Can only run showImage on WebGPU array, " + + "get " + DeviceEnumToStr[dataRGBA.device.deviceType] + " instead."); + } + if (dataRGBA.dtype != "uint32") { + throw Error("Require a height x width uint32 NDArray in RGBA, " + + "get " + dataRGBA.dtype + " instead."); + } + this.lib.webGPUContext?.drawImageFromBuffer( + dataRGBA.getDataPtr(), dataRGBA.shape[0], dataRGBA.shape[1] + ); + } + + /** + * Clear canvas + */ + clearCanvas() { + this.lib.webGPUContext?.clearCanvas(); + } + /** * Create an tuple {@link TVMArray} input array. * @@ -1356,6 +1741,71 @@ export class Instance implements Disposable { this.registerFunc("__async." + name, asyncVariant, override); } + /** + * Asynchrously load webgpu pipelines when possible. + * @param mod The input module. + */ + async asyncLoadWebGPUPiplines(mod: Module): Promise { + if (this.lib.webGPUContext == undefined) throw Error("WebGPU not initialied"); + const webgpuContext = this.lib.webGPUContext; + + this.beginScope(); + const fmap_str = mod.getFunction("webgpu.get_fmap", true)() as string; + let fmap: Record = JSON.parse(fmap_str); + const totalFuncs = fmap.length; + const fGetShader = this.detachFromCurrentScope( + mod.getFunction("webgpu.get_shader") + ); + const fUpdatePrebuild = this.detachFromCurrentScope( + mod.getFunction("webgpu.update_prebuild") + ); + this.endScope(); + + const perf = compact.getPerformance(); + const tstart = perf.now(); + let tlastReport = tstart; + let finishCounter = 0; + const fmapEntries = Object.entries(fmap); + + let allEvents = Promise.resolve(); + + for (const [key, finfo] of fmapEntries) { + const code = fGetShader(key); + assert(key == finfo.name); + const event = webgpuContext.createShaderAsync(finfo, code).then((func: Function) => { + this.beginScope(); + fUpdatePrebuild(key, func); + this.endScope(); + + }).then(() => { + finishCounter += 1; + const tend = perf.now(); + const timeReportGap = 1000; + // skip report if gap is smaller than 1000 + if ((tend - tlastReport) < 1000 && finishCounter != fmapEntries.length) { + return; + } + tlastReport = tend; + const timeElapsed = Math.ceil((perf.now() - tstart) / 1000); + // report + for (let j = 0; j < this.initProgressCallback.length; ++j) { + const progress = finishCounter / fmapEntries.length; + let text = "Loading GPU shader modules[" + finishCounter + "/" + fmapEntries.length+ "]: "; + text += Math.floor(progress * 100).toString() + "% completed, " + text += timeElapsed + " secs elapsed."; + this.initProgressCallback[j]({ + progress: progress, + timeElapsed: timeElapsed, + text: text + }); + } + }); + allEvents = Promise.all([allEvents, event]).then(()=>{}); + } + await allEvents; + assert(finishCounter == fmapEntries.length); + } + /** * Initialize webgpu in the runtime. * @param device The given GPU device. @@ -1368,7 +1818,8 @@ export class Instance implements Disposable { return webGPUContext.getDeviceAPI(name); }); this.registerFunc("wasm.WebGPUCreateShader", (info: string, code: string) => { - return webGPUContext.createShader(info, code); + const finfo = JSON.parse(info) as FunctionInfo; + return webGPUContext.createShader(finfo, code); }); this.registerAsyncServerFunc("wasm.WebGPUWaitForTasks", async () => { await webGPUContext.sync(); @@ -1500,8 +1951,13 @@ export class Instance implements Disposable { const valueOffset = argsValue + i * SizeOf.TVMValue; const codeOffset = argsCode + i * SizeOf.I32; if (val instanceof NDArray) { - stack.storePtr(valueOffset, val.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle); + if (!val.isView) { + stack.storePtr(valueOffset, val.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle); + } else { + stack.storePtr(valueOffset, val.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMDLTensorHandle); + } } else if (val instanceof Scalar) { if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { stack.storeI64(valueOffset, val.value); diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index faf6fac990c8..ac39595c7662 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -20,23 +20,256 @@ import "@webgpu/types"; import { assert } from "./support"; import { Pointer } from "./ctypes"; import { Memory } from "./memory"; +import { Disposable } from "./types"; /** A pointer to points to the raw address space. */ export type GPUPointer = number; +export interface GPUDeviceDetectOutput { + adapter: GPUAdapter; + adapterInfo: GPUAdapterInfo; + device: GPUDevice; +} + /** * DetectGPU device in the environment. */ -export async function detectGPUDevice(): Promise { +export async function detectGPUDevice(): Promise { if (typeof navigator !== "undefined" && navigator.gpu !== undefined) { const adapter = await navigator.gpu.requestAdapter(); - return await adapter?.requestDevice(); + if (adapter == null) { + throw Error("Cannot find adapter that matches the request"); + } + const adapterInfo = await adapter.requestAdapterInfo(); + const device = await adapter.requestDevice({ + requiredLimits: { + maxBufferSize: 1 << 30, + maxStorageBufferBindingSize: 1 << 30, + maxComputeWorkgroupStorageSize: 32 << 10, + } + }); + return { + adapter: adapter, + adapterInfo: adapterInfo, + device: device + }; } else { return undefined; } } -interface FunctionInfo { +const canvasRenderWGSL =` +@group(0) @binding(0) var my_sampler : sampler; +@group(0) @binding(1) var my_texture : texture_2d; + +struct VertexOutput { + @builtin(position) position : vec4, + @location(0) uv : vec2, +} + +@vertex +fn vertex_main(@builtin(vertex_index) vidx : u32) -> VertexOutput { + const pos = array( + vec2( 1.0, 1.0), + vec2( 1.0, -1.0), + vec2(-1.0, -1.0), + vec2( 1.0, 1.0), + vec2(-1.0, -1.0), + vec2(-1.0, 1.0), + ); + + const uv = array( + vec2(1.0, 0.0), + vec2(1.0, 1.0), + vec2(0.0, 1.0), + vec2(1.0, 0.0), + vec2(0.0, 1.0), + vec2(0.0, 0.0), + ); + + var output : VertexOutput; + output.position = vec4(pos[vidx], 0.0, 1.0); + output.uv = uv[vidx]; + return output; +} + +@fragment +fn fragment_main(@location(0) uv : vec2) -> @location(0) vec4 { + return textureSample(my_texture, my_sampler, uv); +} + +@fragment +fn fragment_clear(@location(0) uv : vec2) -> @location(0) vec4 { + return vec4(1.0, 1.0, 1.0, 1.0); +} +` +class CanvaRenderManager implements Disposable { + private device: GPUDevice; + private canvasContext: GPUCanvasContext; + private stagingTexture: GPUTexture; + private renderSampler: GPUSampler; + private renderPipeline: GPURenderPipeline; + private clearPipeline: GPURenderPipeline; + private canvasTextureFormat: GPUTextureFormat; + + constructor(device: GPUDevice, canvas: HTMLCanvasElement) { + this.device = device; + const ctx = canvas.getContext("webgpu"); + if (ctx == null) { + throw Error("Cannot bind WebGPU context"); + } + this.canvasContext = ctx; + this.canvasTextureFormat = navigator.gpu.getPreferredCanvasFormat(); + this.canvasContext.configure({ + device: this.device, + format: this.canvasTextureFormat, + alphaMode: "opaque", + }); + + this.renderPipeline = device.createRenderPipeline({ + layout: "auto", + vertex: { + module: device.createShaderModule({ + code: canvasRenderWGSL, + }), + entryPoint: "vertex_main", + }, + fragment: { + module: device.createShaderModule({ + code: canvasRenderWGSL, + }), + entryPoint: "fragment_main", + targets: [{ + format: this.canvasTextureFormat, + }], + }, + primitive: { + topology: "triangle-list", + }, + }); + + this.clearPipeline = device.createRenderPipeline({ + layout: "auto", + vertex: { + module: device.createShaderModule({ + code: canvasRenderWGSL, + }), + entryPoint: "vertex_main", + }, + fragment: { + module: device.createShaderModule({ + code: canvasRenderWGSL, + }), + entryPoint: "fragment_clear", + targets: [{ + format: this.canvasTextureFormat, + }], + }, + primitive: { + topology: "triangle-list", + }, + }); + + this.renderSampler = device.createSampler({ + magFilter: "linear", + minFilter: "linear", + }); + // staging texture always be in RGBA + this.stagingTexture = device.createTexture({ + size: [canvas.height, canvas.width, 1], + format: "rgba8unorm", + usage: + GPUTextureUsage.TEXTURE_BINDING | + GPUTextureUsage.COPY_DST | + GPUTextureUsage.RENDER_ATTACHMENT, + }); + } + + clear() { + const commandEncoder = this.device.createCommandEncoder(); + const passEncoder = commandEncoder.beginRenderPass({ + colorAttachments: [ + { + view: this.canvasContext.getCurrentTexture().createView(), + clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1.0 }, + loadOp: "clear", + storeOp: "store", + }, + ], + }); + passEncoder.setPipeline(this.clearPipeline); + const renderBindingGroup = this.device.createBindGroup({ + layout: this.renderPipeline.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: this.renderSampler }, + { binding: 1, resource: this.stagingTexture.createView() }, + ], + }); + passEncoder.setBindGroup(0, renderBindingGroup); + passEncoder.draw(6, 1, 0, 0); + passEncoder.end(); + this.device.queue.submit([commandEncoder.finish()]); + } + + draw(buffer: GPUBuffer, height: number, width: number) { + // resize the staging texture + if (height != this.stagingTexture.height || width != this.stagingTexture.width) { + this.stagingTexture.destroy(); + this.stagingTexture = this.device.createTexture({ + size: [height, width, 1], + format: "rgba8unorm", + usage: + GPUTextureUsage.TEXTURE_BINDING | + GPUTextureUsage.COPY_DST | + GPUTextureUsage.RENDER_ATTACHMENT, + }); + } + + const commandEncoder = this.device.createCommandEncoder(); + commandEncoder.copyBufferToTexture({ + buffer: buffer, + offset: 0, + bytesPerRow: this.stagingTexture.width * 4 + }, { + texture: this.stagingTexture + },{ + width: this.stagingTexture.width, + height: this.stagingTexture.height + }); + + const passEncoder = commandEncoder.beginRenderPass({ + colorAttachments: [ + { + view: this.canvasContext.getCurrentTexture().createView(), + clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1.0 }, + loadOp: "clear", + storeOp: "store", + }, + ], + }); + passEncoder.setPipeline(this.renderPipeline); + const renderBindingGroup = this.device.createBindGroup({ + layout: this.renderPipeline.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: this.renderSampler }, + { binding: 1, resource: this.stagingTexture.createView() }, + ], + }); + passEncoder.setBindGroup(0, renderBindingGroup); + passEncoder.draw(6, 1, 0, 0); + passEncoder.end(); + this.device.queue.submit([commandEncoder.finish()]); + } + + dispose() : void { + this.stagingTexture.destroy(); + } +} + +/** + * Function info from the API + */ +export interface FunctionInfo { name: string; arg_types: Array; launch_param_tags: Array; @@ -49,12 +282,24 @@ interface FunctionInfo { export class WebGPUContext { device: GPUDevice; memory: Memory; - - //private readBuffer:; + // internal data private bufferTable: Array = [undefined]; private bufferTableFreeId: Array = []; - private pendingRead: Promise = Promise.resolve(); - private numPendingReads = 0; + private canvasRenderManager?: CanvaRenderManager = undefined; + // flags for debugging + // stats of the runtime. + // peak allocation + private peakAllocatedBytes: number = 0; + // current allocation + private currAllocatedBytes: number = 0; + // all allocation(ignoring free) + private allAllocatedBytes: number = 0; + // shader submit counter + private shaderSubmitCounter: number = 0; + // limite number of shaders to be submitted, useful for debugging, default to -1 + protected debugShaderSubmitLimit: number = -1; + // log and sync each step + protected debugLogFinish: boolean = false; constructor(memory: Memory, device: GPUDevice) { this.memory = memory; @@ -65,56 +310,117 @@ export class WebGPUContext { * Wait for all pending GPU tasks to complete */ async sync(): Promise { - if (this.numPendingReads != 0) { - await Promise.all([ - this.device.queue.onSubmittedWorkDone(), - this.pendingRead - ]) - } else { - await this.device.queue.onSubmittedWorkDone() + await this.device.queue.onSubmittedWorkDone(); + } + + /** + * Dispose the binded canvas. + */ + disposeCanvas() { + this.canvasRenderManager?.dispose(); + this.canvasRenderManager = undefined; + } + + /** + * Obtain the runtime information in readable format. + */ + runtimeStatsText(): string { + let info = "peak-memory=" + Math.ceil(this.peakAllocatedBytes / (1 << 20)) + " MB"; + info += ", all-memory=" + Math.ceil(this.allAllocatedBytes / (1 << 20)) + " MB"; + info += ", shader-submissions=" + this.shaderSubmitCounter; + return info; + } + + /** + * Draw image from data in storage buffer. + * @param ptr The GPU ptr + * @param height The height of the image. + * @param width The width of the image. + */ + drawImageFromBuffer(ptr: GPUPointer, height: number, width: number) { + if (this.canvasRenderManager == undefined) { + throw Error("Do not have a canvas context, call bindCanvas first"); } + this.canvasRenderManager.draw(this.gpuBufferFromPtr(ptr), height, width); + } + + /** + * Copy raw bytes into buffer ptr. + * + * @param rawBytes The raw bytes + * @param toPtr The target gpu buffer ptr + * @param toOffset The beginning offset + * @param nbytes Number of bytes + */ + copyRawBytesToBuffer( + rawBytes: Uint8Array, + toPtr: GPUPointer, + toOffset: number, + nbytes: number + ): void { + // Perhaps it would be more useful to use a staging buffer? + this.device.queue.writeBuffer( + this.gpuBufferFromPtr(toPtr), + toOffset, + rawBytes, + 0, + nbytes + ); + } + /** + * Clear canvas + */ + clearCanvas() { + this.canvasRenderManager?.clear(); + } + + /** + * Bind a canvas element to the runtime. + * @param canvas The HTML canvas/ + */ + bindCanvas(canvas: HTMLCanvasElement) { + this.canvasRenderManager = new CanvaRenderManager(this.device, canvas); } /** * Create a PackedFunc that runs the given shader + * via createComputePipeline * - * @param info The function information in json. + * @param info The function information already parsed as a record. * @param code The shader data(in WGSL) + * @returns The shader */ - createShader(info: string, code: string): Function { - const finfo = JSON.parse(info); - const layoutEntries: Array = []; - for (let i = 0; i < finfo.arg_types.length; ++i) { - const dtype = finfo.arg_types[i]; - if (dtype == "handle") { - layoutEntries.push({ - binding: i, - visibility: GPUShaderStage.COMPUTE, - buffer : { - type: "storage" - } - }); - } else { - throw new Error("Cannot handle argument type " + dtype + " in WebGPU shader"); - } - } - const bindGroupLayout = this.device.createBindGroupLayout({ - entries: layoutEntries - }); + createShader(finfo: FunctionInfo, code: string) : Function { + return this.createShadeInternl(finfo, code, false) as Function; + } - const pipeline = this.device.createComputePipeline({ - layout: this.device.createPipelineLayout({ - bindGroupLayouts: [ bindGroupLayout ] - }), - compute: { - module: this.device.createShaderModule({ - code: code - }), - entryPoint: "main" - } - }); + /** + * Create a PackedFunc that runs the given shader asynchrously + * via createComputePipelineAsync + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @returns The shader + */ + async createShaderAsync(finfo: FunctionInfo, code: string) : Promise { + return await (this.createShadeInternl(finfo, code, true) as Promise); + } + /** + * Internal impl of createShader for both async and sync mode. + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @param asyncMode Whether use async mode. + * @returns The shader function or promise of shader func. + */ + private createShadeInternl( + finfo: FunctionInfo, + code: string, + asyncMode: boolean + ): Function | Promise { const dispatchToDim: Array = []; + let paramWriteAccess: Array = []; for (let i = 0; i < finfo.launch_param_tags.length; ++i) { const tag: string = finfo.launch_param_tags[i]; @@ -126,42 +432,134 @@ export class WebGPUContext { const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); assert(target >= 0 && target < 3); dispatchToDim.push(target + 3); + } else if (tag.startsWith("paramWriteAccess:")) { + paramWriteAccess = JSON.parse(tag.substring(17)); } else { throw new Error("Cannot handle thread_axis " + tag); } } - const submitShader = (...args: Array): void => { - const commandEncoder = this.device.createCommandEncoder(); - const compute = commandEncoder.beginComputePass(); - compute.setPipeline(pipeline); - const bindGroupEntries: Array = []; - assert(args.length == layoutEntries.length + dispatchToDim.length); + assert(paramWriteAccess.length == finfo.arg_types.length); - for (let i = 0; i < layoutEntries.length; ++i) { - bindGroupEntries.push({ + const layoutEntries: Array = []; + for (let i = 0; i < finfo.arg_types.length; ++i) { + const dtype = finfo.arg_types[i]; + if (dtype == "handle") { + layoutEntries.push({ binding: i, - resource: { - buffer: this.gpuBufferFromPtr(args[i]) + visibility: GPUShaderStage.COMPUTE, + buffer : { + type: paramWriteAccess[i] ? "storage" : "read-only-storage" } }); + } else { + throw new Error("Cannot handle argument type " + dtype + " in WebGPU shader"); } + } + const bindGroupLayout = this.device.createBindGroupLayout({ + entries: layoutEntries + }); + const pipelineLayout = this.device.createPipelineLayout({ + bindGroupLayouts: [ bindGroupLayout ] + }); - compute.setBindGroup(0, this.device.createBindGroup({ - layout: bindGroupLayout, - entries: bindGroupEntries - })); - const wl: Array = [1, 1, 1, 1, 1, 1]; - for (let i = 0; i < dispatchToDim.length; ++i) { - wl[dispatchToDim[i]] = args[layoutEntries.length + i]; - } - compute.dispatchWorkgroups(wl[0], wl[1], wl[2]) - compute.end() - const command = commandEncoder.finish(); - this.device.queue.submit([command]); + // Function to create the pipeline. + const createShaderFunc = (pipeline: GPUComputePipeline): Function => { + const submitShader = (...args: Array): void => { + if (this.debugShaderSubmitLimit != -1 && + this.shaderSubmitCounter >= this.debugShaderSubmitLimit) { + this.shaderSubmitCounter += 1; + return; + } + + const commandEncoder = this.device.createCommandEncoder(); + const compute = commandEncoder.beginComputePass(); + compute.setPipeline(pipeline); + const bindGroupEntries: Array = []; + assert(args.length == layoutEntries.length + dispatchToDim.length); + + for (let i = 0; i < layoutEntries.length; ++i) { + bindGroupEntries.push({ + binding: i, + resource: { + buffer: this.gpuBufferFromPtr(args[i]) + } + }); + } + + compute.setBindGroup(0, this.device.createBindGroup({ + layout: bindGroupLayout, + entries: bindGroupEntries + })); + const wl: Array = [1, 1, 1, 1, 1, 1]; + for (let i = 0; i < dispatchToDim.length; ++i) { + wl[dispatchToDim[i]] = args[layoutEntries.length + i]; + } + + // get around 65535 restriction of blockIdx.x + if (wl[2] != 1) { + throw Error("WebGPU: blockIdx.z is reserved for internal use"); + } + // spread thinsg out into blockIdx.z + if (wl[0] >= (1 << 16)) { + let wl_x = wl[0]; + let wl_z = wl[2]; + + while (wl_x >= (1 << 16)) { + if (wl_x % 2 != 0) { + throw Error("WebGPU: cannot factorize big gridDim.x=" + wl[0].toString()); + } + wl_x /= 2; + wl_z *= 2; + } + wl[0] = wl_x; + wl[2] = wl_z; + } + compute.dispatchWorkgroups(wl[0], wl[1], wl[2]) + compute.end() + const command = commandEncoder.finish(); + this.device.queue.submit([command]); + + if (this.debugLogFinish) { + const currCounter = this.shaderSubmitCounter; + this.device.queue.onSubmittedWorkDone().then(()=> { + console.log("["+ currCounter + "][Debug] finish shader" + finfo.name); + }); + } + this.shaderSubmitCounter += 1; + }; + return submitShader; }; - return submitShader; + const shaderModule = this.device.createShaderModule({ + code: code, + hints: { + main: { + layout: pipelineLayout + } + } + }); + + if (asyncMode) { + return this.device.createComputePipelineAsync({ + layout: pipelineLayout, + compute: { + module: shaderModule, + entryPoint: finfo.name + } + }).then((pipeline: GPUComputePipeline) => { + return createShaderFunc(pipeline); + }); + } else { + const pipeline = this.device.createComputePipeline({ + layout: pipelineLayout, + compute: { + module: shaderModule, + entryPoint: finfo.name + } + }); + return createShaderFunc(pipeline); + } } /** @@ -209,7 +607,6 @@ export class WebGPUContext { } else { throw new Error("Unknown DeviceAPI function " + name); } - } // DeviceAPI @@ -218,7 +615,13 @@ export class WebGPUContext { size: nbytes, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, }); - return this.attachToBufferTable(buffer); + this.currAllocatedBytes += nbytes; + this.allAllocatedBytes += nbytes; + if (this.currAllocatedBytes > this.peakAllocatedBytes) { + this.peakAllocatedBytes = this.currAllocatedBytes; + } + const ptr = this.attachToBufferTable(buffer); + return ptr; } private deviceFreeDataSpace(ptr: GPUPointer): void { @@ -227,6 +630,7 @@ export class WebGPUContext { this.bufferTable[idx] = undefined; assert(buffer !== undefined); this.bufferTableFreeId.push(idx); + this.currAllocatedBytes -= buffer.size; buffer.destroy(); } @@ -237,29 +641,14 @@ export class WebGPUContext { nbytes: number ): void { // Perhaps it would be more useful to use a staging buffer? - const gpuTemp = this.device.createBuffer({ - mappedAtCreation: true, - size: nbytes, - usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC - }); - - const cpuTemp = gpuTemp.getMappedRange(); - - const viewU8 = new Uint8Array(cpuTemp); - viewU8.set(this.memory.loadRawBytes(from, nbytes)); - gpuTemp.unmap(); - - const copyEncoder = this.device.createCommandEncoder(); - copyEncoder.copyBufferToBuffer( - gpuTemp, - 0, + const rawBytes = this.memory.loadRawBytes(from, nbytes); + this.device.queue.writeBuffer( this.gpuBufferFromPtr(to), toOffset, + rawBytes, + 0, nbytes ); - const copyCommands = copyEncoder.finish(); - this.device.queue.submit([copyCommands]); - gpuTemp.destroy(); } private deviceCopyFromGPU( @@ -285,24 +674,11 @@ export class WebGPUContext { const copyCommands = copyEncoder.finish(); this.device.queue.submit([copyCommands]); - this.numPendingReads += 1; - - const readEvent = gpuTemp.mapAsync(GPUMapMode.READ).then(() => { + gpuTemp.mapAsync(GPUMapMode.READ).then(() => { const data = gpuTemp.getMappedRange(); this.memory.storeRawBytes(to, new Uint8Array(data)); - this.numPendingReads -= 1; gpuTemp.destroy(); }); - - if (this.numPendingReads == 1) { - this.pendingRead = readEvent; - } else { - this.pendingRead = Promise.all([ - this.pendingRead, - readEvent, - // eslint-disable-next-line @typescript-eslint/no-empty-function - ]).then(() => {}); - } } private deviceCopyWithinGPU( diff --git a/web/tests/node/test_relax_vm.js b/web/tests/node/test_relax_vm.js new file mode 100644 index 000000000000..ceb47aa014ec --- /dev/null +++ b/web/tests/node/test_relax_vm.js @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* eslint-disable no-undef */ +// Load Emscripten Module, need to change path to root/lib +const path = require("path"); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "test_relax.wasm")); + +const tvm = new tvmjs.Instance( + new WebAssembly.Module(wasmSource), + new EmccWASI() +); + + +function randomArray(length, max) { + return Array.apply(null, Array(length)).map(function () { + return Math.random() * max; + }); +} + +test("add one", () => { + tvm.beginScope(); + // Load system library + const vm = tvm.createVirtualMachine(tvm.cpu()); + // grab pre-loaded function + const fadd = vm.getFunction("main"); + + assert(tvm.isPackedFunc(fadd)); + const n = 124; + const A = tvm.empty(n).copyFrom(randomArray(n, 1)); + const B = tvm.empty(n).copyFrom(randomArray(n, 1)); + + // call the function. + const C = fadd(A, B); + const AA = A.toArray(); // retrieve values in js array + const BB = B.toArray(); // retrieve values in js array + const CC = C.toArray(); // retrieve values in js array + // verify + for (var i = 0; i < BB.length; ++i) { + assert(Math.abs(CC[i] - (AA[i] + BB[i])) < 1e-5); + } + tvm.endScope(); + // assert auto release scope behavior + assert(vm.mod.getHandle(false) == 0); + assert(fadd._tvmPackedCell.getHandle(false) == 0); +}); diff --git a/web/tests/python/prepare_test_libs.py b/web/tests/python/prepare_test_libs.py index 5c1f7c68c421..a63e0655b45d 100644 --- a/web/tests/python/prepare_test_libs.py +++ b/web/tests/python/prepare_test_libs.py @@ -18,12 +18,32 @@ import tvm from tvm import te -from tvm.contrib import emcc +from tvm.contrib import tvmjs from tvm.relay.backend import Runtime +from tvm import relax +from tvm.script import relax as R import os -def prepare_test_libs(base_path): +def prepare_relax_lib(base_path): + pipeline = relax.get_pipeline() + + @tvm.script.ir_module + class Mod: + @R.function + def main(x: R.Tensor(["n"], "float32"), y: R.Tensor(["n"], "float32")): + lv0 = R.add(x, y) + return lv0 + + target = tvm.target.Target("llvm -mtriple=wasm32-unknown-unknown-wasm") + + mod = pipeline(Mod) + ex = relax.build(mod, target) + wasm_path = os.path.join(base_path, "test_relax.wasm") + ex.export_library(wasm_path, tvmjs.create_tvmjs_wasm) + + +def prepare_tir_lib(base_path): runtime = Runtime("cpp", {"system-lib": True}) target = "llvm -mtriple=wasm32-unknown-unknown-wasm" if not tvm.runtime.enabled(target): @@ -35,9 +55,11 @@ def prepare_test_libs(base_path): fadd = tvm.build(s, [A, B], target, runtime=runtime, name="add_one") wasm_path = os.path.join(base_path, "test_addone.wasm") - fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + fadd.export_library(wasm_path, tvmjs.create_tvmjs_wasm) if __name__ == "__main__": curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - prepare_test_libs(os.path.join(curr_path, "../../dist/wasm")) + base_path = os.path.join(curr_path, "../../dist/wasm") + prepare_tir_lib(base_path) + prepare_relax_lib(base_path) diff --git a/web/tests/python/relax_rpc_test.py b/web/tests/python/relax_rpc_test.py new file mode 100644 index 000000000000..a347fe70b345 --- /dev/null +++ b/web/tests/python/relax_rpc_test.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test relax vm through rpc.""" + +import tvm +import numpy as np +from tvm import rpc, relax +from tvm.contrib import utils, tvmjs +from tvm.script import relax as R + +proxy_host = "127.0.0.1" +proxy_port = 9090 + + +def get_model(): + pipeline = relax.get_pipeline() + + @tvm.script.ir_module + class Mod: + @R.function + def main(x: R.Tensor([1024], "float32"), y: R.Tensor([1024], "float32")): + lv0 = R.add(x, y) + return lv0 + + mod = pipeline(Mod) + sch = tvm.tir.Schedule(mod) + # manually transform loop + sch.work_on("add") + (i,) = sch.get_loops(block=sch.get_block("T_add")) + i0, i1 = sch.split(i, [None, 128]) + sch.bind(i0, "blockIdx.x") + sch.bind(i1, "threadIdx.x") + return sch.mod + + +def test_rpc(): + if not tvm.runtime.enabled("rpc"): + return + n = 1024 + dtype = "float32" + temp = utils.tempdir() + wasm_path = temp.relpath("relax.wasm") + target = tvm.target.Target("webgpu", host="llvm -mtriple=wasm32-unknown-unknown-wasm") + + mod = get_model() + ex = relax.build(mod, target) + ex.export_library(wasm_path, tvmjs.create_tvmjs_wasm) + wasm_binary = open(wasm_path, "rb").read() + + remote = rpc.connect( + proxy_host, + proxy_port, + key="wasm", + session_constructor_args=["rpc.WasmSession", wasm_binary], + ) + + def check(remote): + dev = remote.webgpu(0) + # invoke the function + vm = relax.VirtualMachine(remote.system_lib(), device=dev) + adata = np.random.uniform(size=n).astype(dtype) + bdata = np.random.uniform(size=n).astype(dtype) + a = tvm.nd.array(adata, dev) + b = tvm.nd.array(bdata, dev) + vm.set_input("main", a, b) + vm.invoke_stateful("main") + c = vm.get_outputs("main") + np.testing.assert_equal(c.numpy(), a.numpy() + b.numpy()) + + check(remote) + + +test_rpc() diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py index 6e34a8a2b36c..986393e9d41d 100644 --- a/web/tests/python/webgpu_rpc_test.py +++ b/web/tests/python/webgpu_rpc_test.py @@ -23,7 +23,7 @@ import tvm from tvm import te from tvm import rpc -from tvm.contrib import utils, emcc +from tvm.contrib import utils, tvmjs from tvm.relay.backend import Runtime import numpy as np @@ -52,7 +52,7 @@ def test_rpc(): temp = utils.tempdir() wasm_path = temp.relpath("addone_gpu.wasm") - fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + fadd.export_library(wasm_path, tvmjs.create_tvmjs_wasm) wasm_binary = open(wasm_path, "rb").read() remote = rpc.connect( diff --git a/web/tests/python/websock_rpc_test.py b/web/tests/python/websock_rpc_test.py index 7de5ee956ec8..19d5dc57480c 100644 --- a/web/tests/python/websock_rpc_test.py +++ b/web/tests/python/websock_rpc_test.py @@ -23,7 +23,7 @@ import tvm from tvm import te from tvm import rpc -from tvm.contrib import utils, emcc +from tvm.contrib import utils, tvmjs from tvm.relay.backend import Runtime import numpy as np @@ -48,7 +48,7 @@ def test_rpc(): temp = utils.tempdir() wasm_path = temp.relpath("addone.wasm") - fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + fadd.export_library(wasm_path, tvmjs.create_tvmjs_wasm) wasm_binary = open(wasm_path, "rb").read()