diff --git a/CMakeLists.txt b/CMakeLists.txt index a4cfa0b59ad8..b499edd4560f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -112,6 +112,10 @@ if(MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj") + + # MSVC already errors on undefined symbols, no additional flag needed. + set(TVM_NO_UNDEFINED_SYMBOLS "") + if(USE_MSVC_MT) foreach(flag_var CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE @@ -165,6 +169,16 @@ else(MSVC) set(CMAKE_CXX_FLAGS "-faligned-new ${CMAKE_CXX_FLAGS}") endif() + # ld option to warn if symbols are undefined (e.g. libtvm_runtime.so + # using symbols only present in libtvm.so). Not needed for MSVC, + # since this is already the default there. + if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(TVM_NO_UNDEFINED_SYMBOLS "-Wl,-undefined,error") + else() + set(TVM_NO_UNDEFINED_SYMBOLS "-Wl,--no-undefined") + endif() + message(STATUS "Forbidding undefined symbols in shared library, using ${TVM_NO_UNDEFINED_SYMBOLS} on platform ${CMAKE_SYSTEM_NAME}") + # Detect if we're compiling for Hexagon. set(TEST_FOR_HEXAGON_CXX "#ifndef __hexagon__" @@ -414,6 +428,7 @@ add_library(tvm_objs OBJECT ${COMPILER_SRCS}) add_library(tvm_runtime_objs OBJECT ${RUNTIME_SRCS}) add_library(tvm SHARED $ $) +set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}") set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}") if(BUILD_STATIC_RUNTIME) add_library(tvm_runtime STATIC $) @@ -425,6 +440,7 @@ if(BUILD_STATIC_RUNTIME) COMMAND ${CMAKE_COMMAND} -E cmake_echo_color --yellow --bold ${NOTICE}) else() add_library(tvm_runtime SHARED $) + set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}") endif() set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}") target_compile_definitions(tvm_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 6ebb0575d72b..ad4e2dee4331 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -42,32 +42,33 @@ We do encourage everyone to work anything they are interested in. - [Aditya Atluri](https://github.com/adityaatluri): @adityaatluri - rocm - [Matthew Barrett](https://github.com/mbaret): @mbaret - byoc, arm - [Matthew Brookhart](https://github.com/mbrookhart): @mbrookhart - relay, frontends -- [Tianqi Chen](https://github.com/tqchen) (PMC): @tqchen - topi, compiler, relay, docs - [Liangfu Chen](https://github.com/liangfu): @liangfu - vta, chisel, intel FPGA, c runtime +- [Tianqi Chen](https://github.com/tqchen) (PMC): @tqchen - topi, compiler, relay, docs - [Wei Chen](https://github.com/wweic): @wweic - runtime, relay, vm - [Zhi Chen](https://github.com/zhiics) (PMC): @zhiics - relay, quantization, pass manager -- [Chenfan](https://github.com/jcf94): @jcf94 - auto_scheduler - [Josh Fromm](https://github.com/jwfromm): @jwfromm - frontends, quantization, topi - [Yuwei Hu](https://github.com/Huyuwei): @Huyuwei - topi, frontends - [Nick Hynes](https://github.com/nhynes): @nhynes: - sgx, rust - [Animesh Jain](https://github.com/anijain2305): @anijain2305 - quantization, relay +- [Chenfan Jia](https://github.com/jcf94): @jcf94 - auto_scheduler - [Ziheng Jiang](https://github.com/ZihengJiang) (PMC): @ZihengJiang - relay, compiler - [Marisa Kirisame](https://github.com/MarisaKirisame): @MarisaKirisame - relay - [Wuwei Lin](https://github.com/vinx13): @vinx13 - relay, topi - [Yizhi Liu](https://github.com/yzhliu) (PMC): @yzhliu - jvm, topi, relay -- [Steven Lyubomirsky](https://github.com/slyubomirsky): @slyubomirsky - relay - [Hao Lu](https://github.com/hlu1): @hlu1 - nnpack, frontends +- [Steven Lyubomirsky](https://github.com/slyubomirsky): @slyubomirsky - relay - [Masahiro Masuda](https://github.com/masahi) (PMC): @masahi - topi, relay - [Thierry Moreau](https://github.com/tmoreau89) (PMC): @tmoreau89 - vta - [Kazutaka Morita](https://github.com/kazum): @kazum - frontends, opencl +- [Trevor Morris](https://github.com/trevor-m): @trevor-m - byoc, compiler - [Leandro Nunes](https://github.com/leandron): @leandron - tvmc - [Krzysztof Parzyszek](https://github.com/kparzysz-quic): @kparzysz-quic - hexagon, llvm - [Andrew Reusch](https://github.com/areusch): @areusch - runtime, µTVM - [Jared Roesch](https://github.com/jroesch) (PMC): @jroesch - relay - [Siju Samuel](https://github.com/siju-samuel): @siju-samuel - frontends -- [Siva](https://github.com/srkreddy1238): @srkreddy1238 - frontends, golang - [Junru Shao](https://github.com/junrushao1994) @junrushao1994 - relay, compiler - [Haichen Shen](https://github.com/icemelon9) (PMC): @icemelon9 - relay, topi +- [Siva](https://github.com/srkreddy1238): @srkreddy1238 - frontends, golang - [Zhixun Tan](https://github.com/phisiart): @phisiart - opengl, web - [Andrew Tulloch](https://github.com/ajtulloch): @ajtulloch - topi, compiler, runtime - [Luis Vega](https://github.com/vegaluisjose): @vegaluisjose - vta, chisel @@ -85,28 +86,25 @@ We do encourage everyone to work anything they are interested in. - [Matthew Barrett](https://github.com/mbaret): @mbaret - [Arnaud Bergeron](https://github.com/abergeron): @abergeron - [Matthew Brookhart](https://github.com/mbrookhart): @mbrookhart -- [Tianqi Chen](https://github.com/tqchen): @tqchen - [Liangfu Chen](https://github.com/liangfu): @liangfu +- [Tianqi Chen](https://github.com/tqchen): @tqchen - [Zhi Chen](https://github.com/zhiics): @zhiics -- [Chenfan](https://github.com/jcf94): @jcf94 - [Neo Chien](https://github.com/cchung100m): @cchung100m - [Meghan Cowan](https://github.com/cowanmeg): @cowanmeg - [Balint Cristian](https://github.com/cbalint13): @cbalint13 +- [Egor Churaev](https://github.com/echuraev): @echuraev - metal +- [Xiaoqiang Dan](https://github.com/xqdan): @xqdan - [Haozheng Fan](https://github.com/hzfan): @hzfan -- [Josh Fromm](https://github.com/jwfromm): @jwfromm - [Siyuan Feng](https://github.com/Hzfengsy): @Hzfengsy +- [Josh Fromm](https://github.com/jwfromm): @jwfromm - [Sergei Grechanik](https://github.com/sgrechanik-h): @sgrechanik-h -- [Hao Lu](https://github.com/hlu1): @hlu1 - [Bohan Hou](https://github.com/spectrometerHBH): @spectrometerHBH -- [Nick Hynes](https://github.com/nhynes): @nhynes - [Yuwei Hu](https://github.com/Huyuwei): @Huyuwei - [Luke Hutton](https://github.com/lhutton1): @lhutton1 +- [Nick Hynes](https://github.com/nhynes): @nhynes - [Animesh Jain](https://github.com/anijain2305): @anijain2305 +- [Chenfan Jia](https://github.com/jcf94): @jcf94 - [Hua Jiang](https://github.com/huajsj): @huajsj -- [Leandro Nunes](https://github.com/leandron): @leandron -- [Yizhi Liu](https://github.com/yzhliu) : @yzhliu -- [Zhixun Tan](https://github.com/phisiart): @phisiart -- [Xiaoqiang Dan](https://github.com/xqdan): @xqdan - [Ziheng Jiang](https://github.com/ZihengJiang): @ZihengJiang - [Manupa Karunaratne](https://github.com/manupa-arm): @manupa-arm - [Marisa Kirisame](https://github.com/MarisaKirisame): @MarisaKirisame @@ -115,6 +113,8 @@ We do encourage everyone to work anything they are interested in. - [Andrew Liu](https://github.com/hypercubestart): @hypercubestart - [Henry Liu](https://github.com/optima2005): @optima2005 - [Xin Liu](https://github.com/Meteorix): @Meteorix +- [Yizhi Liu](https://github.com/yzhliu) : @yzhliu +- [Hao Lu](https://github.com/hlu1): @hlu1 - [Steven Lyubomirsky](https://github.com/slyubomirsky): @slyubomirsky - [Masahiro Masuda](https://github.com/masahi): @masahi - [Sergey Mironov](https://github.com/grwlf): @grwlf @@ -122,26 +122,28 @@ We do encourage everyone to work anything they are interested in. - [Kazutaka Morita](https://github.com/kazum): @kazum - [Trevor Morris](https://github.com/trevor-m): @trevor-m - [Tatsuya Nishiyama](https://github.com/nishi-t): @nishi-t +- [Leandro Nunes](https://github.com/leandron): @leandron - [Wei Pan](https://github.com/wpan11nv): @wpan11nv - [Krzysztof Parzyszek](https://github.com/kparzysz-quic): @kparzysz-quic - [Pariksheet Pinjari](https://github.com/PariksheetPinjari909): @PariksheetPinjari909 - [Josh Pollock](https://github.com/joshpoll): @joshpoll +- [Andrew Reusch](https://github.com/areusch): @areusch - [Jared Roesch](https://github.com/jroesch): @jroesch - [Giuseppe Rossini](https://github.com/giuseros): @giuseros -- [Andrew Reusch](https://github.com/areusch): @areusch -- [Dmitriy Smirnov](https://github.com/d-smirnov): @d-smirnov -- [Siva](https://github.com/srkreddy1238): @srkreddy1238 - [Siju Samuel](https://github.com/siju-samuel): @siju-samuel - [Junru Shao](https://github.com/junrushao1994): @junrushao1994 - [Haichen Shen](https://github.com/icemelon9): @icemelon9 - [Xingjian Shi](https://github.com/sxjscience): @sxjscience +- [Siva](https://github.com/srkreddy1238): @srkreddy1238 +- [Dmitriy Smirnov](https://github.com/d-smirnov): @d-smirnov - [Jon Soifer](https://github.com/soiferj): @soiferj +- [Zhixun Tan](https://github.com/phisiart): @phisiart - [Andrew Tulloch](https://github.com/ajtulloch): @ajtulloch - [Luis Vega](https://github.com/vegaluisjose): @vegaluisjose - [Thomas Viehmann](https://github.com/t-vi): @t-vi -- [Alex Weaver](https://github.com/alex-weaver): @alex-weaver - [Yao Wang](https://github.com/kevinthesun): @kevinthesun - [Leyuan Wang](https://github.com/Laurawly): @Laurawly +- [Alex Weaver](https://github.com/alex-weaver): @alex-weaver - [Logan Weber](https://github.com/weberlo): @weberlo - [Jian Weng](https://github.com/were): @were - [Yong Wu](https://github.com/yongwww): @yongwww @@ -154,9 +156,3 @@ We do encourage everyone to work anything they are interested in. ## List of Contributors - [Full List of Contributors](https://github.com/apache/tvm/graphs/contributors) - - To contributors: please add your name to the list. -- [Qiao Zhang](https://github.com/zhangqiaorjc) -- [Haolong Zhang](https://github.com/haolongzhangm) -- [Cody Hao Yu](https://github.com/comaniac) -- [Chris Nuernberger](https://github.com/cnuernber) -- [Shoubhik Bhattacharya](https://github.com/shoubhik) diff --git a/Jenkinsfile b/Jenkinsfile index a2f1017b66a0..3ea6d22d11d0 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -44,13 +44,13 @@ // // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> -ci_lint = "tlcpack/ci-lint:v0.62" -ci_gpu = "tlcpack/ci-gpu:v0.72" -ci_cpu = "tlcpack/ci-cpu:v0.73" -ci_wasm = "tlcpack/ci-wasm:v0.70" -ci_i386 = "tlcpack/ci-i386:v0.72-t0" -ci_qemu = "tlcpack/ci-qemu:v0.04" -ci_arm = "tlcpack/ci-arm:v0.03" +ci_lint = "tlcpack/ci-lint:v0.65" +ci_gpu = "tlcpack/ci-gpu:v0.75" +ci_cpu = "tlcpack/ci-cpu:v0.74" +ci_wasm = "tlcpack/ci-wasm:v0.71" +ci_i386 = "tlcpack/ci-i386:v0.73" +ci_qemu = "tlcpack/ci-qemu:v0.05" +ci_arm = "tlcpack/ci-arm:v0.05" // <--- End of regex-scanned config. // tvm libraries diff --git a/apps/android_camera/app/src/main/jni/tvm_runtime.h b/apps/android_camera/app/src/main/jni/tvm_runtime.h index f3c7efd08b5c..07a812c4b840 100644 --- a/apps/android_camera/app/src/main/jni/tvm_runtime.h +++ b/apps/android_camera/app/src/main/jni/tvm_runtime.h @@ -57,6 +57,7 @@ #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" #include "../src/runtime/opencl/opencl_module.cc" +#include "../src/runtime/source_utils.cc" #endif #ifdef TVM_VULKAN_RUNTIME diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 5dcd823929ca..1331e1a65ca8 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -62,6 +62,8 @@ #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" #include "../src/runtime/opencl/opencl_module.cc" +#include "../src/runtime/opencl/texture_pool.cc" +#include "../src/runtime/source_utils.cc" #endif #ifdef TVM_VULKAN_RUNTIME diff --git a/apps/extension/python/tvm_ext/__init__.py b/apps/extension/python/tvm_ext/__init__.py index 0315a8f11b39..be1b42328c1b 100644 --- a/apps/extension/python/tvm_ext/__init__.py +++ b/apps/extension/python/tvm_ext/__init__.py @@ -44,7 +44,7 @@ def load_lib(): @tvm.register_object("tvm_ext.IntVector") class IntVec(tvm.Object): - """Example for using extension class in c++ """ + """Example for using extension class in c++""" @property def _tvm_handle(self): diff --git a/apps/microtvm/reference-vm/README.md b/apps/microtvm/reference-vm/README.md index 7ef7900c3e05..7ff75c75b4f9 100644 --- a/apps/microtvm/reference-vm/README.md +++ b/apps/microtvm/reference-vm/README.md @@ -49,19 +49,19 @@ Reference VMs are organized as follows: ## Creating Releases -1. Build the base box for the given platform: `$ ./base-box-tool.py build ` +1. Build the base box for the given platform: `$ ./base-box-tool.py [--provider=] build ` 2. Run release tests for each platform: 1. Connect any needed hardware to the VM host machine. - 2. Run tests: `$ ./base-box-tool.py test [--test-device-serial=]`. This + 2. Run tests: `$ ./base-box-tool.py [--provider=] test [--microtvm-platform=] [--test-device-serial=]`. This command does the following for each provider: 1. Copies all files inside `./` except `.vagrant` and `base-box` to `./release-test`. This is done to avoid reusing any VM the developer may have started. - 2. Executes `$ vagrant up --provider=`. + 2. Executes `$ vagrant up [--provider=]`. 3. Finds an attached USB device matching the VID and PID specified in `test-config.json`, and if `--test-device-serial` was given, that serial number (as reported to USB). Creates a rule to autoconnect this device to the VM, and also attaches it to the VM> 4. SSHs to the VM, `cd` to the TVM root directory, and runs `test_cmd` from `test-config.json`. Nonzero status means failure. 3. If release tests fail, fix them and restart from step 1. -4. If release tests pass: `$ ./base-box-tool.py release `. Be sure you've logged +4. If release tests pass: `$ ./base-box-tool.py [--provider=] release <--release-version=> <--platform-version=> `. Be sure you've logged in to Vagrant Cloud using the `vagrant` tool. diff --git a/apps/microtvm/reference-vm/base-box-tool.py b/apps/microtvm/reference-vm/base-box-tool.py index fb7a9c0b5ce6..c22eff4cdbad 100755 --- a/apps/microtvm/reference-vm/base-box-tool.py +++ b/apps/microtvm/reference-vm/base-box-tool.py @@ -34,7 +34,6 @@ THIS_DIR = os.path.realpath(os.path.dirname(__file__) or ".") - # List of vagrant providers supported by this tool ALL_PROVIDERS = ( "parallels", @@ -46,8 +45,11 @@ ALL_MICROTVM_PLATFORMS = ( "stm32f746xx", "nrf5340dk", + "mps2_an521", ) +PACKER_FILE_NAME = "packer.json" + def parse_virtualbox_devices(): output = subprocess.check_output(["VBoxManage", "list", "usbhost"], encoding="utf-8") @@ -173,12 +175,21 @@ def attach_vmware(uuid, vid_hex=None, pid_hex=None, serial=None): "vmware_desktop": attach_vmware, } +# Extra scripts required to execute on provisioning +# in zephyr/base-box/base_box_provision.sh +EXTRA_SCRIPTS = ( + "docker/install/ubuntu_init_zephyr_project.sh", + "docker/install/ubuntu_install_qemu.sh", +) + def generate_packer_config(file_path, providers): builders = [] + provisioners = [] for provider_name in providers: builders.append( { + "name": f"{provider_name}", "type": "vagrant", "box_name": f"microtvm-base-{provider_name}", "output_dir": f"output-packer-{provider_name}", @@ -189,10 +200,26 @@ def generate_packer_config(file_path, providers): } ) + repo_root = subprocess.check_output( + ["git", "rev-parse", "--show-toplevel"], cwd=os.path.dirname(__file__), encoding="utf-8" + ).strip() + for script in EXTRA_SCRIPTS: + script_path = os.path.join(repo_root, script) + filename = os.path.basename(script_path) + provisioners.append({"type": "file", "source": script_path, "destination": f"~/{filename}"}) + + provisioners.append( + { + "type": "shell", + "script": "base_box_provision.sh", + } + ) + with open(file_path, "w") as f: json.dump( { "builders": builders, + "provisioners": provisioners, }, f, sort_keys=True, @@ -202,7 +229,7 @@ def generate_packer_config(file_path, providers): def build_command(args): generate_packer_config( - os.path.join(THIS_DIR, args.platform, "base-box", "packer.json"), + os.path.join(THIS_DIR, args.platform, "base-box", PACKER_FILE_NAME), args.provider or ALL_PROVIDERS, ) env = copy.copy(os.environ) @@ -212,7 +239,7 @@ def build_command(args): if args.debug_packer: packer_args += ["-debug"] - packer_args += ["packer.json"] + packer_args += [PACKER_FILE_NAME] subprocess.check_call( packer_args, cwd=os.path.join(THIS_DIR, args.platform, "base-box"), env=env ) @@ -221,7 +248,6 @@ def build_command(args): REQUIRED_TEST_CONFIG_KEYS = { "vid_hex": str, "pid_hex": str, - "test_cmd": list, } @@ -284,7 +310,6 @@ def do_build_release_test_vm(release_test_dir, user_box_dir, base_box_dir, provi return_code = subprocess.call(remove_args, cwd=release_test_dir) assert return_code in (0, 1), f'{" ".join(remove_args)} returned exit code {return_code}' subprocess.check_call(["vagrant", "up", f"--provider={provider_name}"], cwd=release_test_dir) - return True @@ -293,18 +318,30 @@ def do_run_release_test(release_test_dir, provider_name, test_config, test_devic os.path.join(release_test_dir, ".vagrant", "machines", "default", provider_name, "id") ) as f: machine_uuid = f.read() - ATTACH_USB_DEVICE[provider_name]( - machine_uuid, - vid_hex=test_config["vid_hex"], - pid_hex=test_config["pid_hex"], - serial=test_device_serial, - ) + + # Check if target is not QEMU + if test_config["vid_hex"] and test_config["pid_hex"]: + ATTACH_USB_DEVICE[provider_name]( + machine_uuid, + vid_hex=test_config["vid_hex"], + pid_hex=test_config["pid_hex"], + serial=test_device_serial, + ) tvm_home = os.path.realpath(os.path.join(THIS_DIR, "..", "..", "..")) def _quote_cmd(cmd): return " ".join(shlex.quote(a) for a in cmd) - test_cmd = _quote_cmd(["cd", tvm_home]) + " && " + _quote_cmd(test_config["test_cmd"]) + test_cmd = ( + _quote_cmd(["cd", tvm_home]) + + " && " + + _quote_cmd( + [ + "apps/microtvm/reference-vm/zephyr/base-box/base_box_test.sh", + test_config["microtvm_platform"], + ] + ) + ) subprocess.check_call(["vagrant", "ssh", "-c", f"bash -ec '{test_cmd}'"], cwd=release_test_dir) @@ -325,6 +362,7 @@ def test_command(args): microtvm_test_platform["vid_hex"] = microtvm_test_platform["vid_hex"].lower() microtvm_test_platform["pid_hex"] = microtvm_test_platform["pid_hex"].lower() + microtvm_test_platform["microtvm_platform"] = args.microtvm_platform providers = args.provider provider_passed = {p: False for p in providers} @@ -399,18 +437,18 @@ def parse_args(): description="Automates building, testing, and releasing a base box" ) subparsers = parser.add_subparsers(help="Action to perform.") - parser.add_argument( - "platform", - help="Name of the platform VM to act on. Must be a sub-directory of this directory.", - ) parser.add_argument( "--provider", choices=ALL_PROVIDERS, action="append", - default=list(ALL_PROVIDERS), help="Name of the provider or providers to act on; if not specified, act on all.", ) + parser.add_argument( + "platform", + help="Name of the platform VM to act on. Must be a sub-directory of this directory.", + ) + parser_build = subparsers.add_parser("build", help="Build a base box.") parser_build.set_defaults(func=build_command) parser_test = subparsers.add_parser("test", help="Test a base box before release.") diff --git a/apps/microtvm/reference-vm/zephyr/Vagrantfile b/apps/microtvm/reference-vm/zephyr/Vagrantfile index f335565341db..be41c0b733e5 100644 --- a/apps/microtvm/reference-vm/zephyr/Vagrantfile +++ b/apps/microtvm/reference-vm/zephyr/Vagrantfile @@ -18,6 +18,18 @@ Vagrant.configure("2") do |config| config.vm.box = "tlcpack/microtvm-zephyr-2.5" + if ENV.has_key?("TVM_RVM_NUM_CORES") + num_cores = ENV["TVM_RVM_NUM_CORES"] + else + num_cores = 2 + end + + if ENV.has_key?("TVM_RVM_RAM_BYTES") + ram_bytes = ENV["TVM_RVM_RAM_BYTES"] + else + ram_bytes = 2048 + end + tvm_home = "../../../.." dirs_to_mount = [Pathname.new(Pathname.new(tvm_home).expand_path())] if ENV.has_key?("TVM_PROJECT_DIR") then @@ -34,12 +46,14 @@ Vagrant.configure("2") do |config| end end - config.vm.provision "shell", path: "setup.sh", env: {"TVM_HOME": dirs_to_mount[0]}, privileged: false + config.vm.provision "shell", path: "provision_setup.sh", env: {"TVM_HOME": dirs_to_mount[0]}, privileged: false # Enable USB Controller on VirtualBox vm_name = "microtvm-#{Time.now.tv_sec}" config.vm.provider "virtualbox" do |vb, overrides| vb.name = vm_name + vb.cpus = num_cores + vb.memory = ram_bytes vb.customize ["modifyvm", :id, "--usb", "on"] vb.customize ["modifyvm", :id, "--usbehci", "on"] vb.customize ["modifyvm", :id, "--usbxhci", "on"] @@ -50,6 +64,8 @@ Vagrant.configure("2") do |config| config.vm.provider "parallels" do |prl, overrides| prl.name = vm_name + prl.cpus = num_cores + prl.memory = ram_bytes prl.update_guest_tools = true prl.customize ["set", :id, "--support-usb30", "on"] dirs_to_mount.each do |d| @@ -58,6 +74,8 @@ Vagrant.configure("2") do |config| end config.vm.provider "vmware_desktop" do |vm, overrides| + vm.cpus = num_cores + vm.memory = ram_bytes vm.vmx["usb_xhci.present"] = "TRUE" vm.vmx["usb.present"] = "TRUE" vm.vmx["ehci.present"] = "TRUE" diff --git a/apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template b/apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template index 38f9a20b56cf..b43596bb83c1 100644 --- a/apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template +++ b/apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template @@ -41,7 +41,7 @@ Vagrant.configure("2") do |config| config.vm.provision "shell", inline: "touch ~/skip_zeroing_disk", privileged: false {{- end}} - # NOTE: setup.sh resides in the parent directory (../) because this template is expanded into a + # NOTE: base_box_setup.sh resides in the parent directory (../) because this template is expanded into a # sub-directory of base-box (output-packer-*). - config.vm.provision "shell", path: "../setup.sh", privileged: false + config.vm.provision "shell", path: "../base_box_setup.sh", privileged: false end diff --git a/tests/python/relay/test_pass_profiler.py b/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh similarity index 52% rename from tests/python/relay/test_pass_profiler.py rename to apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh index acf6c8c50aff..69e6171d06dd 100644 --- a/tests/python/relay/test_pass_profiler.py +++ b/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh @@ -1,3 +1,4 @@ +#!/bin/bash -e # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,28 +15,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -import tvm.relay -from tvm.relay import op - +# +# Using this script we can reuse docker/install scripts to configure the reference +# virtual machine similar to CI QEMU setup. +# -def test_pass_profiler(): - x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] - e1 = op.add(x, y) - e2 = op.subtract(x, z) - e3 = op.multiply(e1, e1 / e2) - mod = tvm.IRModule.from_expr(e3 + e2) +set -e +set -x - tvm.transform.enable_pass_profiling() +source ~/.profile - mod = tvm.relay.transform.AnnotateSpans()(mod) - mod = tvm.relay.transform.ToANormalForm()(mod) - mod = tvm.relay.transform.InferType()(mod) +# Init Zephyr +cd ~ +# Using most recent commit that passes all the tests. +~/ubuntu_init_zephyr_project.sh ~/zephyr v2.5-branch --commit dabf23758417fd041fec2a2a821d8f526afac29d - profiles = tvm.transform.render_pass_profiles() - assert "AnnotateSpans" in profiles - assert "ToANormalForm" in profiles - assert "InferType" in profiles +# Build QEMU +sudo ~/ubuntu_install_qemu.sh --target-list arm-softmmu - tvm.transform.clear_pass_profiles() - tvm.transform.disable_pass_profiling() +# Cleanup +rm -f *.sh diff --git a/apps/microtvm/reference-vm/zephyr/base-box/setup.sh b/apps/microtvm/reference-vm/zephyr/base-box/base_box_setup.sh similarity index 97% rename from apps/microtvm/reference-vm/zephyr/base-box/setup.sh rename to apps/microtvm/reference-vm/zephyr/base-box/base_box_setup.sh index 8f7ed41af337..e8385af9f663 100644 --- a/apps/microtvm/reference-vm/zephyr/base-box/setup.sh +++ b/apps/microtvm/reference-vm/zephyr/base-box/base_box_setup.sh @@ -17,6 +17,7 @@ # under the License. set -e +set -x skip_zeroing_disk=0 if [ -e "$HOME/skip_zeroing_disk" ]; then @@ -81,8 +82,6 @@ pip3 install --user -U west echo 'export PATH=$HOME/.local/bin:"$PATH"' >> ~/.profile source ~/.profile echo PATH=$PATH -REPO_ROOT=$(git rev-parse --show-toplevel) -${REPO_ROOT}/docker/install/ubuntu_init_zephyr_project.sh ~/zephyr v2.5.0 cd ~ echo "Downloading zephyr SDK..." diff --git a/apps/microtvm/reference-vm/zephyr/base-box/base_box_test.sh b/apps/microtvm/reference-vm/zephyr/base-box/base_box_test.sh new file mode 100755 index 000000000000..8eba63e9e331 --- /dev/null +++ b/apps/microtvm/reference-vm/zephyr/base-box/base_box_test.sh @@ -0,0 +1,39 @@ +#!/bin/bash -e +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Usage: base_box_test.sh +# Execute microTVM Zephyr tests. +# + +set -e +set -x + +if [ "$#" -lt 1 ]; then + echo "Usage: base_box_test.sh " + exit -1 +fi + +microtvm_platform=$1 + +pytest tests/micro/zephyr/test_zephyr.py --microtvm-platforms=${microtvm_platform} + +if [ $microtvm_platform == "stm32f746xx" ]; then + echo "NOTE: skipped test_zephyr_aot.py on $microtvm_platform -- known failure" +else + pytest tests/micro/zephyr/test_zephyr_aot.py --microtvm-platforms=${microtvm_platform} +fi diff --git a/apps/microtvm/reference-vm/zephyr/base-box/test-config.json b/apps/microtvm/reference-vm/zephyr/base-box/test-config.json index 1a39d34c7e64..48b6915a10f4 100644 --- a/apps/microtvm/reference-vm/zephyr/base-box/test-config.json +++ b/apps/microtvm/reference-vm/zephyr/base-box/test-config.json @@ -1,12 +1,14 @@ { "stm32f746xx": { "vid_hex": "0483", - "pid_hex": "374b", - "test_cmd": ["pytest", "tests/micro/zephyr/test_zephyr.py", "--microtvm-platforms=stm32f746xx"] + "pid_hex": "374b" }, "nrf5340dk": { "vid_hex": "1366", - "pid_hex": "1055", - "test_cmd": ["pytest", "tests/micro/zephyr/test_zephyr.py", "--microtvm-platforms=nrf5340dk"] + "pid_hex": "1055" + }, + "mps2_an521": { + "vid_hex": "", + "pid_hex": "" } } diff --git a/apps/microtvm/reference-vm/zephyr/setup.sh b/apps/microtvm/reference-vm/zephyr/provision_setup.sh similarity index 95% rename from apps/microtvm/reference-vm/zephyr/setup.sh rename to apps/microtvm/reference-vm/zephyr/provision_setup.sh index e0f382cfc23e..f95c7e24f5aa 100644 --- a/apps/microtvm/reference-vm/zephyr/setup.sh +++ b/apps/microtvm/reference-vm/zephyr/provision_setup.sh @@ -24,6 +24,7 @@ cd "${TVM_HOME}" apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh +# Build poetry cd apps/microtvm/reference-vm/zephyr poetry env use 3.6 @@ -41,7 +42,7 @@ echo "downloaded and cached for future use." echo "------------------------------[ TVM Message ]------------------------------" poetry lock -vvv poetry install -poetry run pip3 install -r ~/zephyr/zephyr/scripts/requirements.txt +poetry run pip3 install -r ${ZEPHYR_BASE}/scripts/requirements.txt echo "export TVM_LIBRARY_PATH=\"$TVM_HOME\"/build-microtvm" >>~/.profile echo "VENV_PATH=\$((cd \"$TVM_HOME\"/apps/microtvm/reference-vm/zephyr && poetry env list --full-path) | sed -E 's/^(.*)[[:space:]]\(Activated\)\$/\1/g')" >>~/.profile diff --git a/apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh b/apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh index 2eb55e385520..1cebcf7166af 100755 --- a/apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh +++ b/apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh @@ -18,6 +18,14 @@ set -e +# Get number of cores for build +if [ -n "${TVM_CI_NUM_CORES}" ]; then + num_cores=${TVM_CI_NUM_CORES} +else + # default setup for Vagrantfile + num_cores=2 +fi + cd "$(dirname $0)" cd "$(git rev-parse --show-toplevel)" BUILD_DIR=build-microtvm @@ -32,4 +40,4 @@ sed -i 's/USE_GRAPH_EXECUTOR_DEBUG OFF/USE_GRAPH_EXECUTOR_DEBUG ON/' config.cmak sed -i 's/USE_LLVM OFF/USE_LLVM ON/' config.cmake cmake .. rm -rf standalone_crt host_standalone_crt # remove stale generated files -make -j4 +make -j${num_cores} diff --git a/apps/microtvm/zephyr/demo_runtime/CMakeLists.txt b/apps/microtvm/zephyr/aot_demo/CMakeLists.txt similarity index 88% rename from apps/microtvm/zephyr/demo_runtime/CMakeLists.txt rename to apps/microtvm/zephyr/aot_demo/CMakeLists.txt index a99d5edb07e6..d7ec2a25db14 100644 --- a/apps/microtvm/zephyr/demo_runtime/CMakeLists.txt +++ b/apps/microtvm/zephyr/aot_demo/CMakeLists.txt @@ -10,8 +10,9 @@ find_package(Zephyr HINTS $ENV{ZEPHYR_BASE}) project(microtvm_zephyr_runtime) set(CMAKE_VERBOSE_MAKEFILE ON) -file(GLOB TVM_SOURCES ${CMAKE_SOURCE_DIR}/__tvm*.c) -target_sources(app PRIVATE src/main.c ${TVM_SOURCES}) + +target_sources(app PRIVATE src/zephyr_uart.c) +target_sources(app PRIVATE src/main.c) foreach(tvm_lib ${TVM_LIBS}) string(LENGTH ${tvm_lib} tvm_lib_length) diff --git a/apps/microtvm/zephyr/aot_demo/README.md b/apps/microtvm/zephyr/aot_demo/README.md new file mode 100644 index 000000000000..a718da65e2fa --- /dev/null +++ b/apps/microtvm/zephyr/aot_demo/README.md @@ -0,0 +1,20 @@ + + + + + + + + + + + + + + + + + +This directory contains a Zephyr-based ahead of time (AOT) "demo" runtime environment that +pulls together the microTVM runtime dependencies into a single application +that can run TVM on a microTVM device without the need to a host. diff --git a/apps/microtvm/zephyr/demo_runtime/boards/mps2_an521.conf b/apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/boards/mps2_an521.conf rename to apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf diff --git a/apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf b/apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf new file mode 100644 index 000000000000..d298325eb4a4 --- /dev/null +++ b/apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file is specific to the nRF5340 DK board. + +# For intrinsics used by generated optimized operators. +CONFIG_CMSIS_DSP=y + +# For AOT runtime which requires lots of function call. +CONFIG_MAIN_STACK_SIZE=2000 + +# For random number generation. +CONFIG_ENTROPY_GENERATOR=y +CONFIG_TEST_RANDOM_GENERATOR=y + +# For debugging. +CONFIG_LED=y diff --git a/apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf b/apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf new file mode 100644 index 000000000000..5f3c4a4bed36 --- /dev/null +++ b/apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is specific to the QEMU-emulated microTVM board. + +# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. +CONFIG_TEST_RANDOM_GENERATOR=y +CONFIG_TIMER_RANDOM_GENERATOR=y + +# Default stack size is 1k, this is required for debug mode. +CONFIG_MAIN_STACK_SIZE=2000 diff --git a/apps/microtvm/zephyr/aot_demo/crt/crt_config.h b/apps/microtvm/zephyr/aot_demo/crt/crt_config.h new file mode 100644 index 000000000000..9ee315aa1763 --- /dev/null +++ b/apps/microtvm/zephyr/aot_demo/crt/crt_config.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/crt_config.h.template + * \brief Template for CRT configuration, to be modified on each target. + */ +#ifndef TVM_RUNTIME_CRT_CONFIG_H_ +#define TVM_RUNTIME_CRT_CONFIG_H_ + +#include + +/*! Log level of the CRT runtime */ +#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG + +/*! Maximum supported dimension in NDArray */ +#define TVM_CRT_MAX_NDIM 6 + +/*! Maximum supported arguments in generated functions */ +#define TVM_CRT_MAX_ARGS 10 + +/*! Size of the global function registry, in bytes. */ +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200 + +/*! Maximum number of registered modules. */ +#define TVM_CRT_MAX_REGISTERED_MODULES 2 + +/*! Maximum packet size, in bytes, including the length header. */ +#define TVM_CRT_MAX_PACKET_SIZE_BYTES (1 * 1024) + +/*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ +#define TVM_CRT_MAX_STRLEN_DLTYPE 10 + +/*! Maximum supported string length in function names */ +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 + +/*! \brief Maximum length of a PackedFunc function name. */ +#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 + +/*! \brief Log2 of the page size (bytes) for a virtual memory page. */ +#define TVM_CRT_PAGE_BITS 10 // 1 kB + +/*! \brief Number of pages on device. */ +#define TVM_CRT_MAX_PAGES 300 + +#endif // TVM_RUNTIME_CRT_CONFIG_H_ diff --git a/apps/microtvm/zephyr/aot_demo/include/zephyr_uart.h b/apps/microtvm/zephyr/aot_demo/include/zephyr_uart.h new file mode 100644 index 000000000000..f24ade734c4f --- /dev/null +++ b/apps/microtvm/zephyr/aot_demo/include/zephyr_uart.h @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_APPS_MICROTVM_ZEPHYR_AOT_DEMO_INCLUDE_ZEPHYR_UART_H_ +#define TVM_APPS_MICROTVM_ZEPHYR_AOT_DEMO_INCLUDE_ZEPHYR_UART_H_ + +#include + +// Used to read data from the UART. + +/*! + * \brief Read Uart Rx buffer. + * \param data Pointer to read data. + * \param data_size_bytes Read request size in bytes. + * + * \return Number of data read in bytes. + */ +uint32_t TVMPlatformUartRxRead(uint8_t* data, uint32_t data_size_bytes); + +/*! + * \brief Write data in serial. + * \param data Pointer to data to write. + * \param size Size of data in bytes. + * + * \return Number of write in bytes. + */ +uint32_t TVMPlatformWriteSerial(const char* data, uint32_t size); + +/*! + * \brief Initialize Uart. + */ +void TVMPlatformUARTInit(); + +#endif /* TVM_APPS_MICROTVM_ZEPHYR_AOT_DEMO_INCLUDE_ZEPHYR_UART_H_ */ diff --git a/apps/microtvm/zephyr/demo_runtime/prj.conf b/apps/microtvm/zephyr/aot_demo/prj.conf similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/prj.conf rename to apps/microtvm/zephyr/aot_demo/prj.conf diff --git a/apps/microtvm/zephyr/aot_demo/qemu-hack b/apps/microtvm/zephyr/aot_demo/qemu-hack new file mode 120000 index 000000000000..b4810f2aab6e --- /dev/null +++ b/apps/microtvm/zephyr/aot_demo/qemu-hack @@ -0,0 +1 @@ +../qemu-hack \ No newline at end of file diff --git a/apps/microtvm/zephyr/aot_demo/src/main.c b/apps/microtvm/zephyr/aot_demo/src/main.c new file mode 100644 index 000000000000..b92366a7098b --- /dev/null +++ b/apps/microtvm/zephyr/aot_demo/src/main.c @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "input_data.h" +#include "output_data.h" +#include "zephyr_uart.h" + +#ifdef CONFIG_ARCH_POSIX +#include "posix_board_if.h" +#endif + +#define WORKSPACE_SIZE (270 * 1024) + +static uint8_t g_aot_memory[WORKSPACE_SIZE]; +extern tvm_model_t network; +tvm_workspace_t app_workspace; + +// Wakeup sequence used to wake up QEMU on the host. +const unsigned char g_wakeup_sequence[12] = {0xfe, 0xff, 0xfd, 0x03, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x02, 0x66, 0x77}; +const char g_start_cmd[] = "start\n"; + +size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, + va_list args) { + return vsnprintk(out_buf, out_buf_size_bytes, fmt, args); +} + +void TVMLogf(const char* msg, ...) { + char buffer[256]; + int size; + va_list args; + va_start(args, msg); + size = vsprintf(buffer, msg, args); + va_end(args); + TVMPlatformWriteSerial(buffer, (uint32_t)size); +} + +void TVMPlatformAbort(tvm_crt_error_t error) { + TVMLogf("TVMPlatformAbort: %08x\n", error); + sys_reboot(SYS_REBOOT_COLD); + for (;;) + ; +} + +tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { + return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr); +} + +tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { + return StackMemoryManager_Free(&app_workspace, ptr); +} + +void timer_expiry_function(struct k_timer* timer_id) { return; } + +#define MILLIS_TIL_EXPIRY 200 +#define TIME_TIL_EXPIRY (K_MSEC(MILLIS_TIL_EXPIRY)) +struct k_timer g_utvm_timer; +uint32_t g_utvm_start_time; +int g_utvm_timer_running = 0; + +// Called to start system timer. +tvm_crt_error_t TVMPlatformTimerStart() { + if (g_utvm_timer_running) { + TVMLogf("timer already running"); + return kTvmErrorPlatformTimerBadState; + } + + k_timer_start(&g_utvm_timer, TIME_TIL_EXPIRY, TIME_TIL_EXPIRY); + g_utvm_start_time = k_cycle_get_32(); + g_utvm_timer_running = 1; + return kTvmErrorNoError; +} + +// Called to stop system timer. +tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { + if (!g_utvm_timer_running) { + TVMLogf("timer not running"); + return kTvmErrorSystemErrorMask | 2; + } + + uint32_t stop_time = k_cycle_get_32(); + + // compute how long the work took + uint32_t cycles_spent = stop_time - g_utvm_start_time; + if (stop_time < g_utvm_start_time) { + // we rolled over *at least* once, so correct the rollover it was *only* + // once, because we might still use this result + cycles_spent = ~((uint32_t)0) - (g_utvm_start_time - stop_time); + } + + uint32_t ns_spent = (uint32_t)k_cyc_to_ns_floor64(cycles_spent); + double hw_clock_res_us = ns_spent / 1000.0; + + // need to grab time remaining *before* stopping. when stopped, this function + // always returns 0. + int32_t time_remaining_ms = k_timer_remaining_get(&g_utvm_timer); + k_timer_stop(&g_utvm_timer); + // check *after* stopping to prevent extra expiries on the happy path + if (time_remaining_ms < 0) { + return kTvmErrorSystemErrorMask | 3; + } + uint32_t num_expiries = k_timer_status_get(&g_utvm_timer); + uint32_t timer_res_ms = ((num_expiries * MILLIS_TIL_EXPIRY) + time_remaining_ms); + double approx_num_cycles = + (double)k_ticks_to_cyc_floor32(1) * (double)k_ms_to_ticks_ceil32(timer_res_ms); + // if we approach the limits of the HW clock datatype (uint32_t), use the + // coarse-grained timer result instead + if (approx_num_cycles > (0.5 * (~((uint32_t)0)))) { + *elapsed_time_seconds = timer_res_ms / 1000.0; + } else { + *elapsed_time_seconds = hw_clock_res_us / 1e6; + } + + g_utvm_timer_running = 0; + return kTvmErrorNoError; +} + +void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, int dtype_code_hint, + int dtype_bits_hint) { + tvm_crt_error_t err = kTvmErrorNoError; + void* ptr = 0; + DLDevice dev = {device_type, device_id}; + assert(nbytes > 0); + err = TVMPlatformMemoryAllocate(nbytes, dev, &ptr); + CHECK_EQ(err, kTvmErrorNoError, + "TVMBackendAllocWorkspace(%d, %d, %" PRIu64 ", %d, %d) -> %" PRId32, device_type, + device_id, nbytes, dtype_code_hint, dtype_bits_hint, err); + return ptr; +} + +int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { + tvm_crt_error_t err = kTvmErrorNoError; + DLDevice dev = {device_type, device_id}; + err = TVMPlatformMemoryFree(ptr, dev); + return err; +} + +static uint8_t main_rx_buf[128]; +static uint8_t cmd_buf[128]; +static size_t g_cmd_buf_ind; + +void main(void) { + g_cmd_buf_ind = 0; + memset((char*)cmd_buf, 0, sizeof(cmd_buf)); + TVMPlatformUARTInit(); + k_timer_init(&g_utvm_timer, NULL, NULL); + // Wake up host side. + TVMPlatformWriteSerial(g_wakeup_sequence, sizeof(g_wakeup_sequence)); + + // Wait for start command + while (true) { + int bytes_read = TVMPlatformUartRxRead(main_rx_buf, sizeof(main_rx_buf)); + if (bytes_read > 0) { + memcpy((char*)cmd_buf + g_cmd_buf_ind, main_rx_buf, bytes_read); + g_cmd_buf_ind += bytes_read; + } + if (g_cmd_buf_ind >= 6) { + if (!strcmp((char*)(cmd_buf), g_start_cmd)) { + break; + } else { + memset((char*)cmd_buf, 0, sizeof(cmd_buf)); + g_cmd_buf_ind = 0; + } + } + } + TVMLogf("Zephyr AOT Runtime\n"); + + void* inputs[1] = { + input_data, + }; + void* outputs[1] = { + output_data, + }; + + StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE); + + double elapsed_time = 0; + TVMPlatformTimerStart(); + int ret_val = tvm_runtime_run(&network, inputs, outputs); + TVMPlatformTimerStop(&elapsed_time); + + if (ret_val != 0) { + TVMLogf("Error: %d\n", ret_val); + TVMPlatformAbort(kTvmErrorPlatformCheckFailure); + } + + size_t max_ind = -1; + float max_val = -FLT_MAX; + for (size_t i = 0; i < output_data_len; i++) { + if (output_data[i] >= max_val) { + max_ind = i; + max_val = output_data[i]; + } + } + TVMLogf("#result:%d:%d\n", max_ind, (uint32_t)(elapsed_time * 1000)); +#ifdef CONFIG_ARCH_POSIX + posix_exit(0); +#endif +} diff --git a/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c b/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c new file mode 100644 index 000000000000..1f4dde1de4b9 --- /dev/null +++ b/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "zephyr_uart.h" + +#include +#include + +#include "crt_config.h" + +static const struct device* g_utvm_uart; +#define RING_BUF_SIZE_BYTES (TVM_CRT_MAX_PACKET_SIZE_BYTES + 100) + +// Ring buffer used to store data read from the UART on rx interrupt. +RING_BUF_DECLARE(uart_rx_rbuf, RING_BUF_SIZE_BYTES); + +static uint8_t uart_data[8]; +// UART interrupt callback. +void uart_irq_cb(const struct device* dev, void* user_data) { + while (uart_irq_update(dev) && uart_irq_is_pending(dev)) { + struct ring_buf* rbuf = (struct ring_buf*)user_data; + if (uart_irq_rx_ready(dev) != 0) { + for (;;) { + // Read a small chunk of data from the UART. + int bytes_read = uart_fifo_read(dev, uart_data, sizeof(uart_data)); + if (bytes_read < 0) { + TVMPlatformAbort((tvm_crt_error_t)(0xbeef1)); + } else if (bytes_read == 0) { + break; + } + // Write it into the ring buffer. + int bytes_written = ring_buf_put(rbuf, uart_data, bytes_read); + if (bytes_read != bytes_written) { + TVMPlatformAbort((tvm_crt_error_t)(0xbeef2)); + } + } + } + } +} + +// Used to initialize the UART receiver. +void uart_rx_init(struct ring_buf* rbuf, const struct device* dev) { + uart_irq_callback_user_data_set(dev, uart_irq_cb, (void*)rbuf); + uart_irq_rx_enable(dev); +} + +uint32_t TVMPlatformUartRxRead(uint8_t* data, uint32_t data_size_bytes) { + unsigned int key = irq_lock(); + uint32_t bytes_read = ring_buf_get(&uart_rx_rbuf, data, data_size_bytes); + irq_unlock(key); + return bytes_read; +} + +uint32_t TVMPlatformWriteSerial(const char* data, uint32_t size) { + for (uint32_t i = 0; i < size; i++) { + uart_poll_out(g_utvm_uart, data[i]); + } + return size; +} + +// Initialize UART +void TVMPlatformUARTInit() { + // Claim console device. + g_utvm_uart = device_get_binding(DT_LABEL(DT_CHOSEN(zephyr_console))); + uart_rx_init(&uart_rx_rbuf, g_utvm_uart); +} diff --git a/apps/microtvm/zephyr/host_driven/CMakeLists.txt b/apps/microtvm/zephyr/host_driven/CMakeLists.txt new file mode 100644 index 000000000000..f04a792086cb --- /dev/null +++ b/apps/microtvm/zephyr/host_driven/CMakeLists.txt @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 + +cmake_minimum_required(VERSION 3.13.1) + +set(ENV{QEMU_BIN_PATH} "${CMAKE_SOURCE_DIR}/qemu-hack") + +set(QEMU_PIPE "\${QEMU_PIPE}") # QEMU_PIPE is set by the calling TVM instance. + +find_package(Zephyr HINTS $ENV{ZEPHYR_BASE}) +project(microtvm_zephyr_runtime) + +set(CMAKE_VERBOSE_MAKEFILE ON) + +target_sources(app PRIVATE src/main.c) + +foreach(tvm_lib ${TVM_LIBS}) + string(LENGTH ${tvm_lib} tvm_lib_length) + math(EXPR tvm_lib_cut "${tvm_lib_length} - 2") + string(SUBSTRING ${tvm_lib} 3 ${tvm_lib_cut} tvm_lib_name) + add_library(${tvm_lib_name} STATIC IMPORTED) + set_target_properties(${tvm_lib_name} PROPERTIES + IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/${tvm_lib}) + target_link_libraries(app PRIVATE ${tvm_lib_name}) +endforeach(tvm_lib ${TVM_LIBS}) + +target_include_directories(app PRIVATE ${TVM_INCLUDE_DIRS}) diff --git a/apps/microtvm/zephyr/demo_runtime/README.md b/apps/microtvm/zephyr/host_driven/README.md similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/README.md rename to apps/microtvm/zephyr/host_driven/README.md diff --git a/apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf b/apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf new file mode 100644 index 000000000000..3916b17c49cf --- /dev/null +++ b/apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file is specific to the MPS2-AN512 board. + +# For intrinsics used by generated optimized operators. +CONFIG_CMSIS_DSP=y + +# For random number generation. +CONFIG_ENTROPY_GENERATOR=y +CONFIG_TEST_RANDOM_GENERATOR=y + +# For debugging. +CONFIG_LED=n diff --git a/apps/microtvm/zephyr/demo_runtime/boards/nrf5340dk_nrf5340_cpuapp.conf b/apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/boards/nrf5340dk_nrf5340_cpuapp.conf rename to apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf diff --git a/apps/microtvm/zephyr/demo_runtime/boards/nucleo_f746zg.conf b/apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/boards/nucleo_f746zg.conf rename to apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf diff --git a/apps/microtvm/zephyr/demo_runtime/boards/qemu_riscv32.conf b/apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/boards/qemu_riscv32.conf rename to apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf diff --git a/apps/microtvm/zephyr/demo_runtime/boards/qemu_riscv64.conf b/apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/boards/qemu_riscv64.conf rename to apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf diff --git a/apps/microtvm/zephyr/demo_runtime/boards/qemu_x86.conf b/apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/boards/qemu_x86.conf rename to apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf diff --git a/apps/microtvm/zephyr/demo_runtime/boards/stm32f746g_disco.conf b/apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/boards/stm32f746g_disco.conf rename to apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf diff --git a/apps/microtvm/zephyr/demo_runtime/crt/crt_config.h b/apps/microtvm/zephyr/host_driven/crt/crt_config.h similarity index 97% rename from apps/microtvm/zephyr/demo_runtime/crt/crt_config.h rename to apps/microtvm/zephyr/host_driven/crt/crt_config.h index f8fc7514a28d..658b97e267ba 100644 --- a/apps/microtvm/zephyr/demo_runtime/crt/crt_config.h +++ b/apps/microtvm/zephyr/host_driven/crt/crt_config.h @@ -42,7 +42,7 @@ #define TVM_CRT_MAX_REGISTERED_MODULES 2 /*! Maximum packet size, in bytes, including the length header. */ -#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8192 +#define TVM_CRT_MAX_PACKET_SIZE_BYTES (4 * 1024) /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 diff --git a/apps/microtvm/zephyr/host_driven/prj.conf b/apps/microtvm/zephyr/host_driven/prj.conf new file mode 100644 index 000000000000..5f4d7a0689dc --- /dev/null +++ b/apps/microtvm/zephyr/host_driven/prj.conf @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# The settings in this file are generic for all boards, and are merged +# with the settings in the file boards/.conf by the Zephyr build +# process. + +# For UART implementation in main(). +CONFIG_RING_BUFFER=y +CONFIG_UART_CONSOLE=n +CONFIG_UART_INTERRUPT_DRIVEN=y + +# For RPC server C++ bindings. +CONFIG_CPLUSPLUS=y +CONFIG_NEWLIB_LIBC=y + +# For models with floating point. +CONFIG_FPU=y + +# For TVMPlatformAbort(). +CONFIG_REBOOT=y diff --git a/apps/microtvm/zephyr/host_driven/qemu-hack b/apps/microtvm/zephyr/host_driven/qemu-hack new file mode 120000 index 000000000000..b4810f2aab6e --- /dev/null +++ b/apps/microtvm/zephyr/host_driven/qemu-hack @@ -0,0 +1 @@ +../qemu-hack \ No newline at end of file diff --git a/apps/microtvm/zephyr/demo_runtime/src/main.c b/apps/microtvm/zephyr/host_driven/src/main.c similarity index 85% rename from apps/microtvm/zephyr/demo_runtime/src/main.c rename to apps/microtvm/zephyr/host_driven/src/main.c index 4acca0b9ca12..637a58ae92fd 100644 --- a/apps/microtvm/zephyr/demo_runtime/src/main.c +++ b/apps/microtvm/zephyr/host_driven/src/main.c @@ -61,6 +61,7 @@ static const struct device* led0_pin; static size_t g_num_bytes_requested = 0; static size_t g_num_bytes_written = 0; +static size_t g_num_bytes_in_rx_buffer = 0; // Called by TVM to write serial data to the UART. ssize_t write_serial(void* unused_context, const uint8_t* data, size_t size) { @@ -99,6 +100,7 @@ size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const // Called by TVM when an internal invariant is violated, and execution cannot continue. void TVMPlatformAbort(tvm_crt_error_t error) { + TVMLogf("TVMError: %x", error); sys_reboot(SYS_REBOOT_COLD); #ifdef CONFIG_LED gpio_pin_set(led0_pin, LED0_PIN, 1); @@ -214,33 +216,37 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { } // Ring buffer used to store data read from the UART on rx interrupt. -#define RING_BUF_SIZE_BYTES 4 * 1024 -RING_BUF_DECLARE(uart_rx_rbuf, RING_BUF_SIZE_BYTES); - -// Small buffer used to read data from the UART into the ring buffer. -static uint8_t uart_data[8]; +// This ring buffer size is only required for testing with QEMU and not for physical hardware. +#define RING_BUF_SIZE_BYTES (TVM_CRT_MAX_PACKET_SIZE_BYTES + 100) +RING_BUF_ITEM_DECLARE_SIZE(uart_rx_rbuf, RING_BUF_SIZE_BYTES); // UART interrupt callback. void uart_irq_cb(const struct device* dev, void* user_data) { - while (uart_irq_update(dev) && uart_irq_is_pending(dev)) { + uart_irq_update(dev); + if (uart_irq_is_pending(dev)) { struct ring_buf* rbuf = (struct ring_buf*)user_data; if (uart_irq_rx_ready(dev) != 0) { - for (;;) { - // Read a small chunk of data from the UART. - int bytes_read = uart_fifo_read(dev, uart_data, sizeof(uart_data)); - if (bytes_read < 0) { - TVMPlatformAbort((tvm_crt_error_t)0xbeef1); - } else if (bytes_read == 0) { - break; - } - // Write it into the ring buffer. - int bytes_written = ring_buf_put(rbuf, uart_data, bytes_read); - if (bytes_read != bytes_written) { - TVMPlatformAbort((tvm_crt_error_t)0xbeef2); - } - // CHECK_EQ(bytes_read, bytes_written, "bytes_read: %d; bytes_written: %d", bytes_read, - // bytes_written); + uint8_t* data; + uint32_t size; + size = ring_buf_put_claim(rbuf, &data, RING_BUF_SIZE_BYTES); + int rx_size = uart_fifo_read(dev, data, size); + // Write it into the ring buffer. + g_num_bytes_in_rx_buffer += rx_size; + + if (g_num_bytes_in_rx_buffer > RING_BUF_SIZE_BYTES) { + TVMPlatformAbort((tvm_crt_error_t)0xbeef3); } + + if (rx_size < 0) { + TVMPlatformAbort((tvm_crt_error_t)0xbeef1); + } + + int err = ring_buf_put_finish(rbuf, rx_size); + if (err != 0) { + TVMPlatformAbort((tvm_crt_error_t)0xbeef2); + } + // CHECK_EQ(bytes_read, bytes_written, "bytes_read: %d; bytes_written: %d", bytes_read, + // bytes_written); } } } @@ -251,17 +257,6 @@ void uart_rx_init(struct ring_buf* rbuf, const struct device* dev) { uart_irq_rx_enable(dev); } -// Used to read data from the UART. -int uart_rx_buf_read(struct ring_buf* rbuf, uint8_t* data, size_t data_size_bytes) { - unsigned int key = irq_lock(); - int bytes_read = ring_buf_get(rbuf, data, data_size_bytes); - irq_unlock(key); - return bytes_read; -} - -// Buffer used to read from the UART rx ring buffer and feed it to the UTvmRpcServerLoop. -static uint8_t main_rx_buf[RING_BUF_SIZE_BYTES]; - // The main function of this application. extern void __stdout_hook_install(int (*hook)(int)); void main(void) { @@ -299,13 +294,15 @@ void main(void) { // The main application loop. We continuously read commands from the UART // and dispatch them to UTvmRpcServerLoop(). while (true) { - int bytes_read = uart_rx_buf_read(&uart_rx_rbuf, main_rx_buf, sizeof(main_rx_buf)); + uint8_t* data; + unsigned int key = irq_lock(); + uint32_t bytes_read = ring_buf_get_claim(&uart_rx_rbuf, &data, RING_BUF_SIZE_BYTES); if (bytes_read > 0) { + g_num_bytes_in_rx_buffer -= bytes_read; size_t bytes_remaining = bytes_read; - uint8_t* cursor = main_rx_buf; while (bytes_remaining > 0) { // Pass the received bytes to the RPC server. - tvm_crt_error_t err = UTvmRpcServerLoop(server, &cursor, &bytes_remaining); + tvm_crt_error_t err = UTvmRpcServerLoop(server, &data, &bytes_remaining); if (err != kTvmErrorNoError && err != kTvmErrorFramingShortPacket) { TVMPlatformAbort(err); } @@ -317,7 +314,12 @@ void main(void) { g_num_bytes_requested = 0; } } + int err = ring_buf_get_finish(&uart_rx_rbuf, bytes_read); + if (err != 0) { + TVMPlatformAbort((tvm_crt_error_t)0xbeef6); + } } + irq_unlock(key); } #ifdef CONFIG_ARCH_POSIX diff --git a/apps/microtvm/zephyr/demo_runtime/qemu-hack/qemu-system-arm b/apps/microtvm/zephyr/qemu-hack/qemu-system-arm similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/qemu-hack/qemu-system-arm rename to apps/microtvm/zephyr/qemu-hack/qemu-system-arm diff --git a/apps/microtvm/zephyr/demo_runtime/qemu-hack/qemu-system-i386 b/apps/microtvm/zephyr/qemu-hack/qemu-system-i386 similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/qemu-hack/qemu-system-i386 rename to apps/microtvm/zephyr/qemu-hack/qemu-system-i386 diff --git a/apps/microtvm/zephyr/demo_runtime/qemu-hack/qemu-system-riscv32 b/apps/microtvm/zephyr/qemu-hack/qemu-system-riscv32 similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/qemu-hack/qemu-system-riscv32 rename to apps/microtvm/zephyr/qemu-hack/qemu-system-riscv32 diff --git a/apps/microtvm/zephyr/demo_runtime/qemu-hack/qemu-system-riscv64 b/apps/microtvm/zephyr/qemu-hack/qemu-system-riscv64 similarity index 100% rename from apps/microtvm/zephyr/demo_runtime/qemu-hack/qemu-system-riscv64 rename to apps/microtvm/zephyr/qemu-hack/qemu-system-riscv64 diff --git a/cmake/libs/Libbacktrace.cmake b/cmake/libs/Libbacktrace.cmake index 742855358809..58eb4e02bb5b 100644 --- a/cmake/libs/Libbacktrace.cmake +++ b/cmake/libs/Libbacktrace.cmake @@ -14,14 +14,39 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +# On MacOS, the default C compiler (/usr/bin/cc) is actually a small script that dispatches to a +# compiler the default SDK (usually /Library/Developer/CommandLineTools/usr/bin/ or +# /Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/). CMake +# automatically detects what is being dispatched and uses it instead along with all the flags it +# needs. CMake makes this second compiler avaliable through the CMAKE_C_COMPILER variable, but it +# does not make the necessary flags available. This leads to configuration errors in libbacktrace +# because it can't find system libraries. Our solution is to detect if CMAKE_C_COMPILER lives in +# /Library or /Applications and switch to the default compiler instead. include(ExternalProject) + +if(CMAKE_SYSTEM_NAME MATCHES "Darwin" AND (CMAKE_C_COMPILER MATCHES "^/Library" + OR CMAKE_C_COMPILER MATCHES "^/Applications")) + set(c_compiler "/usr/bin/cc") + else() + set(c_compiler "${CMAKE_C_COMPILER}") +endif() + ExternalProject_Add(project_libbacktrace PREFIX libbacktrace SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}/../../3rdparty/libbacktrace BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/libbacktrace CONFIGURE_COMMAND "${CMAKE_CURRENT_LIST_DIR}/../../3rdparty/libbacktrace/configure" - "--prefix=${CMAKE_CURRENT_BINARY_DIR}/libbacktrace" --with-pic + "--prefix=${CMAKE_CURRENT_BINARY_DIR}/libbacktrace" + --with-pic + "CC=${c_compiler}" + "CFLAGS=${CMAKE_C_FLAGS}" + "LDFLAGS=${CMAKE_EXE_LINKER_FLAGS}" + "CPP=${c_compiler} -E" + "NM=${CMAKE_NM}" + "STRIP=${CMAKE_STRIP}" + "--host=${MACHINE_NAME}" INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/libbacktrace" BUILD_COMMAND make INSTALL_COMMAND make install diff --git a/cmake/modules/ClangFlags.cmake b/cmake/modules/ClangFlags.cmake index 841570dc2e12..563c96272063 100644 --- a/cmake/modules/ClangFlags.cmake +++ b/cmake/modules/ClangFlags.cmake @@ -52,7 +52,7 @@ if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") -Wno-shorten-64-to-32 -Wno-covered-switch-default -Wno-unused-exception-parameter - -Wno-return-std-move-in-c++11 + -Wno-return-std-move -Wno-over-aligned -Wno-undef -Wno-inconsistent-missing-destructor-override diff --git a/cmake/modules/Vulkan.cmake b/cmake/modules/Vulkan.cmake index 095790f08547..3ee13aa38b98 100644 --- a/cmake/modules/Vulkan.cmake +++ b/cmake/modules/Vulkan.cmake @@ -18,37 +18,16 @@ # Be compatible with older version of CMake find_vulkan(${USE_VULKAN}) -# Extra Vulkan runtime options, exposed for advanced users. -tvm_option(USE_VULKAN_IMMEDIATE_MODE "Use Vulkan Immediate mode -(KHR_push_descriptor extension)" ON IF USE_VULKAN) -tvm_option(USE_VULKAN_DEDICATED_ALLOCATION "Use Vulkan dedicated allocations" ON -IF USE_VULKAN) -tvm_option(USE_VULKAN_VALIDATION "Enable Vulkan API validation layers" OFF - IF USE_VULKAN) - if(USE_VULKAN) if(NOT Vulkan_FOUND) message(FATAL_ERROR "Cannot find Vulkan, USE_VULKAN=" ${USE_VULKAN}) endif() include_directories(SYSTEM ${Vulkan_INCLUDE_DIRS}) message(STATUS "Build with Vulkan support") - file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/vulkan.cc) + file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc) file(GLOB COMPILER_VULKAN_SRCS src/target/spirv/*.cc) list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS}) list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS}) list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARY}) - - if(USE_VULKAN_IMMEDIATE_MODE) - message(STATUS "Build with Vulkan immediate mode") - add_definitions(-DUSE_VULKAN_IMMEDIATE_MODE=1) - endif() - if(USE_VULKAN_DEDICATED_ALLOCATION) - message(STATUS "Build with Vulkan dedicated allocation") - add_definitions(-DUSE_VULKAN_DEDICATED_ALLOCATION=1) - endif() - if(USE_VULKAN_VALIDATION) - message(STATUS "Build with Vulkan API validation") - add_definitions(-DUSE_VULKAN_VALIDATION=1) - endif() endif(USE_VULKAN) diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 1ca592f34ab2..4c6a05c34117 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -87,9 +87,9 @@ RUN bash /install/ubuntu_install_tflite.sh COPY install/ubuntu_install_tensorflow.sh /install/ubuntu_install_tensorflow.sh RUN bash /install/ubuntu_install_tensorflow.sh -# Arm(R) Compute Library -COPY install/ubuntu_install_arm_compute_lib.sh /install/ubuntu_install_arm_compute_lib.sh -RUN bash /install/ubuntu_install_arm_compute_lib.sh +# Compute Library +COPY install/ubuntu_download_arm_compute_lib_binaries.sh /install/ubuntu_download_arm_compute_lib_binaries.sh +RUN bash /install/ubuntu_download_arm_compute_lib_binaries.sh # Caffe deps COPY install/ubuntu_install_caffe.sh /install/ubuntu_install_caffe.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index d5dbb9115138..09c6425da6fb 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -114,6 +114,10 @@ ENV C_INCLUDE_PATH=/usr/local/cuda/include:${C_INCLUDE_PATH} ENV LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/compat:${LIBRARY_PATH} ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/compat:${LD_LIBRARY_PATH} +# Ensure the local libcuda have higher priority than the /usr/local/cuda/compact +# since the compact libcuda does not work on non-Tesla gpus +ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:${LD_LIBRARY_PATH} + ENV LD_LIBRARY_PATH=/opt/rocm/lib:${LD_LIBRARY_PATH} ENV PATH=/node_modules/.bin:${PATH} ENV VULKAN_SDK=/usr diff --git a/docker/Dockerfile.ci_lint b/docker/Dockerfile.ci_lint index 61856fff96f4..00599c27f21c 100644 --- a/docker/Dockerfile.ci_lint +++ b/docker/Dockerfile.ci_lint @@ -29,7 +29,7 @@ RUN bash /install/ubuntu1804_install_python.sh RUN apt-get update && apt-get install -y doxygen graphviz -RUN pip3 install cpplint pylint==2.4.4 mypy black +RUN pip3 install cpplint pylint==2.4.4 mypy black==20.8b1 # java deps for rat COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh diff --git a/docker/bash.sh b/docker/bash.sh index 3c4be21c9b9e..80f4a9577be1 100755 --- a/docker/bash.sh +++ b/docker/bash.sh @@ -20,7 +20,7 @@ # # Start a bash, mount /workspace to be current directory. # -# Usage: bash.sh [-i] [--net=host] +# Usage: bash.sh [-i] [--net=host] [--mount path] # # Usage: docker/bash.sh # Starts an interactive session @@ -46,6 +46,14 @@ if [[ "$1" == "--net=host" ]]; then shift 1 fi +# Mount external directory to the docker +CI_DOCKER_MOUNT_CMD=( ) +if [ "$1" == "--mount" ]; then + shift 1 + CI_DOCKER_MOUNT_CMD=( -v "$1:$1" ) + shift 1 +fi + if [ "$#" -lt 1 ]; then echo "Usage: docker/bash.sh [-i] [--net=host] [COMMAND]" exit -1 @@ -154,6 +162,7 @@ ${DOCKER_BINARY} run --rm --pid=host\ ${WORKSPACE_VOLUMES}\ -v ${WORKSPACE}:/workspace \ -v ${SCRIPT_DIR}:/docker \ + "${CI_DOCKER_MOUNT_CMD[@]}" \ "${EXTRA_MOUNTS[@]}" \ -w /workspace \ -e "CI_BUILD_HOME=/workspace" \ diff --git a/docker/install/ubuntu_install_arm_compute_lib.sh b/docker/install/ubuntu_download_arm_compute_lib_binaries.sh similarity index 53% rename from docker/install/ubuntu_install_arm_compute_lib.sh rename to docker/install/ubuntu_download_arm_compute_lib_binaries.sh index c09bb1290a63..ff8ad0eb9073 100755 --- a/docker/install/ubuntu_install_arm_compute_lib.sh +++ b/docker/install/ubuntu_download_arm_compute_lib_binaries.sh @@ -17,62 +17,46 @@ # under the License. set -e -set -u -set -o pipefail - -repo_url="https://github.com/ARM-software/ComputeLibrary.git" -repo_dir="acl" -install_path="/opt/$repo_dir" -architecture_type=$(uname -i) -target_arch="arm64-v8a" # arm64-v8a / arm64-v8.2-a / armv7a -build_type="native" - -tmpdir=$(mktemp -d) - -cleanup() -{ - rm -rf "$tmpdir" -} - -trap cleanup 0 - -apt-get update && \ -apt-get install -y --no-install-recommends \ - git \ - scons \ - bsdmainutils \ - build-essential # Install cross-compiler when not building natively. # Depending on the architecture selected to compile for, # you may need to install an alternative cross-compiler. if [ "$architecture_type" != "aarch64" ]; then - apt-get install -y --no-install-recommends \ + apt-get update && apt-get install -y --no-install-recommends \ g++-aarch64-linux-gnu \ gcc-aarch64-linux-gnu fi -cd "$tmpdir" +compute_lib_version="v21.05" +compute_lib_base_url="https://github.com/ARM-software/ComputeLibrary/releases/download/${compute_lib_version}" +compute_lib_file_name="arm_compute-${compute_lib_version}-bin-linux.tar.gz" +compute_lib_download_url="${compute_lib_base_url}/${compute_lib_file_name}" -git clone "$repo_url" "$repo_dir" +target_lib="linux-arm64-v8a-neon" -cd "$repo_dir" +# uncomment line below if you need asserts/debug version of the library +# target_lib="${target_lib}-asserts" -# pin version to v21.02 -git checkout "v21.02" +extract_dir="arm_compute-${compute_lib_version}-bin-linux" +install_path="/opt/arm/acl" -if [ "$architecture_type" != "aarch64" ]; then - build_type="cross_compile" -fi +tmpdir=$(mktemp -d) + +cleanup() +{ + rm -rf "$tmpdir" +} + +trap cleanup 0 + +cd "$tmpdir" + +curl -sL "${compute_lib_download_url}" -o "${compute_lib_file_name}" +tar xzf "${compute_lib_file_name}" -scons \ - install_dir="$install_path" \ - Werror=1 \ - -j8 \ - debug=0 \ - asserts=0 \ - neon=1 \ - opencl=0 \ - os=linux \ - arch="$target_arch" \ - build="$build_type" +mkdir -p "${install_path}" +cp -r "${extract_dir}/include" "${install_path}/" +cp -r "${extract_dir}/arm_compute" "${install_path}/include/" +cp -r "${extract_dir}/support" "${install_path}/include/" +cp -r "${extract_dir}/utils" "${install_path}/include/" +cp -r "${extract_dir}/lib/${target_lib}" "${install_path}/lib" diff --git a/docker/install/ubuntu_init_zephyr_project.sh b/docker/install/ubuntu_init_zephyr_project.sh index 2116a4d981f5..573ff30c38a8 100755 --- a/docker/install/ubuntu_init_zephyr_project.sh +++ b/docker/install/ubuntu_init_zephyr_project.sh @@ -16,10 +16,35 @@ # specific language governing permissions and limitations # under the License. +# +# Initialize Zephyr Project. +# +# Usage: ubuntu_init_zephyr_project.sh path branch [--commit hash] +# path is the installation path for the repository. +# branch is the zephyr branch. +# --commit is the commit hash number of zephyrproject repository. If not specified, it uses the latest commit. +# + +set -x + DOWNLOAD_DIR=$1 -ZEPHYR_BRANCH=$2 +shift +ZEPHYR_BRANCH=$1 +shift + +commit_hash= +if [ "$1" == "--commit" ]; then + shift + commit_hash=$1 +fi west init --mr ${ZEPHYR_BRANCH} ${DOWNLOAD_DIR} + +if [ -n "$commit_hash" ]; then + cd ${DOWNLOAD_DIR}/zephyr + git checkout ${commit_hash} +fi + cd ${DOWNLOAD_DIR} west update west zephyr-export diff --git a/docker/install/ubuntu_install_qemu.sh b/docker/install/ubuntu_install_qemu.sh index 0adc81b207d7..b1d375253e05 100755 --- a/docker/install/ubuntu_install_qemu.sh +++ b/docker/install/ubuntu_install_qemu.sh @@ -16,10 +16,39 @@ # specific language governing permissions and limitations # under the License. +# +# Install QEMU on Ubuntu. +# +# Usage: ubuntu_install_qemu.sh [--target-list target0,target1,...] +# --target-list is list of target for QEMU comma seperated. e.g. aarch64-softmmu,arm-softmmu,... +# + set -e -set -u set -o pipefail +QEMU_NAME=qemu-5.1.0 +QEMU_SIG_FILE=${QEMU_NAME}.tar.xz.sig +QEMU_TAR_FILE=${QEMU_NAME}.tar.xz + +# Clean previous build +rm -rf ${QEMU_NAME} ${QEMU_SIG_FILE} ${QEMU_TAR_FILE} + +# Get number of cores for build +if [ -n "${TVM_CI_NUM_CORES}" ]; then + num_cores=${TVM_CI_NUM_CORES} +else + num_cores=2 +fi + +# Set target list for QEMU +if [ "$1" == "--target-list" ]; then + shift + target_list=$1 +else + # Build these by defualt for microtvm reference virtual machine and ci_qemu. + target_list="aarch64-softmmu,arm-softmmu,i386-softmmu,riscv32-softmmu,riscv64-softmmu,x86_64-softmmu" +fi + sudo sed -i '/deb-src/s/^# //' /etc/apt/sources.list apt update apt-get -y build-dep qemu @@ -40,12 +69,13 @@ p5ez/+2k4VAIwIQoP5DoO06waLBffvLIAdPPKYsx71K67OoGG2svc7duC/+5qf1x -----END PGP ARMORED FILE----- EOF curl -OLs https://download.qemu.org/qemu-5.1.0.tar.xz -gpg --verify qemu-5.1.0.tar.xz.sig +gpg --verify ${QEMU_SIG_FILE} + +tar -xf ${QEMU_TAR_FILE} -tar -xf qemu-5.1.0.tar.xz -cd qemu-5.1.0 -./configure --target-list=aarch64-softmmu,arm-softmmu,i386-softmmu,riscv32-softmmu,riscv64-softmmu,x86_64-softmmu -make -j2 +cd ${QEMU_NAME} +./configure --target-list=${target_list} +make -j${num_cores} sudo make install # For debugging with qemu diff --git a/docs/api/python/graph_executor.rst b/docs/api/python/graph_executor.rst index 3f8811553ba4..1af93e88458d 100644 --- a/docs/api/python/graph_executor.rst +++ b/docs/api/python/graph_executor.rst @@ -16,6 +16,6 @@ under the License. tvm.contrib.graph_executor -------------------------- +-------------------------- .. automodule:: tvm.contrib.graph_executor :members: diff --git a/docs/conf.py b/docs/conf.py index 45f5da670608..1f645645f25d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -226,10 +226,10 @@ def git_describe_version(original_version): "introduction.py", "install.py", "tvmc_command_line_driver.py", - "auto_tuning_with_python.py", + "autotvm_relay_x86.py", "tensor_expr_get_started.py", - "autotvm_matmul.py", - "autoschedule_matmul.py", + "autotvm_matmul_x86.py", + "auto_scheduler_matmul_x86.py", "cross_compilation_and_rpc.py", "relay_quick_start.py", ], @@ -246,7 +246,7 @@ def git_describe_version(original_version): ], "language": [ "schedule_primitives.py", - "reduciton.py", + "reduction.py", "intrin_math.py", "scan.py", "extern_op.py", diff --git a/docs/contribute/code_guide.rst b/docs/contribute/code_guide.rst index 0ed2ce4ca9e1..725c3ce67b28 100644 --- a/docs/contribute/code_guide.rst +++ b/docs/contribute/code_guide.rst @@ -79,7 +79,7 @@ Python Code Styles ------------------ - The functions and classes are documented in `numpydoc `_ format. - Check your code style using ``make pylint`` -- Stick to language features as in ``python 3.5`` +- Stick to language features as in ``python 3.6`` Writing Python Tests diff --git a/docs/deploy/bnns.rst b/docs/deploy/bnns.rst index 7b62fb15a617..43c7e7bb264f 100644 --- a/docs/deploy/bnns.rst +++ b/docs/deploy/bnns.rst @@ -175,7 +175,8 @@ Operator support | nn.bias_add | Supported by BNNS integration only as a bias part of nn.conv2d or nn.dense | | | fusion | +------------------------+------------------------------------------------------------------------------+ -| add | Supported by BNNS integration only as a bias part of nn.conv2d or nn.dense fusion | +| add | Supported by BNNS integration only as a bias part of nn.conv2d or nn.dense | +| | fusion | +------------------------+------------------------------------------------------------------------------+ | nn.relu | Supported by BNNS integration only as a part of nn.conv2d or nn.dense fusion | +------------------------+------------------------------------------------------------------------------+ diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst index 3cbbb10bd74b..b127de982b61 100644 --- a/docs/deploy/index.rst +++ b/docs/deploy/index.rst @@ -25,12 +25,20 @@ as well as how to integrate it with your project. .. image:: https://tvm.apache.org/images/release/tvm_flexible.png +Build the TVM runtime library +----------------------------- + +.. _build-tvm-runtime-on-target-device: + Unlike traditional deep learning frameworks. TVM stack is divided into two major components: -- TVM compiler, which does all the compilation and optimizations +- TVM compiler, which does all the compilation and optimizations of the model - TVM runtime, which runs on the target devices. -In order to integrate the compiled module, we **do not** need to build entire TVM on the target device. You only need to build the TVM compiler stack on your desktop and use that to cross-compile modules that are deployed on the target device. +In order to integrate the compiled module, we **do not** need to build entire +TVM on the target device. You only need to build the TVM compiler stack on your +desktop and use that to cross-compile modules that are deployed on the target device. + We only need to use a light-weight runtime API that can be integrated into various platforms. For example, you can run the following commands to build the runtime API @@ -46,11 +54,103 @@ on a Linux based embedded system such as Raspberry Pi: cmake .. make runtime -Note that we type `make runtime` to only build the runtime library. +Note that we type ``make runtime`` to only build the runtime library. + +It is also possible to cross compile the runtime. Cross compiling +the runtime library should not be confused with cross compiling models +for embedded devices. + If you want to include additional runtime such as OpenCL, -you can modify `config.cmake` to enable these options. +you can modify ``config.cmake`` to enable these options. After you get the TVM runtime library, you can link the compiled library +.. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/dev/tvm_deploy_crosscompile.svg + :align: center + :width: 85% + +A model (optimized or not by TVM) can be cross compiled by TVM for +different architectures such as ``aarch64`` on a ``x64_64`` host. Once the model +is cross compiled it is neccessary to have a runtime compatible with the target +architecture to be able to run the cross compiled model. + + +Cross compile the TVM runtime for other architectures +----------------------------------------------------- + +In the example :ref:`above ` the runtime library was +compiled on a Raspberry Pi. Producing the runtime library can be done much faster on +hosts that have high performace processors with ample resources (such as laptops, workstation) +compared to a target devices such as a Raspberry Pi. In-order to cross compile the runtime the toolchain +for the target device must be installed. After installing the correct toolchain, +the main difference compared to compiling natively is to pass some additional command +line argument to cmake that specify a toolchain to be used. For reference +building the TVM runtime library on a modern laptop (using 8 threads) for ``aarch64`` +takes around 20 seconds vs ~10 min to build the runtime on a Raspberry Pi 4. + +cross-compile for aarch64 +""""""""""""""""""""""""" + +.. code-block:: bash + + sudo apt-get update + sudo apt-get install gcc-aarch64-linux-gnu g++-aarch64-linux-gnu + +.. code-block:: bash + + cmake .. \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_VERSION=1 \ + -DCMAKE_C_COMPILER=/usr/bin/aarch64-linux-gnu-gcc \ + -DCMAKE_CXX_COMPILER=/usr/bin/aarch64-linux-gnu-g++ \ + -DCMAKE_FIND_ROOT_PATH=/usr/aarch64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DMACHINE_NAME=aarch64-linux-gnu + + make -j$(nproc) runtime + +For bare metal ARM devices the following toolchain is quite handy to install instead of gcc-aarch64-linux-* + +.. code-block:: bash + + sudo apt-get install gcc-multilib-arm-linux-gnueabihf g++-multilib-arm-linux-gnueabihf + + +cross-compile for RISC-V +""""""""""""""""""""""""" + +.. code-block:: bash + + sudo apt-get update + sudo apt-get install gcc-riscv64-linux-gnu g++-riscv64-linux-gnu + + +.. code-block:: bash + + cmake .. \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_VERSION=1 \ + -DCMAKE_C_COMPILER=/usr/bin/riscv64-linux-gnu-gcc \ + -DCMAKE_CXX_COMPILER=/usr/bin/riscv64-linux-gnu-g++ \ + -DCMAKE_FIND_ROOT_PATH=/usr/riscv64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DMACHINE_NAME=riscv64-linux-gnu + + make -j$(nproc) runtime + +The ``file`` command can be used to query the architecture of the produced runtime. + + +.. code-block:: bash + + file libtvm_runtime.so + libtvm_runtime.so: ELF 64-bit LSB shared object, UCB RISC-V, version 1 (GNU/Linux), dynamically linked, BuildID[sha1]=e9ak845b3d7f2c126dab53632aea8e012d89477e, not stripped + + +Optimize and tune models for target devices +------------------------------------------- + The easiest and recommended way to test, tune and benchmark TVM kernels on embedded devices is through TVM's RPC API. Here are the links to the related tutorials. @@ -58,8 +158,11 @@ Here are the links to the related tutorials. - :ref:`tutorial-cross-compilation-and-rpc` - :ref:`tutorial-deploy-model-on-rasp` +Deploy optimized model on target devices +---------------------------------------- + After you finished tuning and benchmarking, you might need to deploy the model on the -target device without relying on RPC. see the following resources on how to do so. +target device without relying on RPC. See the following resources on how to do so. .. toctree:: :maxdepth: 2 @@ -72,3 +175,5 @@ target device without relying on RPC. see the following resources on how to do s tensorrt vitis_ai bnns + + diff --git a/docs/deploy/tensorrt.rst b/docs/deploy/tensorrt.rst index a39d9c8edea7..7950fcfbdbc9 100644 --- a/docs/deploy/tensorrt.rst +++ b/docs/deploy/tensorrt.rst @@ -166,6 +166,14 @@ There are some additional options which can be configured at runtime using envir model can use. It is generally best to use the highest value which does not cause you to run out of memory. You can use ``TVM_TENSORRT_MAX_WORKSPACE_SIZE`` to override this by specifying the workspace size in bytes you would like to use. +* For models which contain a dynamic batch dimension, the varaible ``TVM_TENSORRT_MULTI_ENGINE`` + can be used to determine how TensorRT engines will be created at runtime. The default mode, + ``TVM_TENSORRT_MULTI_ENGINE=0``, will maintain only one engine in memory at a time. If an input + is encountered with a higher batch size, the engine will be rebuilt with the new max_batch_size + setting. That engine will be compatible with all batch sizes from 1 to max_batch_size. This mode + reduces the amount of memory used at runtime. The second mode, ``TVM_TENSORRT_MULTI_ENGINE=1`` + will build a unique TensorRT engine which is optimized for each batch size that is encountered. + This will give greater performance, but will consume more memory. Operator support diff --git a/docs/dev/device_target_interactions.rst b/docs/dev/device_target_interactions.rst new file mode 100644 index 000000000000..e5fa708434fb --- /dev/null +++ b/docs/dev/device_target_interactions.rst @@ -0,0 +1,238 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _tvm-target-specific-overview: + +Device/Target Interactions +========================== + +This documented is intended for developers interested in understanding +how the TVM framework interacts with specific device APIs, or who +may want to implement support for a new API or new hardware. + +There are three main aspects that must be implemented for any new +runtime environment. + +* The :ref:`DeviceAPI ` class gives a + handle to a specific device, and the API used to interact with it. + It defines a common interface for querying device parameters + (e.g. memory available, number of threads, etc.) and for performing + simple actions (e.g. copying memory from the host, or between + buffers on the device). + +* The :ref:`Target ` class contains a + description of the device on which a function will run. It is + exposed both to the target code generators and to the optimization + passes. + +* The :ref:`target code generators ` + construct a :ref:`Module ` consisting of + one or more :ref:`PackedFunc `, from + an IRModule. + +.. _tvm-target-specific-device-api: + +DeviceAPI +--------- + +The ``DeviceAPI`` represents a handle to a specific hardware device +API. (e.g. ``CUDADeviceAPI`` handles all interactions through the +CUDA framework.) Most ``DeviceAPI`` methods accept a ``device_id`` +parameter to specify which device should be accessed. In Python, +these are typically accessed using the :py:func:`tvm.runtime.device` +function, which returns a handle to a specific device, accessed +through a specific API. (e.g. ``tvm.runtime.device('cuda',0)`` gives +access to physical device ``0``, accessed through the CUDA API.) + +.. _device_api.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/device_api.h + +* Attribute queries - ``GetAttr`` allows different + device-specific parameters to be queried, such as the device name, + number of threads, etc. The parameters that can be queried are + defined in ``enum DeviceAttrKind`` in `device_api.h`_. Not all + query-able parameters are supported by all devices. If a parameter + cannot be queried (e.g. ``kMaxClockRate`` on Vulkan), or if a + parameter isn't applicable (e.g. ``kWarpSize`` on CPU), then those + queries should return ``nullptr``. + +* Setting active device - ``SetDevice`` should set a + particular device as being active. If a ``PackedFunc`` generated by + the target-specific code gen requires execution on a device, it + should run on the active device. + +* Memory management - Utilities for allocating and deallocating memory + on the device. + + * Allocate data space - ``AllocDataSpace`` and ``FreeDataSpace`` + allocate and free space on the device. These allocations can be + provided as inputs and outputs to an operator and make up the + primary data flow of the operator graph. It must be possible to + transfer data from the host to/from a data space. The return + value is an opaque ``void*``. While some implementations return a + memory address, this is not required, and the ``void*`` may be an + opaque handle that is interpretable only by the device backend + that generated it. The ``void*`` is used as an argument to other + backend-specific functions, such as ``CopyDataFromTo``. + + * Allocate work space - ``AllocWorkspace`` and ``FreeWorkspace`` + allocate and free space on the device. Unlike data space, these + are used for storage of intermediate values within an operator + definition, and are not required to be transferable to/from the + host device. If a ``DeviceAPI`` subclass does not implement these + methods, they will default to calling the corresponding + ``DataSpace`` functions. + + * Copy data - ``CopyDataFromTo`` should copy data from one location + to another. The type of copy is determined by the ``dev_from`` + and ``dev_to`` parameters. Implementations should support copying + memory from CPU to device, from device to CPU, and from one buffer + to another on a single device. If the source or destination + locations are on the CPU, the corresponding ``void*`` points to a + CPU address that can be passed into ``memcpy``. If the source or + destinations locations are on the device, the corresponding + ``void*`` was previously generated by either ``AllocDataSpace`` or + ``AllocWorkspace``. + + These copies are queued to execute on a specific + ``TVMStreamHandle``. However, implementations should not assume + that CPU buffers remains valid or accessible after the call to + ``CopyDataFromTo`` completes. + + +* Execution stream management - Utilities for handling + ``TVMStreamHandle``, which represents parallel streams of execution + used to execute commands. + + * Create stream - ``CreateStream`` and ``FreeStream`` should + allocate/free a handle to a stream of execution. If a device + implements only a single queue of commands, then ``CreateStream`` + should return ``nullptr``. + + * Set active stream - ``SetStream`` should set a stream as being + active. While active, if a ``PackedFunc`` generated by the + target-specific code gen requires execution on a device, the work + should be submitted to the active stream. + + * Synchronize to CPU - ``StreamSync`` should synchronize a stream of + execution to the CPU. The call to ``StreamSync`` should return + once all memory transfers and computations submitted prior to the + ``StreamSync`` call have completed. + + * Synchronize between streams - ``SyncStreamFromTo`` should + introduce a synchronization barrier between the source and + destination stream. That is, the destination stream may not + proceed beyond commands currently queued until the source stream + has completed all commands that are currently queued. + + +In order to be usable by the TVM framework, the new DeviceAPI should +then be registered with the following steps. + +#. Create a function that instantiates the new DeviceAPI, and returns + a pointer to it:: + + FooDeviceAPI* FooDeviceAPI::Global() { + static FooDeviceAPI inst; + return &inst; + } + +#. Register the function to the tvm registry:: + + TVM_REGISTER_GLOBAL("device_api.foo").set_body_typed(FooDeviceAPI::Global); + +.. _c_runtime_api.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/c_runtime_api.h + +#. Add an entry for the new DeviceAPI to the ``TVMDeviceExtType`` enum + in `c_runtime_api.h`_. The value should be an unused value greater + than ``DLDeviceType::kDLExtDev``, but less than + ``DeviceAPIManager::kMaxDeviceAPI``. + +#. Add a case in ``DeviceName`` in `device_api.h`_ to convert from the + enum value to a string representation. This string representation + should match the name given to ``TVM_REGISTER_GLOBAL``. + +#. Add entries to the ``MASK2STR`` and ``STR2MASK`` dictionaries of + :py:class:`tvm.runtime.Device` for the new enum value. + + +.. _tvm-target-specific-target: + +Target Definition +----------------- + +The ``Target`` object is a lookup table of properties about a physical +device, its hardware/driver limits, and its capabilities. The +``Target`` is accessible both during optimization and code generation +stages. While the same ``Target`` class is used for all runtime +targets, each runtime target may need to add target-specific options. + +.. _target_kind.cc: https://github.com/apache/tvm/blob/main/src/target/target_kind.cc + +In `target_kind.cc`_, add a new declaration of +``TVM_REGISTER_TARGET_KIND``, passing a string name of the new target, +and the ``TVMDeviceExtType`` or ``DLDeviceType`` enum value for the +device on which that target should run. Typically, the target name +and the device name will match. (e.g. The ``"cuda"`` target runs on +the ``kDLCUDA`` device.) There are exceptions, such as when multiple +different code generation targets can run on the same physical device. +(e.g. The ``"llvm"`` and ``"c"`` targets both run on the ``kDLCPU`` +device type.) + +All options for a specific target kind are added with the +``add_attr_option`` function, with optional default values. A +preprocessor can be added with ``set_attrs_preprocessor`` to define +any parameters that are dynamically based on other parameters or +queried from device properties. + +This argument definition defines a parser that can unpack a string +description of a target. This is done in the ``Target::Target(const +String&)`` constructor in C++, which accepts a JSON-formatted string +and is typically called using the :py:class:`tvm.target.Target` python +object. For example, ``tvm.target.Target('{"kind": "cuda", +"max_num_threads": 1024}')`` will create a ``cuda`` target, while +overriding the default maximum number of threads. + +In a code generator, the target properties can be accessed using +``target->GetAttr(param_name)`` in C++, or with the +``target.attrs`` dictionary in Python. + + +.. _tvm-target-specific-codegen: + +Target Code Generators +---------------------- + +The code generators take an optimized ``IRModule`` and converts it +into an executable representation. Each code generator must be +registered in order to be used by the TVM framework. This is done by +registering a function named ``"target.build.foo"``, where ``foo`` is +the same name as was used in the ``TVM_REGISTER_TARGET_KIND`` +definition above. :: + + tvm::runtime::Module GeneratorFooCode(IRModule mod, Target target); + TVM_REGISTER_GLOBAL("target.build.foo").set_body_typed(GeneratorFooCode); + +The code generator takes two arguments. The first is the ``IRModule`` +to compile, and the second is the ``Target`` that describes the device +on which the code should run. Because the environment performing the +compilation is not necessarily the same as the environment that will +be executing the code, code generators should not perform any +attribute lookups on the device itself, and should instead access +parameters stored in the ``Target``. + +Each function in the input ``IRModule`` should be accessible by name +in the output ``runtime::Module``. diff --git a/docs/dev/index.rst b/docs/dev/index.rst index 7eeecc12b33c..873af9c6a3b7 100644 --- a/docs/dev/index.rst +++ b/docs/dev/index.rst @@ -24,9 +24,11 @@ This page is organized as follows: - The `Example Compilation Flow`_ gives an overview of the steps that TVM takes to turn a high level description of a model into a deployable module. To get started, please read this section first. + - The `Logical Architecture Components`_ section describes the logical components. The sections after are specific guides focused on each logical component, organized by the component's name. + - Feel free to also check out the :ref:`dev-how-to` for useful development tips. This guide provides a few complementary views of the architecture. @@ -245,6 +247,7 @@ for learning-based optimizations. debugger virtual_machine introduction_to_module_serialization + device_target_interactions tvm/node -------- @@ -312,6 +315,11 @@ It also provides a common `Target` class that describes the target. The compilation pipeline can be customized according to the target by querying the attribute information in the target and builtin information registered to each target id(cuda, opencl). +.. toctree:: + :maxdepth: 1 + + device_target_interactions + tvm/tir ------- diff --git a/docs/dev/runtime.rst b/docs/dev/runtime.rst index fc03ed806bac..dfda00c1d6c4 100644 --- a/docs/dev/runtime.rst +++ b/docs/dev/runtime.rst @@ -42,8 +42,11 @@ We also want the runtime core to be minimal to deploy to embedded devices. PackedFunc ---------- -`PackedFunc`_ is a simple but elegant solution -we find to solve the challenges listed. The following code block provides an example in C++ +`PackedFunc`_ is a simple but elegant solution we find to solve the +challenges listed. A single ``PackedFunc`` object represents a +function call whose caller and callee may be in different languages. + +The following code block provides an example in C++ .. _PackedFunc: https://github.com/apache/tvm/blob/main/include/tvm/runtime/packed_func.h @@ -147,6 +150,8 @@ The overhead of calling into PackedFunc vs. a normal function is small, as it is So it is OK as long as we don't wrap small functions. In summary, the PackedFunc is the universal glue in TVM where we use it extensively to support our compiler and deployment. +.. _tvm-runtime-system-module: + Module ------ @@ -291,3 +296,13 @@ To support extension types, we used a registry system to register type related i in C++, see `Extension types`_ for more details. .. _Extension types: https://github.com/apache/tvm/tree/main/apps/extension + + +Runtime-Specific Information +============================ + +.. toctree:: + :maxdepth: 1 + :glob: + + runtimes/* diff --git a/docs/dev/runtimes/vulkan.rst b/docs/dev/runtimes/vulkan.rst new file mode 100644 index 000000000000..ed0dbe33a305 --- /dev/null +++ b/docs/dev/runtimes/vulkan.rst @@ -0,0 +1,207 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _tvm-runtime-vulkan: + +Vulkan Runtime +============== + +TVM supports using Vulkan compute shaders to execute queries. Each +computational kernel is compiled into a SPIR-V shader, which can then +be called using the TVM interface. + +.. _tvm-runtime-vulkan-features: + +Vulkan Features, Limits +----------------------- + +.. _Required Limits: https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/vkspec.html#limits-minmax + +Since different Vulkan implementations may enable different optional +features or have different physical limits, the code generation must +know which features are available to use. These correspond to +specific Vulkan capabilities/limits as in +:ref:`Vulkan Capabilities Table `. +If unspecified, TVM assumes that a capability is not available, or +that a limit is the minimum guaranteed by the Vulkan spec in the +`Required Limits`_ section. + +These parameters can be either explicitly specific when defining a +:ref:`Target `, or can be queried from a +device. To query from a device, the special parameter +``-from_device=N`` can be used to query all vulkan device parameters +from device id ``N``. Any additional parameters explicitly specified +will override the parameters queried from the device. + +.. _VkSubgroupFeatureFlagBits: https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkSubgroupFeatureFlagBits.html + +.. list-table:: Vulkan Capabilities + :name: tvm-runtime-table-vulkan-capabilities + :header-rows: 1 + + * - Target Parameter + - Required Vulkan Version/Extension + - Parameter Queried + - Default Value + + * - ``supported_subgroup_operations`` + - Vulkan 1.1+ + - ``VkPhysicalDeviceSubgroupProperties::supportedOperations`` + - 0 (interpreted as `VkSubgroupFeatureFlagBits`_) + + * - ``max_push_constants_size`` + - + - ``VkPhysicalDeviceLimits::maxPushConstantsSize`` + - 128 bytes + + * - ``max_uniform_buffer_range`` + - + - ``VkPhysicalDeviceLimits::maxUniformBufferRange`` + - 16384 bytes + + + * - ``max_storage_buffer_range`` + - + - ``VkPhysicalDeviceLimits::maxStorageBufferRange`` + - 2\ :sup:`27`\ bytes + + + * - ``max_per_stage_descriptor_storage_buffer`` + - + - ``VkPhysicalDeviceLimits::maxPerStageDescriptorStorageBuffers`` + - 4 + + + * - ``supports_storage_buffer_storage_class`` + - VK_KHR_storage_buffer_storage_class + - + - false + + + * - ``supports_storage_buffer_8bit_access`` + - VK_KHR_8bit_storage + - ``VkPhysicalDevice8BitStorageFeaturesKHR::storageBuffer8BitAccess`` + - false + + + * - ``supports_storage_buffer_16bit_access`` + - VK_KHR_16bit_storage + - ``VkPhysicalDevice16BitStorageFeaturesKHR::storageBuffer16BitAccess`` + - false + + + * - ``supports_float16`` + - VK_KHR_shader_float16_int8 + - ``VkPhysicalDeviceShaderFloat16Int8FeaturesKHR::shaderFloat16`` + - false + + + * - ``supports_float64`` + - + - ``VkPhysicalDeviceFeatures::shaderFloat64`` + - false + + + * - ``supports_int8`` + - VK_KHR_shader_float16_int8 + - ``VkPhysicalDeviceShaderFloat16Int8FeaturesKHR::shaderInt8`` + - false + + + * - ``supports_int16`` + - + - ``VkPhysicalDeviceFeatures::shaderInt16`` + - false + + + * - ``supports_int64`` + - + - ``VkPhysicalDeviceFeatures::shaderInt64`` + - false + + + +As of May 2021, not all Vulkan implementations are supported. For +example, support for 64-bit integers is required. If a Vulkan target +is not supported, an error message should be issued during SPIR-V code +generation. Efforts are also underway to remove these requirements +and support additional Vulkan implementations. + + +.. _tvm-runtime-vulkan-spirv-capabilities: + +SPIR-V Capabilities +------------------- + +Some of the device-specific capabilities also correspond to SPIR-V +capabilities or extensions that must be declared in the shader, or a +minimum SPIR-V version required in order to use a feature. The +TVM-generated shaders will declare the minimum set of +extensions/capabilities and the minimum allowed version of SPIR-V +that are needed to execute the compiled graph. + +If the shader generation requires a capability or extension that is +not enabled in the ``Target``, an exception will be raised. + + +.. list-table:: Vulkan Capabilities + :name: tvm-table-vulkan-capabilities + :header-rows: 1 + + * - Target Parameter + - Required SPIR-V Version/Extension + - Declared Capability + + * - ``supported_subgroup_operations`` + - SPIR-V 1.3+ + - Varies, see `VkSubgroupFeatureFlagBits`_ + + * - ``supports_storage_buffer_storage_class`` + - SPV_KHR_storage_buffer_storage_class + - + + * - ``supports_storage_buffer_8bit_access`` + - SPV_KHR_8bit_storage + - StorageBuffer8BitAccess + + * - ``supports_storage_buffer_16bit_access`` + - SPV_KHR_16bit_storage + - StorageBuffer16BitAccess + + * - ``supports_float16`` + - + - Float16 + + + * - ``supports_float64`` + - + - Float64 + + + * - ``supports_int8`` + - + - Int8 + + + * - ``supports_int16`` + - + - Int16 + + + * - ``supports_int64`` + - + - Int64 diff --git a/docs/index.rst b/docs/index.rst index 2a1078e645ab..a7ae68c87b01 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -46,6 +46,7 @@ For Developers dev/how_to microtvm/index errors + faq .. toctree:: :maxdepth: 1 @@ -77,7 +78,6 @@ For Developers :caption: MISC vta/index - faq Index diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index bc6cdb90da15..5d723d1ce048 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -51,27 +51,31 @@ Build the Shared Library Our goal is to build the shared libraries: -- On Linux the target library are `libtvm.so` -- On macOS the target library are `libtvm.dylib` -- On Windows the target library are `libtvm.dll` + - On Linux the target library are `libtvm.so` and `libtvm_runtime.so` + - On macOS the target library are `libtvm.dylib` and `libtvm_runtime.dylib` + - On Windows the target library are `libtvm.dll` and `libtvm_runtime.dll` +It is also possible to :ref:`build the runtime ` library only. + +The minimal building requirements for the ``TVM`` libraries are: + + - A recent c++ compiler supporting C++ 14 (g++-5 or higher) + - CMake 3.5 or higher + - We highly recommend to build with LLVM to enable all the features. + - If you want to use CUDA, CUDA toolkit version >= 8.0 is required. If you are upgrading from an older version, make sure you purge the older version and reboot after installation. + - On macOS, you may want to install `Homebrew ` to easily install and manage dependencies. + +To install the these minimal pre-requisites on Ubuntu/Debian like +linux operating systems, execute (in a terminal): .. code:: bash sudo apt-get update sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev -The minimal building requirements are - -- A recent c++ compiler supporting C++ 14 (g++-5 or higher) -- CMake 3.5 or higher -- We highly recommend to build with LLVM to enable all the features. -- If you want to use CUDA, CUDA toolkit version >= 8.0 is required. If you are upgrading from an older version, make sure you purge the older version and reboot after installation. -- On macOS, you may want to install `Homebrew ` to easily install and manage dependencies. - We use cmake to build the library. -The configuration of TVM can be modified by `config.cmake`. +The configuration of TVM can be modified by editing `config.cmake` and/or by passing cmake flags to the command line: - First, check the cmake in your system. If you do not have cmake, diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index 257fe085bfe5..49d3a42d3e98 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -80,6 +80,17 @@ Here is another example to match an op with a specific attribute: y = relay.var('y') assert not is_conv2d.match(relay.op.nn.conv2d(x, y)) +Or a convolution with a specific kernel size: + +.. code-block:: python + + def test_match_kernel_size(): + is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]}) + x = relay.var('x') + y = relay.var('y') + assert is_conv2d.match(relay.op.nn.conv2d(x, y, kernel_size=[3, 3])) + + Matching an Optional Op *********************** diff --git a/golang/src/tvm_runtime_pack.cc b/golang/src/tvm_runtime_pack.cc index e6d8e74ae0f9..c2add6a36734 100644 --- a/golang/src/tvm_runtime_pack.cc +++ b/golang/src/tvm_runtime_pack.cc @@ -68,3 +68,4 @@ // Uncomment the following lines to enable OpenCL // #include "../../src/runtime/opencl/opencl_device_api.cc" // #include "../../src/runtime/opencl/opencl_module.cc" +// #include "../src/runtime/source_utils.cc" diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index f8e63ed5857a..cf84b9a3a641 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -25,7 +25,6 @@ #include #include -#include #include #include diff --git a/include/tvm/arith/pattern.h b/include/tvm/arith/pattern.h index 3f1096b10a8b..5e1165d509c4 100644 --- a/include/tvm/arith/pattern.h +++ b/include/tvm/arith/pattern.h @@ -25,7 +25,6 @@ #define TVM_ARITH_PATTERN_H_ #include -#include #include namespace tvm { diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index caff37cbf6d2..0ca14c43eb47 100755 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -50,7 +50,6 @@ #include #include -#include #include #include diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 71a69a000944..418d532fdd5f 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -42,17 +43,68 @@ #include namespace tvm { + +/*! + * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) + * \param mod The IRmodule to lower + * \param simple_mode Disables the loop partition pass. Defaults to false. + * \return The result module. + */ +TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false); + +/*! + * \brief Lower a primfunc and name (convert to IRModule, and optimize it with the pass list + * defined in CreatePassList) + * \param func The PrimFunc to lower + * \param name The name of the lowered function. + * \param simple_mode Disables the loop partition pass. Defaults to false. + * \return The result module. + */ +TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, + bool simple_mode = false); + /*! - * \brief Build an IRModule given a schedule, args and binds - * \param sch The schedule to lower. + * \brief Build an IRModule given a TE schedule, args and binds. This function also applies + * the lowering passes defined in CreatePassList. + * \param sch The TE schedule to lower. * \param args The arguments to the function. * \param name The name of the lowered function. * \param binds Buffer assignments. + * \param simple_mode Disables the loop partition pass. Defaults to false. * \return The result module. */ -TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds); +TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, + const std::string& name, + const std::unordered_map& binds, + bool simple_mode = false); + +/*! + * \brief Build an IRModule given a TE schedule, args and binds. This function also applies + * the lowering passes defined in CreatePassList. + * \param sch The TE schedule to lower. + * \param args The arguments to the function (Array of Tensor, Buffer and Vars) + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \param simple_mode Disables the loop partition pass. Defaults to false. + * \return The result module. + */ +TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, + const std::string& name, + const std::unordered_map& binds, + bool simple_mode = false); + +/*! + * \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want + * to apply lowering passes as well, use LowerSchedule. + * \param sch The schedule + * \param args The arguments to the function. + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \return The result module. + */ +IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds); /*! * \brief Build a device and host module for a specific target from an IRModule. * \param funcs The functions to be built. diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h index 231c04e69821..50e9bcbab273 100644 --- a/include/tvm/ir/adt.h +++ b/include/tvm/ir/adt.h @@ -30,7 +30,9 @@ #include #include #include -#include +#include +#include +#include #include #include diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 2295baa0297b..b910d32ceca4 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 5b9e0714e202..c1a012f05318 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -26,7 +26,9 @@ #include #include -#include +#include +#include +#include #include #include diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h new file mode 100644 index 000000000000..1b9eb9c1b7c8 --- /dev/null +++ b/include/tvm/ir/instrument.h @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/instrument.h + * + * This file introduces a pass instrument infrastructure, inspired by LLVM and MLIR. + * It inserts instrumentation points around passes. + */ +#ifndef TVM_IR_INSTRUMENT_H_ +#define TVM_IR_INSTRUMENT_H_ + +#include +#include + +#include +#include + +namespace tvm { + +class IRModule; + +// Forward class for PassInstrumentNode methods +namespace transform { +class PassInfo; +} // namespace transform + +namespace instrument { + +/*! + * \brief PassInstrumentNode forms an instrument implementation. + * It provides API for users to register callbacks at different instrumentation points. + * + * Within a PassContext, call sequence of a PassInstrument implementation is like: + * + * with PassContext(instruments=[pi]): # pi = a PassInstrument implementation + * pi.EnterPassContext() + * + * if pi.ShouldRun(Pass1): + * pi.RunBeforePass() + * Pass1() + * pi.RunAfterPass() + * + * if pi.ShouldRun(Pass2): + * pi.RunBeforePass() + * Pass2() + * pi.RunAfterPass() + * + * pi.ExitPassContext() + * + * `EnterPassContext` and `ExitPassContext` are only called once when entering/exiting a + * PassContext. `ShouldRun`, `RunBeforePass` and `RunAfterPass` are called multiple times depending + * on how many passes. + * + * If there are multiple pass instrumentations provided, the instrument points are the same. + * PassInstrument implementations' callbacks are called in order: + * + * with PassContext(instruments=[pi1, pi2]): # pi1, pi2 = two distinct PassInstrument impls + * pi.EnterPassContext() for pi in instruments + * + * should_run = all([pi.ShoudRun(Pass1) for pi in instruments)]) + * if (should_run) + * pi.RunBeforePass() for pi in instruments + * Pass1() + * pi.RunAfterPass() for pi in instruments + * + * should_run = all([pi.ShouldRun(Pass2) for pi in instruments)]) + * if (should_run) + * pi.RunBeforePass() for pi in instruments + * Pass2() + * pi.RunAfterPass() for pi in instruments + * + * pi.ExitPassContext() for pi in instruments + * + * Note: + * 1. Assume there is no dependency between PassInstrument implementations in `instruments` . + * 2. `EnterPassContext` and `ExitPassContext` have `with` behavior (see PassContext and its FFI): + * If there is any exception raised in `ShouldRun()`, `RunBeforePass()`, `RunAfterPass()` and + * `Pass()`, `ExitPassContext()` is still called. + * 3. In mutiple PassInstrument instances scenario, callbacks are called in order: + * If one throws exceptions, remainings will not be called. + * + * \sa PassInstrument + * \sa src/ir/transform.cc + */ +class PassInstrumentNode : public Object { + public: + /*! \brief Name of this pass instrument object. */ + String name; + + virtual ~PassInstrumentNode() {} + + /*! \brief Instrument when entering PassContext. Called once within a PassContext. */ + virtual void EnterPassContext() const = 0; + + /*! \brief Instrument when exiting PassContext. Called once within a PassContext. */ + virtual void ExitPassContext() const = 0; + + /*! + * \brief Determine whether to run the pass or not. Called multiple times depend on number of + * passes. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + * + * \return true to run the pass; false to skip the pass. + */ + virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0; + + /*! + * \brief Instrument before pass run. Called multiple times depend on number of passes. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0; + + /*! + * \brief Instrument after pass run. Called multiple time depend on number of passes. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0; + + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } + + static constexpr const char* _type_key = "instrument.PassInstrument"; + TVM_DECLARE_BASE_OBJECT_INFO(PassInstrumentNode, Object); +}; + +/*! + * \brief Managed reference class for PassInstrumentNode + * \sa PassInstrumentNode + */ +class PassInstrument : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode); +}; + +} // namespace instrument +} // namespace tvm + +#endif // TVM_IR_INSTRUMENT_H_ diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 07d582a298e4..638f132e3179 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -29,7 +29,9 @@ #include #include #include -#include +#include +#include +#include #include #include diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 9456ea80d860..a18d42902503 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -244,12 +244,18 @@ class OpRegEntry { runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)> type_rel_func); /*! - * \brief Set the the attrs type key and index to be AttrsType. + * \brief Set the attrs type key and index to be AttrsType. * \tparam AttrsType the attribute type to b set. * \return reference to self. */ template inline OpRegEntry& set_attrs_type(); + /*! + * \brief Set the attrs type key and index to be AttrsType. + * \param key The attribute type key to be set. + * \return reference to self. + */ + inline OpRegEntry& set_attrs_type_key(const String& key); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. @@ -454,6 +460,12 @@ inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*) return *this; } +inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*) + get()->attrs_type_key = key; + get()->attrs_type_index = Object::TypeKey2Index(key); + return *this; +} + inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*) get()->support_level = n; return *this; diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 50c6f8dd8c3a..cb556fc13de7 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -58,8 +58,10 @@ #include #include +#include #include -#include +#include +#include #include #include @@ -68,15 +70,6 @@ namespace tvm { namespace transform { -// Forward declare for TraceFunc. -class PassInfo; - -/*! - * \brief A callback for tracing passes, useful for debugging and logging. - */ -using TraceFunc = - runtime::TypedPackedFunc; - /*! * \brief PassContextNode contains the information that a pass can rely on, * such as analysis results. @@ -95,8 +88,9 @@ class PassContextNode : public Object { mutable Optional diag_ctx; /*! \brief Pass specific configurations. */ Map config; - /*! \brief Trace function to be invoked before and after each pass. */ - TraceFunc trace_func; + + /*! \brief A list of pass instrument implementations. */ + Array instruments; PassContextNode() = default; @@ -134,6 +128,7 @@ class PassContextNode : public Object { v->Visit("opt_level", &opt_level); v->Visit("required_pass", &required_pass); v->Visit("disabled_pass", &disabled_pass); + v->Visit("instruments", &instruments); v->Visit("config", &config); v->Visit("diag_ctx", &diag_ctx); } @@ -189,12 +184,46 @@ class PassContext : public ObjectRef { TVM_DLL static PassContext Current(); /*! - * \brief Apply the tracing functions of the context to the module, with the info. - * \param module The IRModule to trace. + * \brief Get all supported configuration names and metadata, registered within the PassContext. + * \return Map indexed by the config name, pointing to the metadata map as key-value + */ + TVM_DLL static Map> ListConfigs(); + + /*! + * \brief Call instrument implementations' callbacks when entering PassContext. + * The callbacks are called in order, and if one raises an exception, the rest will not be + * called. + */ + TVM_DLL void InstrumentEnterPassContext(); + + /*! + * \brief Call instrument implementations' callbacks when exiting PassContext. + * The callbacks are called in order, and if one raises an exception, the rest will not be + * called. + */ + TVM_DLL void InstrumentExitPassContext(); + + /*! + * \brief Call instrument implementations' callbacks before a pass run. + * The callbacks are called in order, and if one raises an exception, the rest will not be + * called. + * + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + * + * \return false: the pass is skipped; true: the pass runs. + */ + TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; + + /*! + * \brief Call instrument implementations callbacks after a pass run. + * The callbacks are called in order, and if one raises an exception, the rest will not be + * called. + * + * \param mod The module that an optimization pass runs on. * \param info The pass information. - * \param is_before Indicated whether the tracing is before or after a pass. */ - TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const; + TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; /*! * \brief Check whether a pass is enabled. @@ -275,7 +304,7 @@ class PassInfoNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object); }; -/* +/*! * \brief Managed reference class for PassInfoNode * \sa PassInfoNode */ diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 4e4e009b2875..c772650809fa 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -51,7 +51,7 @@ #include #include -#include +#include #include #include diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h index 6acd2e7dbdd8..c4b54ef0f27d 100644 --- a/include/tvm/node/attr_registry_map.h +++ b/include/tvm/node/attr_registry_map.h @@ -23,7 +23,7 @@ #ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_ #define TVM_NODE_ATTR_REGISTRY_MAP_H_ -#include +#include #include #include diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 7b2a9f8061b4..ad4fb1e1c27a 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -39,7 +39,6 @@ #include #include #include -#include #include #include diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index d5309bca894d..6c25c3d2d21d 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -24,7 +24,7 @@ #define TVM_NODE_STRUCTURAL_EQUAL_H_ #include -#include +#include #include #include diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index a661a852780d..887a012cfc93 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -24,7 +24,6 @@ #define TVM_NODE_STRUCTURAL_HASH_H_ #include -#include #include #include diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 15f6b03f0c06..a58bb8750c14 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -29,8 +29,6 @@ #include -#include "tvm/runtime/container.h" - namespace tvm { namespace relay { diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index cc97a94a1406..69a9c64a4588 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -146,11 +146,18 @@ struct GatherAttrs : public tvm::AttrsNode { struct GatherNDAttrs : public tvm::AttrsNode { Integer batch_dims; + Optional index_rank; TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") { TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions."); + TVM_ATTR_FIELD(index_rank) + .set_default(NullValue()) + .describe( + "The size of an indexing tuple, which is a fixed value. Only needed when the number of " + "indexting tuples is dynamic."); } }; + struct TakeAttrs : public tvm::AttrsNode { Integer batch_dims; Integer axis; @@ -303,6 +310,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode { Optional> end; Optional> strides; std::string slice_mode; + Optional> axes; TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive"); @@ -317,6 +325,9 @@ struct StridedSliceAttrs : public tvm::AttrsNode { "size - The input strides will be ignored, input end in this mode indicates the size" "of a slice starting at the location specified by begin. If end[i] is -1," "all remaining elements in that dimension are included in the slice"); + TVM_ATTR_FIELD(axes).describe( + "Axes along which slicing is applied. When it is specified, the length of begin, end, " + "strides, and axes must be equal."); } }; diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 005b900d5d44..976304e79c34 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -114,11 +114,19 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { + std::string output_format; + TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs, - "relay.attrs.AllClassNonMaximumSuppressionAttrs") {} + "relay.attrs.AllClassNonMaximumSuppressionAttrs") { + TVM_ATTR_FIELD(output_format) + .set_default("onnx") + .describe( + "Output format, onnx or tensorflow. Returns outputs in a way that can be easily " + "consumed by each frontend."); + } }; /*! \brief Attributes used in roi_align operators */ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 17718d14da00..daad8514f9ff 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -227,6 +227,11 @@ class Var : public Expr { class Call; /*! \brief Call container. */ class CallNode : public ExprNode { + protected: + // CallNode uses own deleter to indirectly call non-recursive destructor + Object::FDeleter saved_deleter_; + static void Deleter_(Object* ptr); + public: /*! * \brief The operator(function) being invoked @@ -290,6 +295,7 @@ class CallNode : public ExprNode { static constexpr const char* _type_key = "relay.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); + friend class Call; }; class Call : public Expr { diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index 4a5de33af4b9..751593f94cc0 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -26,7 +26,6 @@ #include #include -#include #include #include diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index e3fd5ae77193..93a56cede77b 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -36,7 +36,7 @@ #include #include -#include +#include #include #include diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 123b7e395faa..b090e3e40063 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -30,7 +30,6 @@ #include #include #include -#include #include #include diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h deleted file mode 100644 index edceabc3525a..000000000000 --- a/include/tvm/runtime/container.h +++ /dev/null @@ -1,3124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/container.h - * \brief Common POD(plain old data) container types. - */ -#ifndef TVM_RUNTIME_CONTAINER_H_ -#define TVM_RUNTIME_CONTAINER_H_ - -#ifndef USE_FALLBACK_STL_MAP -#define USE_FALLBACK_STL_MAP 0 -#endif - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -// We use c++14 std::experimental::string_view for optimizing hash computation -// only right now, its usage is limited in this file. Any broader usage of -// std::experiment in our core codebase is discouraged and needs community -// discussion for each use case. Reference for feature test macros of -// string_view: -// https://isocpp.org/std/standing-documents/sd-6-sg10-feature-test-recommendations -// https://en.cppreference.com/w/User:D41D8CD98F/feature_testing_macros -#if defined(__cpp_lib_experimental_string_view) && __cpp_lib_experimental_string_view >= 201411 -#define TVM_USE_CXX14_STRING_VIEW_HASH 1 -#else -#define TVM_USE_CXX14_STRING_VIEW_HASH 0 -#endif - -// Tested with clang version 9.0.1 and c++17. It will detect string_view support -// correctly. -#if defined(__cpp_lib_string_view) && __cpp_lib_string_view >= 201606 -#define TVM_USE_CXX17_STRING_VIEW_HASH 1 -#else -#define TVM_USE_CXX17_STRING_VIEW_HASH 0 -#endif - -#if TVM_USE_CXX17_STRING_VIEW_HASH -#include -#elif TVM_USE_CXX14_STRING_VIEW_HASH -#include -#endif - -#include -#include -#include - -namespace llvm { -// String to llvm object compatibility. -class StringRef; -} // namespace llvm - -namespace tvm { -namespace runtime { - -// Forward declare TVMArgValue -class TVMArgValue; - -/*! \brief String-aware ObjectRef equal functor */ -struct ObjectHash { - /*! - * \brief Calculate the hash code of an ObjectRef - * \param a The given ObjectRef - * \return Hash code of a, string hash for strings and pointer address otherwise. - */ - size_t operator()(const ObjectRef& a) const; -}; - -/*! \brief String-aware ObjectRef hash functor */ -struct ObjectEqual { - /*! - * \brief Check if the two ObjectRef are equal - * \param a One ObjectRef - * \param b The other ObjectRef - * \return String equality if both are strings, pointer address equality otherwise. - */ - bool operator()(const ObjectRef& a, const ObjectRef& b) const; -}; - -/*! - * \brief Base template for classes with array like memory layout. - * - * It provides general methods to access the memory. The memory - * layout is ArrayType + [ElemType]. The alignment of ArrayType - * and ElemType is handled by the memory allocator. - * - * \tparam ArrayType The array header type, contains object specific metadata. - * \tparam ElemType The type of objects stored in the array right after - * ArrayType. - * - * \code - * // Example usage of the template to define a simple array wrapper - * class ArrayObj : public InplaceArrayBase { - * public: - * // Wrap EmplaceInit to initialize the elements - * template - * void Init(Iterator begin, Iterator end) { - * size_t num_elems = std::distance(begin, end); - * auto it = begin; - * this->size = 0; - * for (size_t i = 0; i < num_elems; ++i) { - * InplaceArrayBase::EmplaceInit(i, *it++); - * this->size++; - * } - * } - * } - * - * void test_function() { - * vector fields; - * auto ptr = make_inplace_array_object(fields.size()); - * ptr->Init(fields.begin(), fields.end()); - * - * // Access the 0th element in the array. - * assert(ptr->operator[](0) == fields[0]); - * } - * - * \endcode - */ -template -class InplaceArrayBase { - public: - /*! - * \brief Access element at index - * \param idx The index of the element. - * \return Const reference to ElemType at the index. - */ - const ElemType& operator[](size_t idx) const { - size_t size = Self()->GetSize(); - ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; - return *(reinterpret_cast(AddressOf(idx))); - } - - /*! - * \brief Access element at index - * \param idx The index of the element. - * \return Reference to ElemType at the index. - */ - ElemType& operator[](size_t idx) { - size_t size = Self()->GetSize(); - ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; - return *(reinterpret_cast(AddressOf(idx))); - } - - /*! - * \brief Destroy the Inplace Array Base object - */ - ~InplaceArrayBase() { - if (!(std::is_standard_layout::value && std::is_trivial::value)) { - size_t size = Self()->GetSize(); - for (size_t i = 0; i < size; ++i) { - ElemType* fp = reinterpret_cast(AddressOf(i)); - fp->ElemType::~ElemType(); - } - } - } - - protected: - /*! - * \brief Construct a value in place with the arguments. - * - * \tparam Args Type parameters of the arguments. - * \param idx Index of the element. - * \param args Arguments to construct the new value. - * - * \note Please make sure ArrayType::GetSize returns 0 before first call of - * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. - */ - template - void EmplaceInit(size_t idx, Args&&... args) { - void* field_ptr = AddressOf(idx); - new (field_ptr) ElemType(std::forward(args)...); - } - - /*! - * \brief Return the self object for the array. - * - * \return Pointer to ArrayType. - */ - inline ArrayType* Self() const { - return static_cast(const_cast(this)); - } - - /*! - * \brief Return the raw pointer to the element at idx. - * - * \param idx The index of the element. - * \return Raw pointer to the element. - */ - void* AddressOf(size_t idx) const { - static_assert( - alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, - "The size and alignment of ArrayType should respect " - "ElemType's alignment."); - - size_t kDataStart = sizeof(ArrayType); - ArrayType* self = Self(); - char* data_start = reinterpret_cast(self) + kDataStart; - return data_start + idx * sizeof(ElemType); - } -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class IterAdapter { - public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit IterAdapter(TIter iter) : iter_(iter) {} - IterAdapter& operator++() { - ++iter_; - return *this; - } - IterAdapter& operator--() { - --iter_; - return *this; - } - IterAdapter operator++(int) { - IterAdapter copy = *this; - ++iter_; - return copy; - } - IterAdapter operator--(int) { - IterAdapter copy = *this; - --iter_; - return copy; - } - - IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } - - IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } - - template - typename std::enable_if::value, - typename T::difference_type>::type inline - operator-(const IterAdapter& rhs) const { - return iter_ - rhs.iter_; - } - - bool operator==(IterAdapter other) const { return iter_ == other.iter_; } - bool operator!=(IterAdapter other) const { return !(*this == other); } - const value_type operator*() const { return Converter::convert(*iter_); } - - private: - TIter iter_; -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class ReverseIterAdapter { - public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; // NOLINT(*) - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} - ReverseIterAdapter& operator++() { - --iter_; - return *this; - } - ReverseIterAdapter& operator--() { - ++iter_; - return *this; - } - ReverseIterAdapter& operator++(int) { - ReverseIterAdapter copy = *this; - --iter_; - return copy; - } - ReverseIterAdapter& operator--(int) { - ReverseIterAdapter copy = *this; - ++iter_; - return copy; - } - ReverseIterAdapter operator+(difference_type offset) const { - return ReverseIterAdapter(iter_ - offset); - } - - template - typename std::enable_if::value, - typename T::difference_type>::type inline - operator-(const ReverseIterAdapter& rhs) const { - return rhs.iter_ - iter_; - } - - bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } - bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } - const value_type operator*() const { return Converter::convert(*iter_); } - - private: - TIter iter_; -}; - -/*! \brief array node content in array */ -class ArrayNode : public Object, public InplaceArrayBase { - public: - /*! \return The size of the array */ - size_t size() const { return this->size_; } - - /*! - * \brief Read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const ObjectRef at(int64_t i) const { return this->operator[](i); } - - /*! \return begin constant iterator */ - const ObjectRef* begin() const { return static_cast(InplaceArrayBase::AddressOf(0)); } - - /*! \return end constant iterator */ - const ObjectRef* end() const { return begin() + size_; } - - /*! \brief Release reference to all the elements */ - void clear() { ShrinkBy(size_); } - - /*! - * \brief Set i-th element of the array in-place - * \param i The index - * \param item The value to be set - */ - void SetItem(int64_t i, ObjectRef item) { this->operator[](i) = std::move(item); } - - /*! - * \brief Constructs a container and copy from another - * \param cap The capacity of the container - * \param from Source of the copy - * \return Ref-counted ArrayNode requested - */ - static ObjectPtr CopyFrom(int64_t cap, ArrayNode* from) { - int64_t size = from->size_; - ICHECK_GE(cap, size) << "ValueError: not enough capacity"; - ObjectPtr p = ArrayNode::Empty(cap); - ObjectRef* write = p->MutableBegin(); - ObjectRef* read = from->MutableBegin(); - // To ensure exception safety, size is only incremented after the initialization succeeds - for (int64_t& i = p->size_ = 0; i < size; ++i) { - new (write++) ObjectRef(*read++); - } - return p; - } - - /*! - * \brief Constructs a container and move from another - * \param cap The capacity of the container - * \param from Source of the move - * \return Ref-counted ArrayNode requested - */ - static ObjectPtr MoveFrom(int64_t cap, ArrayNode* from) { - int64_t size = from->size_; - ICHECK_GE(cap, size) << "ValueError: not enough capacity"; - ObjectPtr p = ArrayNode::Empty(cap); - ObjectRef* write = p->MutableBegin(); - ObjectRef* read = from->MutableBegin(); - // To ensure exception safety, size is only incremented after the initialization succeeds - for (int64_t& i = p->size_ = 0; i < size; ++i) { - new (write++) ObjectRef(std::move(*read++)); - } - from->size_ = 0; - return p; - } - - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - * \return Ref-counted ArrayNode requested - */ - static ObjectPtr CreateRepeated(int64_t n, const ObjectRef& val) { - ObjectPtr p = ArrayNode::Empty(n); - ObjectRef* itr = p->MutableBegin(); - for (int64_t& i = p->size_ = 0; i < n; ++i) { - new (itr++) ObjectRef(val); - } - return p; - } - - static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray; - static constexpr const char* _type_key = "Array"; - TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); - - private: - /*! \return Size of initialized memory, used by InplaceArrayBase. */ - size_t GetSize() const { return this->size_; } - - /*! \return begin mutable iterator */ - ObjectRef* MutableBegin() const { - return static_cast(InplaceArrayBase::AddressOf(0)); - } - - /*! \return end mutable iterator */ - ObjectRef* MutableEnd() const { return MutableBegin() + size_; } - - /*! - * \brief Create an ArrayNode with the given capacity. - * \param n Required capacity - * \return Ref-counted ArrayNode requested - */ - static ObjectPtr Empty(int64_t n = kInitSize) { - ICHECK_GE(n, 0); - ObjectPtr p = make_inplace_array_object(n); - p->capacity_ = n; - p->size_ = 0; - return p; - } - - /*! - * \brief Inplace-initialize the elements starting idx from [first, last) - * \param idx The starting point - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return Self - */ - template - ArrayNode* InitRange(int64_t idx, IterType first, IterType last) { - ObjectRef* itr = MutableBegin() + idx; - for (; first != last; ++first) { - ObjectRef ref = *first; - new (itr++) ObjectRef(std::move(ref)); - } - return this; - } - - /*! - * \brief Move elements from right to left, requires src_begin > dst - * \param dst Destination - * \param src_begin The start point of copy (inclusive) - * \param src_end The end point of copy (exclusive) - * \return Self - */ - ArrayNode* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { - ObjectRef* from = MutableBegin() + src_begin; - ObjectRef* to = MutableBegin() + dst; - while (src_begin++ != src_end) { - *to++ = std::move(*from++); - } - return this; - } - - /*! - * \brief Move elements from left to right, requires src_begin < dst - * \param dst Destination - * \param src_begin The start point of move (inclusive) - * \param src_end The end point of move (exclusive) - * \return Self - */ - ArrayNode* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { - ObjectRef* from = MutableBegin() + src_end; - ObjectRef* to = MutableBegin() + (src_end - src_begin + dst); - while (src_begin++ != src_end) { - *--to = std::move(*--from); - } - return this; - } - - /*! - * \brief Enlarges the size of the array - * \param delta Size enlarged, should be positive - * \param val Default value - * \return Self - */ - ArrayNode* EnlargeBy(int64_t delta, const ObjectRef& val = ObjectRef(nullptr)) { - ObjectRef* itr = MutableEnd(); - while (delta-- > 0) { - new (itr++) ObjectRef(val); - ++size_; - } - return this; - } - - /*! - * \brief Shrinks the size of the array - * \param delta Size shrinked, should be positive - * \return Self - */ - ArrayNode* ShrinkBy(int64_t delta) { - ObjectRef* itr = MutableEnd(); - while (delta-- > 0) { - (--itr)->ObjectRef::~ObjectRef(); - --size_; - } - return this; - } - - /*! \brief Number of elements used */ - int64_t size_; - - /*! \brief Number of elements allocated */ - int64_t capacity_; - - /*! \brief Initial size of ArrayNode */ - static constexpr int64_t kInitSize = 4; - - /*! \brief Expansion factor of the Array */ - static constexpr int64_t kIncFactor = 2; - - // CRTP parent class - friend InplaceArrayBase; - - // Reference class - template - friend class Array; - - // To specialize make_object - friend ObjectPtr make_object<>(); -}; - -/*! - * \brief Array, container representing a contigious sequence of ObjectRefs. - * - * Array implements in-place copy-on-write semantics. - * - * As in typical copy-on-write, a method which would typically mutate the array - * instead opaquely copies the underlying container, and then acts on its copy. - * - * If the array has reference count equal to one, we directly update the - * container in place without copying. This is optimization is sound because - * when the reference count is equal to one this reference is guranteed to be - * the sole pointer to the container. - * - * - * operator[] only provides const access, use Set to mutate the content. - * \tparam T The content ObjectRef type. - */ -template ::value>::type> -class Array : public ObjectRef { - public: - using value_type = T; - // constructors - /*! - * \brief default constructor - */ - Array() { data_ = ArrayNode::Empty(); } - - /*! - * \brief move constructor - * \param other source - */ - Array(Array&& other) : ObjectRef() { // NOLINT(*) - data_ = std::move(other.data_); - } - - /*! - * \brief copy constructor - * \param other source - */ - Array(const Array& other) : ObjectRef() { // NOLINT(*) - data_ = other.data_; - } - - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Array(ObjectPtr n) : ObjectRef(n) {} - - /*! - * \brief Constructor from iterator - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - */ - template - Array(IterType first, IterType last) { - Assign(first, last); - } - - /*! - * \brief constructor from initializer list - * \param init The initializer list - */ - Array(std::initializer_list init) { // NOLINT(*) - Assign(init.begin(), init.end()); - } - - /*! - * \brief constructor from vector - * \param init The vector - */ - Array(const std::vector& init) { // NOLINT(*) - Assign(init.begin(), init.end()); - } - - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - */ - explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); } - - /*! - * \brief move assign operator - * \param other The source of assignment - * \return reference to self. - */ - Array& operator=(Array&& other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief copy assign operator - * \param other The source of assignment - * \return reference to self. - */ - Array& operator=(const Array& other) { - data_ = other.data_; - return *this; - } - - public: - // iterators - struct ValueConverter { - using ResultType = T; - static T convert(const ObjectRef& n) { return DowncastNoCheck(n); } - }; - - using iterator = IterAdapter; - using reverse_iterator = ReverseIterAdapter; - - /*! \return begin iterator */ - iterator begin() const { return iterator(GetArrayNode()->begin()); } - - /*! \return end iterator */ - iterator end() const { return iterator(GetArrayNode()->end()); } - - /*! \return rbegin iterator */ - reverse_iterator rbegin() const { - // ArrayNode::end() is never nullptr - return reverse_iterator(GetArrayNode()->end() - 1); - } - - /*! \return rend iterator */ - reverse_iterator rend() const { - // ArrayNode::begin() is never nullptr - return reverse_iterator(GetArrayNode()->begin() - 1); - } - - public: - // const methods in std::vector - /*! - * \brief Immutably read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const T operator[](int64_t i) const { - ArrayNode* p = GetArrayNode(); - ICHECK(p != nullptr) << "ValueError: cannot index a null array"; - ICHECK(0 <= i && i < p->size_) - << "IndexError: indexing " << i << " on an array of size " << p->size_; - return DowncastNoCheck(*(p->begin() + i)); - } - - /*! \return The size of the array */ - size_t size() const { - ArrayNode* p = GetArrayNode(); - return p == nullptr ? 0 : GetArrayNode()->size_; - } - - /*! \return The capacity of the array */ - size_t capacity() const { - ArrayNode* p = GetArrayNode(); - return p == nullptr ? 0 : GetArrayNode()->capacity_; - } - - /*! \return Whether array is empty */ - bool empty() const { return size() == 0; } - - /*! \return The first element of the array */ - const T front() const { - ArrayNode* p = GetArrayNode(); - ICHECK(p != nullptr) << "ValueError: cannot index a null array"; - ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; - return DowncastNoCheck(*(p->begin())); - } - - /*! \return The last element of the array */ - const T back() const { - ArrayNode* p = GetArrayNode(); - ICHECK(p != nullptr) << "ValueError: cannot index a null array"; - ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; - return DowncastNoCheck(*(p->end() - 1)); - } - - public: - // mutation in std::vector, implements copy-on-write - - /*! - * \brief push a new item to the back of the list - * \param item The item to be pushed. - */ - void push_back(const T& item) { - ArrayNode* p = CopyOnWrite(1); - p->EmplaceInit(p->size_++, item); - } - - /*! - * \brief Insert an element into the given position - * \param position An iterator pointing to the insertion point - * \param val The element to insert - */ - void insert(iterator position, const T& val) { - ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; - int64_t idx = std::distance(begin(), position); - int64_t size = GetArrayNode()->size_; - auto addr = CopyOnWrite(1) // - ->EnlargeBy(1) // - ->MoveElementsRight(idx + 1, idx, size) // - ->MutableBegin(); - new (addr + idx) ObjectRef(val); - } - - /*! - * \brief Insert a range of elements into the given position - * \param position An iterator pointing to the insertion point - * \param first The begin iterator of the range - * \param last The end iterator of the range - */ - template - void insert(iterator position, IterType first, IterType last) { - if (first == last) { - return; - } - ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; - int64_t idx = std::distance(begin(), position); - int64_t size = GetArrayNode()->size_; - int64_t numel = std::distance(first, last); - CopyOnWrite(numel) - ->EnlargeBy(numel) - ->MoveElementsRight(idx + numel, idx, size) - ->InitRange(idx, first, last); - } - - /*! \brief Remove the last item of the list */ - void pop_back() { - ICHECK(data_ != nullptr) << "ValueError: cannot pop_back because array is null"; - int64_t size = GetArrayNode()->size_; - ICHECK_GT(size, 0) << "ValueError: cannot pop_back because array is empty"; - CopyOnWrite()->ShrinkBy(1); - } - - /*! - * \brief Erase an element on the given position - * \param position An iterator pointing to the element to be erased - */ - void erase(iterator position) { - ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; - int64_t st = std::distance(begin(), position); - int64_t size = GetArrayNode()->size_; - ICHECK(0 <= st && st < size) << "ValueError: cannot erase at index " << st - << ", because Array size is " << size; - CopyOnWrite() // - ->MoveElementsLeft(st, st + 1, size) // - ->ShrinkBy(1); - } - - /*! - * \brief Erase a given range of elements - * \param first The begin iterator of the range - * \param last The end iterator of the range - */ - void erase(iterator first, iterator last) { - if (first == last) { - return; - } - ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; - int64_t size = GetArrayNode()->size_; - int64_t st = std::distance(begin(), first); - int64_t ed = std::distance(begin(), last); - ICHECK_LT(st, ed) << "ValueError: cannot erase array in range [" << st << ", " << ed << ")"; - ICHECK(0 <= st && st <= size && 0 <= ed && ed <= size) - << "ValueError: cannot erase array in range [" << st << ", " << ed << ")" - << ", because array size is " << size; - CopyOnWrite() // - ->MoveElementsLeft(st, ed, size) // - ->ShrinkBy(ed - st); - } - - /*! - * \brief Resize the array. - * \param n The new size. - */ - void resize(int64_t n) { - ICHECK_GE(n, 0) << "ValueError: cannot resize an Array to negative size"; - if (data_ == nullptr) { - SwitchContainer(n); - return; - } - int64_t size = GetArrayNode()->size_; - if (size < n) { - CopyOnWrite(n - size)->EnlargeBy(n - size); - } else if (size > n) { - CopyOnWrite()->ShrinkBy(size - n); - } - } - - /*! - * \brief Make sure the list has the capacity of at least n - * \param n lower bound of the capacity - */ - void reserve(int64_t n) { - if (data_ == nullptr || n > GetArrayNode()->capacity_) { - SwitchContainer(n); - } - } - - /*! \brief Release reference to all the elements */ - void clear() { - if (data_ != nullptr) { - ArrayNode* p = CopyOnWrite(); - p->clear(); - } - } - - public: - // Array's own methods - - /*! - * \brief set i-th element of the array. - * \param i The index - * \param value The value to be setted. - */ - void Set(int64_t i, T value) { - ArrayNode* p = this->CopyOnWrite(); - ICHECK(0 <= i && i < p->size_) - << "IndexError: indexing " << i << " on an array of size " << p->size_; - *(p->MutableBegin() + i) = std::move(value); - } - - /*! \return The underlying ArrayNode */ - ArrayNode* GetArrayNode() const { return static_cast(data_.get()); } - - /*! - * \brief Helper function to apply fmutate to mutate an array. - * \param fmutate The transformation function T -> T. - * \tparam F the type of the mutation function. - * \note This function performs copy on write optimization. - */ - template - void MutateByApply(F fmutate) { - if (data_ == nullptr) { - return; - } - struct StackFrame { - ArrayNode* p; - ObjectRef* itr; - int64_t i; - int64_t size; - }; - std::unique_ptr s = std::make_unique(); - s->p = GetArrayNode(); - s->itr = s->p->MutableBegin(); - s->i = 0; - s->size = s->p->size_; - if (!data_.unique()) { - // Loop invariant: keeps iterating when - // 1) data is not unique - // 2) no elements are actually mutated yet - for (; s->i < s->size; ++s->i, ++s->itr) { - T new_elem = fmutate(DowncastNoCheck(*s->itr)); - // do nothing when there is no mutation - if (new_elem.same_as(*s->itr)) { - continue; - } - // loop invariant breaks when the first real mutation happens - // we copy the elements into a new unique array - ObjectPtr copy = ArrayNode::CopyFrom(s->p->capacity_, s->p); - s->itr = copy->MutableBegin() + (s->i++); - *s->itr++ = std::move(new_elem); - data_ = std::move(copy); - // make sure `data_` is unique and break - break; - } - } - // when execution comes to this line, it is guaranteed that either - // 1) i == size - // or 2) data_.unique() is true - for (; s->i < s->size; ++s->i, ++s->itr) { - *s->itr = std::move(fmutate(std::move(DowncastNoCheck(std::move(*s->itr))))); - } - } - - /*! - * \brief reset the array to content from iterator. - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - */ - template - void Assign(IterType first, IterType last) { - int64_t cap = std::distance(first, last); - ICHECK_GE(cap, 0) << "ValueError: cannot construct an Array of negative size"; - ArrayNode* p = GetArrayNode(); - if (p != nullptr && data_.unique() && p->capacity_ >= cap) { - // do not have to make new space - p->clear(); - } else { - // create new space - data_ = ArrayNode::Empty(cap); - p = GetArrayNode(); - } - // To ensure exception safety, size is only incremented after the initialization succeeds - ObjectRef* itr = p->MutableBegin(); - for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { - new (itr) ObjectRef(*first); - } - } - - /*! - * \brief Copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - ArrayNode* CopyOnWrite() { - if (data_ == nullptr) { - return SwitchContainer(ArrayNode::kInitSize); - } - if (!data_.unique()) { - return SwitchContainer(capacity()); - } - return static_cast(data_.get()); - } - - /*! \brief specify container node */ - using ContainerType = ArrayNode; - - private: - /*! - * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. - * \param reserve_extra Number of extra slots needed - * \return ArrayNode pointer to the unique copy - */ - ArrayNode* CopyOnWrite(int64_t reserve_extra) { - ArrayNode* p = GetArrayNode(); - if (p == nullptr) { - // necessary to get around the constexpr address issue before c++17 - const int64_t kInitSize = ArrayNode::kInitSize; - return SwitchContainer(std::max(kInitSize, reserve_extra)); - } - if (p->capacity_ >= p->size_ + reserve_extra) { - return CopyOnWrite(); - } - int64_t cap = p->capacity_ * ArrayNode::kIncFactor; - cap = std::max(cap, p->size_ + reserve_extra); - return SwitchContainer(cap); - } - - /*! - * \brief Move or copy the ArrayNode to new address with the given capacity - * \param capacity The capacity requirement of the new address - */ - ArrayNode* SwitchContainer(int64_t capacity) { - if (data_ == nullptr) { - data_ = ArrayNode::Empty(capacity); - } else if (data_.unique()) { - data_ = ArrayNode::MoveFrom(capacity, GetArrayNode()); - } else { - data_ = ArrayNode::CopyFrom(capacity, GetArrayNode()); - } - return static_cast(data_.get()); - } -}; - -/*! - * \brief Concat two Arrays. - * \param lhs first Array to be concatenated. - * \param rhs second Array to be concatenated. - * \return The concatenated Array. Original Arrays are kept unchanged. - */ -template ::value>::type> -inline Array Concat(Array lhs, const Array& rhs) { - for (const auto& x : rhs) { - lhs.push_back(x); - } - return std::move(lhs); -} - -// Specialize make_object to make sure it is correct. -template <> -inline ObjectPtr make_object() { - return ArrayNode::Empty(); -} - -/*! \brief An object representing a structure or enumeration. */ -class ADTObj : public Object, public InplaceArrayBase { - public: - /*! \brief The tag representing the constructor used. */ - int32_t tag; - /*! \brief Number of fields in the ADT object. */ - uint32_t size; - // The fields of the structure follows directly in memory. - - static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT; - static constexpr const char* _type_key = "runtime.ADT"; - TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); - - private: - /*! - * \return The number of elements in the array. - */ - size_t GetSize() const { return size; } - - /*! - * \brief Initialize the elements in the array. - * - * \tparam Iterator Iterator type of the array. - * \param begin The begin iterator. - * \param end The end iterator. - */ - template - void Init(Iterator begin, Iterator end) { - size_t num_elems = std::distance(begin, end); - this->size = 0; - auto it = begin; - for (size_t i = 0; i < num_elems; ++i) { - InplaceArrayBase::EmplaceInit(i, *it++); - // Only increment size after the initialization succeeds - this->size++; - } - } - - friend class ADT; - friend InplaceArrayBase; -}; - -/*! \brief reference to algebraic data type objects. */ -class ADT : public ObjectRef { - public: - /*! - * \brief construct an ADT object reference. - * \param tag The tag of the ADT object. - * \param fields The fields of the ADT object. - * \return The constructed ADT object reference. - */ - ADT(int32_t tag, std::vector fields) : ADT(tag, fields.begin(), fields.end()){}; - - /*! - * \brief construct an ADT object reference. - * \param tag The tag of the ADT object. - * \param begin The begin iterator to the start of the fields array. - * \param end The end iterator to the end of the fields array. - * \return The constructed ADT object reference. - */ - template - ADT(int32_t tag, Iterator begin, Iterator end) { - size_t num_elems = std::distance(begin, end); - auto ptr = make_inplace_array_object(num_elems); - ptr->tag = tag; - ptr->Init(begin, end); - data_ = std::move(ptr); - } - - /*! - * \brief construct an ADT object reference. - * \param tag The tag of the ADT object. - * \param init The initializer list of fields. - * \return The constructed ADT object reference. - */ - ADT(int32_t tag, std::initializer_list init) : ADT(tag, init.begin(), init.end()){}; - - /*! - * \brief Access element at index. - * - * \param idx The array index - * \return const ObjectRef - */ - const ObjectRef& operator[](size_t idx) const { return operator->()->operator[](idx); } - - /*! - * \brief Return the ADT tag. - */ - int32_t tag() const { return operator->()->tag; } - - /*! - * \brief Return the number of fields. - */ - size_t size() const { return operator->()->size; } - - /*! - * \brief Construct a tuple object. - * - * \tparam Args Type params of tuple feilds. - * \param args Tuple fields. - * \return ADT The tuple object reference. - */ - template - static ADT Tuple(Args&&... args) { - return ADT(0, std::forward(args)...); - } - - TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); -}; - -/*! \brief An object representing string. It's POD type. */ -class StringObj : public Object { - public: - /*! \brief The pointer to string data. */ - const char* data; - - /*! \brief The length of the string object. */ - uint64_t size; - - static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString; - static constexpr const char* _type_key = "runtime.String"; - TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object); - - private: - /*! \brief String object which is moved from std::string container. */ - class FromStd; - - friend class String; -}; - -/*! - * \brief Reference to string objects. - * - * \code - * - * // Example to create runtime String reference object from std::string - * std::string s = "hello world"; - * - * // You can create the reference from existing std::string - * String ref{std::move(s)}; - * - * // You can rebind the reference to another string. - * ref = std::string{"hello world2"}; - * - * // You can use the reference as hash map key - * std::unordered_map m; - * m[ref] = 1; - * - * // You can compare the reference object with other string objects - * assert(ref == "hello world", true); - * - * // You can convert the reference to std::string again - * string s2 = (string)ref; - * - * \endcode - */ -class String : public ObjectRef { - public: - /*! - * \brief Construct an empty string. - */ - String() : String(std::string()) {} - /*! - * \brief Construct a new String object - * - * \param other The moved/copied std::string object - * - * \note If user passes const reference, it will trigger copy. If it's rvalue, - * it will be moved into other. - */ - String(std::string other); // NOLINT(*) - - /*! - * \brief Construct a new String object - * - * \param other a char array. - */ - String(const char* other) // NOLINT(*) - : String(std::string(other)) {} - - /*! - * \brief Change the value the reference object points to. - * - * \param other The value for the new String - * - */ - inline String& operator=(std::string other); - - /*! - * \brief Change the value the reference object points to. - * - * \param other The value for the new String - */ - inline String& operator=(const char* other); - - /*! - * \brief Compares this String object to other - * - * \param other The String to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const String& other) const { - return memncmp(data(), other.data(), size(), other.size()); - } - - /*! - * \brief Compares this String object to other - * - * \param other The string to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const std::string& other) const { - return memncmp(data(), other.data(), size(), other.size()); - } - - /*! - * \brief Compares this to other - * - * \param other The character array to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const char* other) const { - return memncmp(data(), other, size(), std::strlen(other)); - } - - /*! - * \brief Returns a pointer to the char array in the string. - * - * \return const char* - */ - const char* c_str() const { return get()->data; } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t size() const { - const auto* ptr = get(); - return ptr->size; - } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t length() const { return size(); } - - /*! - * \brief Retun if the string is empty - * - * \return true if empty, false otherwise. - */ - bool empty() const { return size() == 0; } - - /*! - * \brief Read an element. - * \param pos The position at which to read the character. - * - * \return The char at position - */ - char at(size_t pos) const { - if (pos < size()) { - return data()[pos]; - } else { - throw std::out_of_range("tvm::String index out of bounds"); - } - } - - /*! - * \brief Return the data pointer - * - * \return const char* data pointer - */ - const char* data() const { return get()->data; } - - /*! - * \brief Convert String to an std::string object - * - * \return std::string - */ - operator std::string() const { return std::string{get()->data, size()}; } - - // LLVM compatibility function, implemented in src/target/llvm/llvm_common.h - /*! - * \brief Convert String to an llvm::StringRef object - * - * \return llvm::StringRef - */ - inline operator llvm::StringRef() const; - - /*! - * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String - * \param val The value to be checked - * \return A boolean indicating if val can be converted to String - */ - inline static bool CanConvertFrom(const TVMArgValue& val); - - /*! - * \brief Hash the binary bytes - * \param data The data pointer - * \param size The size of the bytes. - * \return the hash value. - */ - static size_t HashBytes(const char* data, size_t size) { - // This function falls back to string copy with c++11 compiler and is - // recommended to be compiled with c++14 -#if TVM_USE_CXX17_STRING_VIEW_HASH - return std::hash()(std::string_view(data, size)); -#elif TVM_USE_CXX14_STRING_VIEW_HASH - return std::hash()(std::experimental::string_view(data, size)); -#else - return std::hash()(std::string(data, size)); -#endif - } - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); - - private: - /*! - * \brief Compare two char sequence - * - * \param lhs Pointers to the char array to compare - * \param rhs Pointers to the char array to compare - * \param lhs_count Length of the char array to compare - * \param rhs_count Length of the char array to compare - * \return int zero if both char sequences compare equal. negative if this - * appear before other, positive otherwise. - */ - static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); - - /*! - * \brief Concatenate two char sequences - * - * \param lhs Pointers to the lhs char array - * \param lhs_size The size of the lhs char array - * \param rhs Pointers to the rhs char array - * \param rhs_size The size of the rhs char array - * - * \return The concatenated char sequence - */ - static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) { - std::string ret(lhs, lhs_size); - ret.append(rhs, rhs_size); - return String(ret); - } - - // Overload + operator - friend String operator+(const String& lhs, const String& rhs); - friend String operator+(const String& lhs, const std::string& rhs); - friend String operator+(const std::string& lhs, const String& rhs); - friend String operator+(const String& lhs, const char* rhs); - friend String operator+(const char* lhs, const String& rhs); - - friend struct tvm::runtime::ObjectEqual; -}; - -/*! \brief An object representing string moved from std::string. */ -class StringObj::FromStd : public StringObj { - public: - /*! - * \brief Construct a new FromStd object - * - * \param other The moved/copied std::string object - * - * \note If user passes const reference, it will trigger copy. If it's rvalue, - * it will be moved into other. - */ - explicit FromStd(std::string other) : data_container{other} {} - - private: - /*! \brief Container that holds the memory. */ - std::string data_container; - - friend class String; -}; - -inline String::String(std::string other) { - auto ptr = make_object(std::move(other)); - ptr->size = ptr->data_container.size(); - ptr->data = ptr->data_container.data(); - data_ = std::move(ptr); -} - -inline String& String::operator=(std::string other) { - String replace{std::move(other)}; - data_.swap(replace.data_); - return *this; -} - -inline String& String::operator=(const char* other) { return operator=(std::string(other)); } - -inline String operator+(const String& lhs, const String& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const String& lhs, const std::string& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const std::string& lhs, const String& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const char* lhs, const String& rhs) { - size_t lhs_size = std::strlen(lhs); - size_t rhs_size = rhs.size(); - return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const String& lhs, const char* rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = std::strlen(rhs); - return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); -} - -// Overload < operator -inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; } - -inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; } - -// Overload > operator -inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; } - -inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; } - -// Overload <= operator -inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } - -inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } - -// Overload >= operator -inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } - -inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(rhs) <= 0; } - -// Overload == operator -inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; } - -inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; } - -inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } - -inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; } - -inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; } - -// Overload != operator -inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; } - -inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; } - -inline std::ostream& operator<<(std::ostream& out, const String& input) { - out.write(input.data(), input.size()); - return out; -} - -inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { - if (lhs == rhs && lhs_count == rhs_count) return 0; - - for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { - if (lhs[i] < rhs[i]) return -1; - if (lhs[i] > rhs[i]) return 1; - } - if (lhs_count < rhs_count) { - return -1; - } else if (lhs_count > rhs_count) { - return 1; - } else { - return 0; - } -} - -inline size_t ObjectHash::operator()(const ObjectRef& a) const { - if (const auto* str = a.as()) { - return String::HashBytes(str->data, str->size); - } - return ObjectPtrHash()(a); -} - -inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) const { - if (a.same_as(b)) { - return true; - } - if (const auto* str_a = a.as()) { - if (const auto* str_b = b.as()) { - return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0; - } - } - return false; -} - -/*! \brief Helper to represent nullptr for optional. */ -struct NullOptType {}; - -/*! - * \brief Optional container that to represent to a Nullable variant of T. - * \tparam T The original ObjectRef. - * - * \code - * - * Optional opt0 = nullptr; - * Optional opt1 = String("xyz"); - * ICHECK(opt0 == nullptr); - * ICHECK(opt1 == "xyz"); - * - * \endcode - */ -template -class Optional : public ObjectRef { - public: - using ContainerType = typename T::ContainerType; - static_assert(std::is_base_of::value, "Optional is only defined for ObjectRef."); - // default constructors. - Optional() = default; - Optional(const Optional&) = default; - Optional(Optional&&) = default; - Optional& operator=(const Optional&) = default; - Optional& operator=(Optional&&) = default; - /*! - * \brief Construct from an ObjectPtr - * whose type already matches the ContainerType. - * \param ptr - */ - explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} - /*! \brief Nullopt handling */ - Optional(NullOptType) {} // NOLINT(*) - // nullptr handling. - // disallow implicit conversion as 0 can be implicitly converted to nullptr_t - explicit Optional(std::nullptr_t) {} - Optional& operator=(std::nullptr_t) { - data_ = nullptr; - return *this; - } - // normal value handling. - Optional(T other) // NOLINT(*) - : ObjectRef(std::move(other)) {} - Optional& operator=(T other) { - ObjectRef::operator=(std::move(other)); - return *this; - } - // delete the int constructor - // since Optional(0) is ambiguious - // 0 can be implicitly casted to nullptr_t - explicit Optional(int val) = delete; - Optional& operator=(int val) = delete; - /*! - * \return A not-null container value in the optional. - * \note This function performs not-null checking. - */ - T value() const { - ICHECK(data_ != nullptr); - return T(data_); - } - /*! - * \return The contained value if the Optional is not null - * otherwise return the default_value. - */ - T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; } - - /*! \return Whether the container is not nullptr.*/ - explicit operator bool() const { return *this != nullptr; } - // operator overloadings - bool operator==(std::nullptr_t) const { return data_ == nullptr; } - bool operator!=(std::nullptr_t) const { return data_ != nullptr; } - auto operator==(const Optional& other) const { - // support case where sub-class returns a symbolic ref type. - using RetType = decltype(value() == other.value()); - if (same_as(other)) return RetType(true); - if (*this != nullptr && other != nullptr) { - return value() == other.value(); - } else { - // one of them is nullptr. - return RetType(false); - } - } - auto operator!=(const Optional& other) const { - // support case where sub-class returns a symbolic ref type. - using RetType = decltype(value() != other.value()); - if (same_as(other)) return RetType(false); - if (*this != nullptr && other != nullptr) { - return value() != other.value(); - } else { - // one of them is nullptr. - return RetType(true); - } - } - auto operator==(const T& other) const { - using RetType = decltype(value() == other); - if (same_as(other)) return RetType(true); - if (*this != nullptr) return value() == other; - return RetType(false); - } - auto operator!=(const T& other) const { return !(*this == other); } - template - auto operator==(const U& other) const { - using RetType = decltype(value() == other); - if (*this == nullptr) return RetType(false); - return value() == other; - } - template - auto operator!=(const U& other) const { - using RetType = decltype(value() != other); - if (*this == nullptr) return RetType(true); - return value() != other; - } - static constexpr bool _type_is_nullable = true; -}; - -/*! - * \brief An object representing a closure. This object is used by both the - * Relay VM and interpreter. - */ -class ClosureObj : public Object { - public: - static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure; - static constexpr const char* _type_key = "runtime.Closure"; - TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object); -}; - -/*! \brief reference to closure. */ -class Closure : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj); -}; - -#if (USE_FALLBACK_STL_MAP != 0) - -/*! \brief Shared content of all specializations of hash map */ -class MapNode : public Object { - public: - /*! \brief Type of the keys in the hash map */ - using key_type = ObjectRef; - /*! \brief Type of the values in the hash map */ - using mapped_type = ObjectRef; - /*! \brief Type of the actual underlying container */ - using ContainerType = std::unordered_map; - /*! \brief Iterator class */ - using iterator = ContainerType::iterator; - /*! \brief Iterator class */ - using const_iterator = ContainerType::const_iterator; - /*! \brief Type of value stored in the hash map */ - using KVType = ContainerType::value_type; - - static_assert(std::is_standard_layout::value, "KVType is not standard layout"); - static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); - - static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; - static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); - - /*! - * \brief Number of elements in the SmallMapNode - * \return The result - */ - size_t size() const { return data_.size(); } - /*! - * \brief Count the number of times a key exists in the hash map - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const { return data_.count(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { return data_.at(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { return data_.at(key); } - /*! \return begin iterator */ - iterator begin() { return data_.begin(); } - /*! \return const begin iterator */ - const_iterator begin() const { return data_.begin(); } - /*! \return end iterator */ - iterator end() { return data_.end(); } - /*! \return end iterator */ - const_iterator end() const { return data_.end(); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - const_iterator find(const key_type& key) const { return data_.find(key); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) { return data_.find(key); } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { data_.erase(position); } - /*! - * \brief Erase the entry associated with the key, do nothing if not exists - * \param key The indexing key - */ - void erase(const key_type& key) { data_.erase(key); } - /*! - * \brief Create an empty container - * \return The object created - */ - static ObjectPtr Empty() { return make_object(); } - - protected: - /*! - * \brief Create the map using contents from the given iterators. - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return ObjectPtr to the map created - */ - template - static ObjectPtr CreateFromRange(IterType first, IterType last) { - ObjectPtr p = make_object(); - p->data_ = ContainerType(first, last); - return p; - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - MapNode* map_node = static_cast(map->get()); - map_node->data_[kv.first] = kv.second; - } - /*! - * \brief Create an empty container with elements copying from another MapNode - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(MapNode* from) { - ObjectPtr p = make_object(); - p->data_ = ContainerType(from->data_.begin(), from->data_.end()); - return p; - } - /*! \brief The real container storing data */ - ContainerType data_; - template - friend class Map; -}; - -#else - -/*! \brief Shared content of all specializations of hash map */ -class MapNode : public Object { - public: - /*! \brief Type of the keys in the hash map */ - using key_type = ObjectRef; - /*! \brief Type of the values in the hash map */ - using mapped_type = ObjectRef; - /*! \brief Type of value stored in the hash map */ - using KVType = std::pair; - /*! \brief Iterator class */ - class iterator; - - static_assert(std::is_standard_layout::value, "KVType is not standard layout"); - static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); - - static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; - static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); - - /*! - * \brief Number of elements in the SmallMapNode - * \return The result - */ - size_t size() const { return size_; } - /*! - * \brief Count the number of times a key exists in the hash map - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key); - /*! \return begin iterator */ - iterator begin() const; - /*! \return end iterator */ - iterator end() const; - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const; - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position); - /*! - * \brief Erase the entry associated with the key, do nothing if not exists - * \param key The indexing key - */ - void erase(const key_type& key) { erase(find(key)); } - - class iterator { - public: - using iterator_category = std::forward_iterator_tag; - using difference_type = int64_t; - using value_type = KVType; - using pointer = KVType*; - using reference = KVType&; - /*! \brief Default constructor */ - iterator() : index(0), self(nullptr) {} - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { - return index == other.index && self == other.self; - } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return !(*this == other); } - /*! \brief De-reference iterators */ - pointer operator->() const; - /*! \brief De-reference iterators */ - reference operator*() const { return *((*this).operator->()); } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++(); - /*! \brief Prefix self decrement, e.g. --iter */ - iterator& operator--(); - /*! \brief Suffix self increment */ - iterator operator++(int) { - iterator copy = *this; - ++(*this); - return copy; - } - /*! \brief Suffix self decrement */ - iterator operator--(int) { - iterator copy = *this; - --(*this); - return copy; - } - - protected: - /*! \brief Construct by value */ - iterator(uint64_t index, const MapNode* self) : index(index), self(self) {} - /*! \brief The position on the array */ - uint64_t index; - /*! \brief The container it points to */ - const MapNode* self; - - friend class DenseMapNode; - friend class SmallMapNode; - }; - /*! - * \brief Create an empty container - * \return The object created - */ - static inline ObjectPtr Empty(); - - protected: - /*! - * \brief Create the map using contents from the given iterators. - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return ObjectPtr to the map created - */ - template - static inline ObjectPtr CreateFromRange(IterType first, IterType last); - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static inline void InsertMaybeReHash(const KVType& kv, ObjectPtr* map); - /*! - * \brief Create an empty container with elements copying from another SmallMapNode - * \param from The source container - * \return The object created - */ - static inline ObjectPtr CopyFrom(MapNode* from); - /*! \brief number of slots minus 1 */ - uint64_t slots_; - /*! \brief number of entries in the container */ - uint64_t size_; - // Reference class - template - friend class Map; -}; - -/*! \brief A specialization of small-sized hash map */ -class SmallMapNode : public MapNode, - public runtime::InplaceArrayBase { - private: - static constexpr uint64_t kInitSize = 2; - static constexpr uint64_t kMaxSize = 4; - - public: - using MapNode::iterator; - using MapNode::KVType; - - /*! \brief Defaults to the destructor of InplaceArrayBase */ - ~SmallMapNode() = default; - /*! - * \brief Count the number of times a key exists in the SmallMapNode - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const { return find(key).index < size_; } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { - iterator itr = find(key); - ICHECK(itr.index < size_) << "IndexError: key is not in Map"; - return itr->second; - } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { - iterator itr = find(key); - ICHECK(itr.index < size_) << "IndexError: key is not in Map"; - return itr->second; - } - /*! \return begin iterator */ - iterator begin() const { return iterator(0, this); } - /*! \return end iterator */ - iterator end() const { return iterator(size_, this); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - KVType* ptr = static_cast(AddressOf(0)); - for (uint64_t i = 0; i < size_; ++i, ++ptr) { - if (ObjectEqual()(ptr->first, key)) { - return iterator(i, this); - } - } - return iterator(size_, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { Erase(position.index); } - - private: - /*! - * \brief Remove a position in SmallMapNode - * \param index The position to be removed - */ - void Erase(const uint64_t index) { - if (index >= size_) { - return; - } - KVType* begin = static_cast(AddressOf(0)); - KVType* last = begin + (size_ - 1); - if (index + 1 == size_) { - last->first.ObjectRef::~ObjectRef(); - last->second.ObjectRef::~ObjectRef(); - } else { - *(begin + index) = std::move(*last); - } - size_ -= 1; - } - /*! - * \brief Create an empty container - * \param n Number of empty slots - * \return The object created - */ - static ObjectPtr Empty(uint64_t n = kInitSize) { - using ::tvm::runtime::make_inplace_array_object; - ObjectPtr p = make_inplace_array_object(n); - p->size_ = 0; - p->slots_ = n; - return p; - } - /*! - * \brief Create an empty container initialized with a given range - * \param n Number of empty slots - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - * \return The object created - */ - template - static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { - ObjectPtr p = Empty(n); - KVType* ptr = static_cast(p->AddressOf(0)); - for (; first != last; ++first, ++p->size_) { - new (ptr++) KVType(*first); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another SmallMapNode - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(SmallMapNode* from) { - KVType* first = static_cast(from->AddressOf(0)); - KVType* last = first + from->size_; - return CreateFromRange(from->size_, first, last); - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - SmallMapNode* map_node = static_cast(map->get()); - iterator itr = map_node->find(kv.first); - if (itr.index < map_node->size_) { - itr->second = kv.second; - return; - } - if (map_node->size_ < map_node->slots_) { - KVType* ptr = static_cast(map_node->AddressOf(map_node->size_)); - new (ptr) KVType(kv); - ++map_node->size_; - return; - } - uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize)); - next_size = std::min(next_size, uint64_t(kMaxSize)); - ICHECK_GT(next_size, map_node->slots_); - ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); - InsertMaybeReHash(kv, &new_map); - *map = std::move(new_map); - } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return static_cast(AddressOf(index)); } - /*! \brief A size function used by InplaceArrayBase */ - uint64_t GetSize() const { return size_; } - - protected: - friend class MapNode; - friend class DenseMapNode; - friend class runtime::InplaceArrayBase; -}; - -/*! \brief A specialization of hash map that implements the idea of array-based hash map. - * Another reference implementation can be found [1]. - * - * A. Overview - * - * DenseMapNode did several improvements over traditional separate chaining hash, - * in terms of cache locality, memory footprints and data organization. - * - * A1. Implicit linked list. For better cache locality, instead of using linked list - * explicitly for each bucket, we store list data into a single array that spans contiguously - * in memory, and then carefully design access patterns to make sure most of them fall into - * a single cache line. - * - * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and - * traversal. This can be divided in 3 parts. - * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, - * which means the slot is empty but not allowed to be written. - * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is - * head of a linked list. - * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit - * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when - * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are - * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to - * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, - * then x must be one of the 126 pre-defined values. - * - * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. - * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. - * 16 key-value pairs. - * - * B. Implementation details - * - * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid - * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, - * we use the Fibonacci Hashing [2] trick. - * - * B2. Traverse a linked list in the array. - * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i - * indicates that it is list head, then we found the head; otherwise the list is empty. No probing - * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we - * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of - * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). - * - * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this - * element is in the linked list, and if not, we put it at the end by probing the next empty - * position in one of the 126 candidate positions. If the linked list does not even exist, but the - * slot for list head has been occupied by another linked list, we should find this intruder another - * place. - * - * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing - * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the - * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list - * head. - * - * [1] https://github.com/skarupke/flat_hash_map - * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ - * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - */ -class DenseMapNode : public MapNode { - private: - /*! \brief The number of elements in a memory block */ - static constexpr int kBlockCap = 16; - /*! \brief Maximum load factor of the hash map */ - static constexpr double kMaxLoadFactor = 0.99; - /*! \brief Binary representation of the metadata of an empty slot */ - static constexpr uint8_t kEmptySlot = uint8_t(0b11111111); - /*! \brief Binary representation of the metadata of a protected slot */ - static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110); - /*! \brief Number of probing choices available */ - static constexpr int kNumJumpDists = 126; - /*! \brief Head of the implicit linked list */ - struct ListNode; - /*! \brief POD type of a block of memory */ - struct Block { - uint8_t bytes[kBlockCap + kBlockCap * sizeof(KVType)]; - }; - static_assert(sizeof(Block) == kBlockCap * (sizeof(KVType) + 1), "sizeof(Block) incorrect"); - static_assert(std::is_standard_layout::value, "Block is not standard layout"); - - public: - using MapNode::iterator; - - /*! - * \brief Destroy the DenseMapNode - */ - ~DenseMapNode() { this->Reset(); } - /*! \return The number of elements of the key */ - size_t count(const key_type& key) const { return !Search(key).IsNone(); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { return At(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { return At(key); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - ListNode node = Search(key); - return node.IsNone() ? end() : iterator(node.index, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { - uint64_t index = position.index; - if (position.self != nullptr && index <= this->slots_) { - Erase(ListNode(index, this)); - } - } - /*! \return begin iterator */ - iterator begin() const { - if (slots_ == 0) { - return iterator(0, this); - } - for (uint64_t index = 0; index <= slots_; ++index) { - if (!ListNode(index, this).IsEmpty()) { - return iterator(index, this); - } - } - return iterator(slots_ + 1, this); - } - /*! \return end iterator */ - iterator end() const { return slots_ == 0 ? iterator(0, this) : iterator(slots_ + 1, this); } - - private: - /*! - * \brief Search for the given key - * \param key The key - * \return ListNode that associated with the key - */ - ListNode Search(const key_type& key) const { - if (this->size_ == 0) { - return ListNode(); - } - for (ListNode iter = GetListHead(ObjectHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { - if (ObjectEqual()(key, iter.Key())) { - return iter; - } - } - return ListNode(); - } - /*! - * \brief Search for the given key, throw exception if not exists - * \param key The key - * \return ListNode that associated with the key - */ - mapped_type& At(const key_type& key) const { - ListNode iter = Search(key); - ICHECK(!iter.IsNone()) << "IndexError: key is not in Map"; - return iter.Val(); - } - /*! - * \brief Try to insert a key, or do nothing if already exists - * \param key The indexing key - * \param result The linked-list entry found or just constructed - * \return A boolean, indicating if actual insertion happens - */ - bool TryInsert(const key_type& key, ListNode* result) { - if (slots_ == 0) { - return false; - } - // required that `iter` to be the head of a linked list through which we can iterator - ListNode iter = IndexFromHash(ObjectHash()(key)); - // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list - // Case 1: empty - if (iter.IsEmpty()) { - iter.NewHead(KVType(key, ObjectRef(nullptr))); - this->size_ += 1; - *result = iter; - return true; - } - // Case 2: body of an irrelevant list - if (!iter.IsHead()) { - // we move the elements around and construct the single-element linked list - return IsFull() ? false : TrySpareListHead(iter, key, result); - } - // Case 3: head of the relevant list - // we iterate through the linked list until the end - // make sure `iter` is the previous element of `next` - ListNode next = iter; - do { - // find equal item, do not insert - if (ObjectEqual()(key, next.Key())) { - *result = next; - return true; - } - // make sure `iter` is the previous element of `next` - iter = next; - } while (next.MoveToNext(this)); - // `iter` is the tail of the linked list - // always check capacity before insertion - if (IsFull()) { - return false; - } - // find the next empty slot - uint8_t jump; - if (!iter.GetNextEmpty(this, &jump, result)) { - return false; - } - result->NewTail(KVType(key, ObjectRef(nullptr))); - // link `iter` to `empty`, and move forward - iter.SetJump(jump); - this->size_ += 1; - return true; - } - /*! - * \brief Spare an entry to be the head of a linked list. - * As described in B3, during insertion, it is possible that the entire linked list does not - * exist, but the slot of its head has been occupied by other linked lists. In this case, we need - * to spare the slot by moving away the elements to another valid empty one to make insertion - * possible. - * \param target The given entry to be spared - * \param key The indexing key - * \param result The linked-list entry constructed as the head - * \return A boolean, if actual insertion happens - */ - bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { - // `target` is not the head of the linked list - // move the original item of `target` (if any) - // and construct new item on the position `target` - // To make `target` empty, we - // 1) find `w` the previous element of `target` in the linked list - // 2) copy the linked list starting from `r = target` - // 3) paste them after `w` - // read from the linked list after `r` - ListNode r = target; - // write to the tail of `w` - ListNode w = target.FindPrev(this); - // after `target` is moved, we disallow writing to the slot - bool is_first = true; - uint8_t r_meta, jump; - ListNode empty; - do { - // `jump` describes how `w` is jumped to `empty` - // rehash if there is no empty space after `w` - if (!w.GetNextEmpty(this, &jump, &empty)) { - return false; - } - // move `r` to `empty` - empty.NewTail(std::move(r.Data())); - // clear the metadata of `r` - r_meta = r.Meta(); - if (is_first) { - is_first = false; - r.SetProtected(); - } else { - r.SetEmpty(); - } - // link `w` to `empty`, and move forward - w.SetJump(jump); - w = empty; - // move `r` forward as well - } while (r.MoveToNext(this, r_meta)); - // finally we have done moving the linked list - // fill data_ into `target` - target.NewHead(KVType(key, ObjectRef(nullptr))); - this->size_ += 1; - *result = target; - return true; - } - /*! - * \brief Remove a ListNode - * \param iter The node to be removed - */ - void Erase(const ListNode& iter) { - this->size_ -= 1; - if (!iter.HasNext()) { - // `iter` is the last - if (!iter.IsHead()) { - // cut the link if there is any - iter.FindPrev(this).SetJump(0); - } - iter.Data().KVType::~KVType(); - iter.SetEmpty(); - } else { - ListNode last = iter, prev = iter; - for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { - } - iter.Data() = std::move(last.Data()); - last.SetEmpty(); - prev.SetJump(0); - } - } - /*! \brief Clear the container to empty, release all entries and memory acquired */ - void Reset() { - uint64_t n_blocks = CalcNumBlocks(this->slots_); - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr = data_[bi].bytes; - KVType* data_ptr = reinterpret_cast(data_[bi].bytes + kBlockCap); - for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { - uint8_t& meta = *meta_ptr; - if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { - meta = uint8_t(kEmptySlot); - data_ptr->KVType::~KVType(); - } - } - } - ReleaseMemory(); - } - /*! \brief Release the memory acquired by the container without deleting its entries stored inside - */ - void ReleaseMemory() { - delete[] data_; - data_ = nullptr; - slots_ = 0; - size_ = 0; - fib_shift_ = 63; - } - /*! - * \brief Create an empty container - * \param fib_shift The fib shift provided - * \param n_slots Number of slots required, should be power-of-two - * \return The object created - */ - static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { - ICHECK_GT(n_slots, uint64_t(SmallMapNode::kMaxSize)); - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(n_slots - 1); - Block* block = p->data_ = new Block[n_blocks]; - p->slots_ = n_slots - 1; - p->size_ = 0; - p->fib_shift_ = fib_shift; - for (uint64_t i = 0; i < n_blocks; ++i, ++block) { - std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot)); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another DenseMapNode - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(DenseMapNode* from) { - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(from->slots_); - p->data_ = new Block[n_blocks]; - p->slots_ = from->slots_; - p->size_ = from->size_; - p->fib_shift_ = from->fib_shift_; - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr_from = from->data_[bi].bytes; - KVType* data_ptr_from = reinterpret_cast(from->data_[bi].bytes + kBlockCap); - uint8_t* meta_ptr_to = p->data_[bi].bytes; - KVType* data_ptr_to = reinterpret_cast(p->data_[bi].bytes + kBlockCap); - for (int j = 0; j < kBlockCap; - ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { - uint8_t& meta = *meta_ptr_to = *meta_ptr_from; - ICHECK(meta != kProtectedSlot); - if (meta != uint8_t(kEmptySlot)) { - new (data_ptr_to) KVType(*data_ptr_from); - } - } - } - return p; - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - DenseMapNode* map_node = static_cast(map->get()); - ListNode iter; - // Try to insert. If succeed, we simply return - if (map_node->TryInsert(kv.first, &iter)) { - iter.Val() = kv.second; - return; - } - ICHECK_GT(map_node->slots_, uint64_t(SmallMapNode::kMaxSize)); - // Otherwise, start rehash - ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2); - // Insert the given `kv` into the new hash map - InsertMaybeReHash(kv, &p); - uint64_t n_blocks = CalcNumBlocks(map_node->slots_); - // Then Insert data from the original block. - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr = map_node->data_[bi].bytes; - KVType* data_ptr = reinterpret_cast(map_node->data_[bi].bytes + kBlockCap); - for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { - uint8_t& meta = *meta_ptr; - if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { - meta = uint8_t(kEmptySlot); - KVType kv = std::move(*data_ptr); - InsertMaybeReHash(kv, &p); - } - } - } - map_node->ReleaseMemory(); - *map = p; - } - /*! - * \brief Check whether the hash table is full - * \return A boolean indicating whether hash table is full - */ - bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { - for (++index; index <= slots_; ++index) { - if (!ListNode(index, this).IsEmpty()) { - return index; - } - } - return slots_ + 1; - } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { - while (index != 0) { - index -= 1; - if (!ListNode(index, this).IsEmpty()) { - return index; - } - } - return slots_ + 1; - } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } - /*! \brief Construct from hash code */ - ListNode IndexFromHash(uint64_t hash_value) const { - return ListNode(FibHash(hash_value, fib_shift_), this); - } - /*! \brief Construct from hash code if the position is head of list */ - ListNode GetListHead(uint64_t hash_value) const { - ListNode node = IndexFromHash(hash_value); - return node.IsHead() ? node : ListNode(); - } - /*! \brief Construct the number of blocks in the hash table */ - static uint64_t CalcNumBlocks(uint64_t n_slots_m1) { - uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0; - return (n_slots + kBlockCap - 1) / kBlockCap; - } - /*! - * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. - * \param cap The lower-bound of the required capacity - * \param fib_shift The result shift for Fibonacci Hashing - * \param n_slots The result number of slots - */ - static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { - uint32_t shift = 64; - uint64_t slots = 1; - for (uint64_t c = cap; c; c >>= 1) { - shift -= 1; - slots <<= 1; - } - ICHECK_GT(slots, cap); - if (slots < cap * 2) { - *fib_shift = shift - 1; - *n_slots = slots << 1; - } else { - *fib_shift = shift; - *n_slots = slots; - } - } - /*! - * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. - * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. - * \param hash_value The raw hash value - * \param fib_shift The shift in Fibonacci Hashing - * \return An index calculated using Fibonacci Hashing - */ - static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { - constexpr uint64_t coeff = 11400714819323198485ull; - return (coeff * hash_value) >> fib_shift; - } - /*! \brief The implicit in-place linked list used to index a chain */ - struct ListNode { - /*! \brief Construct None */ - ListNode() : index(0), block(nullptr) {} - /*! \brief Construct from position */ - ListNode(uint64_t index, const DenseMapNode* self) - : index(index), block(self->data_ + (index / kBlockCap)) {} - /*! \brief Metadata on the entry */ - uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } - /*! \brief Data on the entry */ - KVType& Data() const { - return *(reinterpret_cast(block->bytes + kBlockCap + - (index % kBlockCap) * sizeof(KVType))); - } - /*! \brief Key on the entry */ - key_type& Key() const { return Data().first; } - /*! \brief Value on the entry */ - mapped_type& Val() const { return Data().second; } - /*! \brief If the entry is head of linked list */ - bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } - /*! \brief If the entry is none */ - bool IsNone() const { return block == nullptr; } - /*! \brief If the entry is empty slot */ - bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); } - /*! \brief If the entry is protected slot */ - bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); } - /*! \brief Set the entry to be empty */ - void SetEmpty() const { Meta() = uint8_t(kEmptySlot); } - /*! \brief Set the entry to be protected */ - void SetProtected() const { Meta() = uint8_t(kProtectedSlot); } - /*! \brief Set the entry's jump to its next entry */ - void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } - /*! \brief Construct a head of linked list in-place */ - void NewHead(KVType v) const { - Meta() = 0b00000000; - new (&Data()) KVType(std::move(v)); - } - /*! \brief Construct a tail of linked list in-place */ - void NewTail(KVType v) const { - Meta() = 0b10000000; - new (&Data()) KVType(std::move(v)); - } - /*! \brief If the entry has next entry on the linked list */ - bool HasNext() const { return kNextProbeLocation[Meta() & 0b01111111] != 0; } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapNode* self, uint8_t meta) { - uint64_t offset = kNextProbeLocation[meta & 0b01111111]; - if (offset == 0) { - index = 0; - block = nullptr; - return false; - } - index = (index + offset) & (self->slots_); - block = self->data_ + (index / kBlockCap); - return true; - } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapNode* self) { return MoveToNext(self, Meta()); } - /*! \brief Get the previous entry on the linked list */ - ListNode FindPrev(const DenseMapNode* self) const { - // start from the head of the linked list, which must exist - ListNode next = self->IndexFromHash(ObjectHash()(Key())); - // `prev` is always the previous item of `next` - ListNode prev = next; - for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { - } - return prev; - } - /*! \brief Get the next empty jump */ - bool GetNextEmpty(const DenseMapNode* self, uint8_t* jump, ListNode* result) const { - for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { - ListNode candidate((index + kNextProbeLocation[idx]) & (self->slots_), self); - if (candidate.IsEmpty()) { - *jump = idx; - *result = candidate; - return true; - } - } - return false; - } - /*! \brief Index on the real array */ - uint64_t index; - /*! \brief Pointer to the actual block */ - Block* block; - }; - - protected: - /*! \brief fib shift in Fibonacci Hashing */ - uint32_t fib_shift_; - /*! \brief array of data blocks */ - Block* data_; - /* clang-format off */ - /*! \brief Candidates of probing distance */ - TVM_DLL static constexpr uint64_t kNextProbeLocation[kNumJumpDists] { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - // Quadratic probing with triangle numbers. See also: - // 1) https://en.wikipedia.org/wiki/Quadratic_probing - // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - // 3) https://github.com/skarupke/flat_hash_map - 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, - 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, - 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, - 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, - 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, - 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, - 2211, 2278, 2346, 2415, 2485, 2556, 2628, - // larger triangle numbers - 8515, 19110, 42778, 96141, 216153, - 486591, 1092981, 2458653, 5532801, 12442566, - 27993903, 62983476, 141717030, 318844378, 717352503, - 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, - 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, - 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, - 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, - 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, 457381325854679626, - 1029107982097042876, 2315492959180353330, 5209859154120846435, - }; - /* clang-format on */ - friend class MapNode; -}; - -#define TVM_DISPATCH_MAP(base, var, body) \ - { \ - using TSmall = SmallMapNode*; \ - using TDense = DenseMapNode*; \ - uint64_t slots = base->slots_; \ - if (slots <= SmallMapNode::kMaxSize) { \ - TSmall var = static_cast(base); \ - body; \ - } else { \ - TDense var = static_cast(base); \ - body; \ - } \ - } - -#define TVM_DISPATCH_MAP_CONST(base, var, body) \ - { \ - using TSmall = const SmallMapNode*; \ - using TDense = const DenseMapNode*; \ - uint64_t slots = base->slots_; \ - if (slots <= SmallMapNode::kMaxSize) { \ - TSmall var = static_cast(base); \ - body; \ - } else { \ - TDense var = static_cast(base); \ - body; \ - } \ - } - -inline MapNode::iterator::pointer MapNode::iterator::operator->() const { - TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); -} - -inline MapNode::iterator& MapNode::iterator::operator++() { - TVM_DISPATCH_MAP_CONST(self, p, { - index = p->IncItr(index); - return *this; - }); -} - -inline MapNode::iterator& MapNode::iterator::operator--() { - TVM_DISPATCH_MAP_CONST(self, p, { - index = p->DecItr(index); - return *this; - }); -} - -inline size_t MapNode::count(const key_type& key) const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); -} - -inline const MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); -} - -inline MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) { - TVM_DISPATCH_MAP(this, p, { return p->at(key); }); -} - -inline MapNode::iterator MapNode::begin() const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); -} - -inline MapNode::iterator MapNode::end() const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); }); -} - -inline MapNode::iterator MapNode::find(const MapNode::key_type& key) const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); -} - -inline void MapNode::erase(const MapNode::iterator& position) { - TVM_DISPATCH_MAP(this, p, { return p->erase(position); }); -} - -#undef TVM_DISPATCH_MAP -#undef TVM_DISPATCH_MAP_CONST - -inline ObjectPtr MapNode::Empty() { return SmallMapNode::Empty(); } - -inline ObjectPtr MapNode::CopyFrom(MapNode* from) { - if (from->slots_ <= SmallMapNode::kMaxSize) { - return SmallMapNode::CopyFrom(static_cast(from)); - } else { - return DenseMapNode::CopyFrom(static_cast(from)); - } -} - -template -inline ObjectPtr MapNode::CreateFromRange(IterType first, IterType last) { - int64_t _cap = std::distance(first, last); - if (_cap < 0) { - return SmallMapNode::Empty(); - } - uint64_t cap = static_cast(_cap); - if (cap < SmallMapNode::kMaxSize) { - return SmallMapNode::CreateFromRange(cap, first, last); - } - uint32_t fib_shift; - uint64_t n_slots; - DenseMapNode::CalcTableSize(cap, &fib_shift, &n_slots); - ObjectPtr obj = DenseMapNode::Empty(fib_shift, n_slots); - for (; first != last; ++first) { - KVType kv(*first); - DenseMapNode::InsertMaybeReHash(kv, &obj); - } - return obj; -} - -inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize; - MapNode* base = static_cast(map->get()); - if (base->slots_ < kSmallMapMaxSize) { - SmallMapNode::InsertMaybeReHash(kv, map); - } else if (base->slots_ == kSmallMapMaxSize) { - if (base->size_ < base->slots_) { - SmallMapNode::InsertMaybeReHash(kv, map); - } else { - ObjectPtr new_map = MapNode::CreateFromRange(base->begin(), base->end()); - DenseMapNode::InsertMaybeReHash(kv, &new_map); - *map = std::move(new_map); - } - } else { - DenseMapNode::InsertMaybeReHash(kv, map); - } -} - -template <> -inline ObjectPtr make_object<>() = delete; - -#endif - -/*! - * \brief Map container of NodeRef->NodeRef in DSL graph. - * Map implements copy on write semantics, which means map is mutable - * but copy will happen when array is referenced in more than two places. - * - * operator[] only provide const acces, use Set to mutate the content. - * \tparam K The key NodeRef type. - * \tparam V The value NodeRef type. - */ -template ::value>::type, - typename = typename std::enable_if::value>::type> -class Map : public ObjectRef { - public: - using key_type = K; - using mapped_type = V; - class iterator; - /*! - * \brief default constructor - */ - Map() { data_ = MapNode::Empty(); } - /*! - * \brief move constructor - * \param other source - */ - Map(Map&& other) { data_ = std::move(other.data_); } - /*! - * \brief copy constructor - * \param other source - */ - Map(const Map& other) : ObjectRef(other.data_) {} - /*! - * \brief copy assign operator - * \param other The source of assignment - * \return reference to self. - */ - Map& operator=(Map&& other) { - data_ = std::move(other.data_); - return *this; - } - /*! - * \brief move assign operator - * \param other The source of assignment - * \return reference to self. - */ - Map& operator=(const Map& other) { - data_ = other.data_; - return *this; - } - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Map(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Map(IterType begin, IterType end) { - data_ = MapNode::CreateFromRange(begin, end); - } - /*! - * \brief constructor from initializer list - * \param init The initalizer list - */ - Map(std::initializer_list> init) { - data_ = MapNode::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief constructor from unordered_map - * \param init The unordered_map - */ - template - Map(const std::unordered_map& init) { // NOLINT(*) - data_ = MapNode::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V at(const K& key) const { return DowncastNoCheck(GetMapNode()->at(key)); } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V operator[](const K& key) const { return this->at(key); } - /*! \return The size of the array */ - size_t size() const { - MapNode* n = GetMapNode(); - return n == nullptr ? 0 : n->size(); - } - /*! \return The number of elements of the key */ - size_t count(const K& key) const { - MapNode* n = GetMapNode(); - return n == nullptr ? 0 : GetMapNode()->count(key); - } - /*! \return whether array is empty */ - bool empty() const { return size() == 0; } - /*! \brief Release reference to all the elements */ - void clear() { - MapNode* n = GetMapNode(); - if (n != nullptr) { - data_ = MapNode::Empty(); - } - } - /*! - * \brief set the Map. - * \param key The index key. - * \param value The value to be setted. - */ - void Set(const K& key, const V& value) { - CopyOnWrite(); - MapNode::InsertMaybeReHash(MapNode::KVType(key, value), &data_); - } - /*! \return begin iterator */ - iterator begin() const { return iterator(GetMapNode()->begin()); } - /*! \return end iterator */ - iterator end() const { return iterator(GetMapNode()->end()); } - /*! \return find the key and returns the associated iterator */ - iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); } - - void erase(const K& key) { CopyOnWrite()->erase(key); } - - /*! - * \brief copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - MapNode* CopyOnWrite() { - if (data_.get() == nullptr) { - data_ = MapNode::Empty(); - } else if (!data_.unique()) { - data_ = MapNode::CopyFrom(GetMapNode()); - } - return GetMapNode(); - } - /*! \brief specify container node */ - using ContainerType = MapNode; - - /*! \brief Iterator of the hash map */ - class iterator { - public: - using iterator_category = std::bidirectional_iterator_tag; - using difference_type = int64_t; - using value_type = const std::pair; - using pointer = value_type*; - using reference = value_type; - - iterator() : itr() {} - - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { return itr == other.itr; } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return itr != other.itr; } - /*! \brief De-reference iterators is not allowed */ - pointer operator->() const = delete; - /*! \brief De-reference iterators */ - reference operator*() const { - auto& kv = *itr; - return std::make_pair(DowncastNoCheck(kv.first), DowncastNoCheck(kv.second)); - } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++() { - ++itr; - return *this; - } - /*! \brief Suffix self increment */ - iterator operator++(int) { - iterator copy = *this; - ++(*this); - return copy; - } - - private: - iterator(const MapNode::iterator& itr) // NOLINT(*) - : itr(itr) {} - - template - friend class Map; - - MapNode::iterator itr; - }; - - private: - /*! \brief Return data_ as type of pointer of MapNode */ - MapNode* GetMapNode() const { return static_cast(data_.get()); } -}; - -/*! - * \brief Merge two Maps. - * \param lhs the first Map to merge. - * \param rhs the second Map to merge. - * @return The merged Array. Original Maps are kept unchanged. - */ -template ::value>::type, - typename = typename std::enable_if::value>::type> -inline Map Merge(Map lhs, const Map& rhs) { - for (const auto& p : rhs) { - lhs.Set(p.first, p.second); - } - return std::move(lhs); -} - -} // namespace runtime - -// expose the functions to the root namespace. -using runtime::Array; -using runtime::ArrayNode; -using runtime::Downcast; -using runtime::IterAdapter; -using runtime::make_object; -using runtime::Map; -using runtime::MapNode; -using runtime::Object; -using runtime::ObjectEqual; -using runtime::ObjectHash; -using runtime::ObjectPtr; -using runtime::ObjectPtrEqual; -using runtime::ObjectPtrHash; -using runtime::ObjectRef; -using runtime::Optional; -using runtime::String; -using runtime::StringObj; -constexpr runtime::NullOptType NullOpt{}; -} // namespace tvm - -namespace std { - -template <> -struct hash<::tvm::runtime::String> { - std::size_t operator()(const ::tvm::runtime::String& str) const { - return ::tvm::runtime::String::HashBytes(str.data(), str.size()); - } -}; -} // namespace std - -#endif // TVM_RUNTIME_CONTAINER_H_ diff --git a/include/tvm/runtime/container/adt.h b/include/tvm/runtime/container/adt.h new file mode 100644 index 000000000000..20c4f796d741 --- /dev/null +++ b/include/tvm/runtime/container/adt.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/adt.h + * \brief Runtime ADT container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_ADT_H_ +#define TVM_RUNTIME_CONTAINER_ADT_H_ + +#include +#include + +#include "./base.h" + +namespace tvm { +namespace runtime { + +/*! \brief An object representing a structure or enumeration. */ +class ADTObj : public Object, public InplaceArrayBase { + public: + /*! \brief The tag representing the constructor used. */ + int32_t tag; + /*! \brief Number of fields in the ADT object. */ + uint32_t size; + // The fields of the structure follows directly in memory. + + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT; + static constexpr const char* _type_key = "runtime.ADT"; + TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); + + private: + /*! + * \return The number of elements in the array. + */ + size_t GetSize() const { return size; } + + /*! + * \brief Initialize the elements in the array. + * + * \tparam Iterator Iterator type of the array. + * \param begin The begin iterator. + * \param end The end iterator. + */ + template + void Init(Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + this->size = 0; + auto it = begin; + for (size_t i = 0; i < num_elems; ++i) { + InplaceArrayBase::EmplaceInit(i, *it++); + // Only increment size after the initialization succeeds + this->size++; + } + } + + friend class ADT; + friend InplaceArrayBase; +}; + +/*! \brief reference to algebraic data type objects. */ +class ADT : public ObjectRef { + public: + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param fields The fields of the ADT object. + * \return The constructed ADT object reference. + */ + ADT(int32_t tag, std::vector fields) : ADT(tag, fields.begin(), fields.end()){}; + + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param begin The begin iterator to the start of the fields array. + * \param end The end iterator to the end of the fields array. + * \return The constructed ADT object reference. + */ + template + ADT(int32_t tag, Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + auto ptr = make_inplace_array_object(num_elems); + ptr->tag = tag; + ptr->Init(begin, end); + data_ = std::move(ptr); + } + + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param init The initializer list of fields. + * \return The constructed ADT object reference. + */ + ADT(int32_t tag, std::initializer_list init) : ADT(tag, init.begin(), init.end()){}; + + /*! + * \brief Access element at index. + * + * \param idx The array index + * \return const ObjectRef + */ + const ObjectRef& operator[](size_t idx) const { return operator->()->operator[](idx); } + + /*! + * \brief Return the ADT tag. + */ + int32_t tag() const { return operator->()->tag; } + + /*! + * \brief Return the number of fields. + */ + size_t size() const { return operator->()->size; } + + /*! + * \brief Construct a tuple object. + * + * \tparam Args Type params of tuple feilds. + * \param args Tuple fields. + * \return ADT The tuple object reference. + */ + template + static ADT Tuple(Args&&... args) { + return ADT(0, std::forward(args)...); + } + + TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); +}; +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_CONTAINER_ADT_H_ diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h new file mode 100644 index 000000000000..8830653da88c --- /dev/null +++ b/include/tvm/runtime/container/array.h @@ -0,0 +1,739 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/array.h + * \brief Runtime Array container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_ARRAY_H_ +#define TVM_RUNTIME_CONTAINER_ARRAY_H_ + +#include +#include +#include +#include + +#include "./base.h" + +namespace tvm { +namespace runtime { + +/*! \brief array node content in array */ +class ArrayNode : public Object, public InplaceArrayBase { + public: + /*! \return The size of the array */ + size_t size() const { return this->size_; } + + /*! + * \brief Read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const ObjectRef at(int64_t i) const { return this->operator[](i); } + + /*! \return begin constant iterator */ + const ObjectRef* begin() const { return static_cast(InplaceArrayBase::AddressOf(0)); } + + /*! \return end constant iterator */ + const ObjectRef* end() const { return begin() + size_; } + + /*! \brief Release reference to all the elements */ + void clear() { ShrinkBy(size_); } + + /*! + * \brief Set i-th element of the array in-place + * \param i The index + * \param item The value to be set + */ + void SetItem(int64_t i, ObjectRef item) { this->operator[](i) = std::move(item); } + + /*! + * \brief Constructs a container and copy from another + * \param cap The capacity of the container + * \param from Source of the copy + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr CopyFrom(int64_t cap, ArrayNode* from) { + int64_t size = from->size_; + ICHECK_GE(cap, size) << "ValueError: not enough capacity"; + ObjectPtr p = ArrayNode::Empty(cap); + ObjectRef* write = p->MutableBegin(); + ObjectRef* read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) ObjectRef(*read++); + } + return p; + } + + /*! + * \brief Constructs a container and move from another + * \param cap The capacity of the container + * \param from Source of the move + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr MoveFrom(int64_t cap, ArrayNode* from) { + int64_t size = from->size_; + ICHECK_GE(cap, size) << "ValueError: not enough capacity"; + ObjectPtr p = ArrayNode::Empty(cap); + ObjectRef* write = p->MutableBegin(); + ObjectRef* read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) ObjectRef(std::move(*read++)); + } + from->size_ = 0; + return p; + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr CreateRepeated(int64_t n, const ObjectRef& val) { + ObjectPtr p = ArrayNode::Empty(n); + ObjectRef* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < n; ++i) { + new (itr++) ObjectRef(val); + } + return p; + } + + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray; + static constexpr const char* _type_key = "Array"; + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); + + private: + /*! \return Size of initialized memory, used by InplaceArrayBase. */ + size_t GetSize() const { return this->size_; } + + /*! \return begin mutable iterator */ + ObjectRef* MutableBegin() const { + return static_cast(InplaceArrayBase::AddressOf(0)); + } + + /*! \return end mutable iterator */ + ObjectRef* MutableEnd() const { return MutableBegin() + size_; } + + /*! + * \brief Create an ArrayNode with the given capacity. + * \param n Required capacity + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr Empty(int64_t n = kInitSize) { + ICHECK_GE(n, 0); + ObjectPtr p = make_inplace_array_object(n); + p->capacity_ = n; + p->size_ = 0; + return p; + } + + /*! + * \brief Inplace-initialize the elements starting idx from [first, last) + * \param idx The starting point + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return Self + */ + template + ArrayNode* InitRange(int64_t idx, IterType first, IterType last) { + ObjectRef* itr = MutableBegin() + idx; + for (; first != last; ++first) { + ObjectRef ref = *first; + new (itr++) ObjectRef(std::move(ref)); + } + return this; + } + + /*! + * \brief Move elements from right to left, requires src_begin > dst + * \param dst Destination + * \param src_begin The start point of copy (inclusive) + * \param src_end The end point of copy (exclusive) + * \return Self + */ + ArrayNode* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { + ObjectRef* from = MutableBegin() + src_begin; + ObjectRef* to = MutableBegin() + dst; + while (src_begin++ != src_end) { + *to++ = std::move(*from++); + } + return this; + } + + /*! + * \brief Move elements from left to right, requires src_begin < dst + * \param dst Destination + * \param src_begin The start point of move (inclusive) + * \param src_end The end point of move (exclusive) + * \return Self + */ + ArrayNode* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { + ObjectRef* from = MutableBegin() + src_end; + ObjectRef* to = MutableBegin() + (src_end - src_begin + dst); + while (src_begin++ != src_end) { + *--to = std::move(*--from); + } + return this; + } + + /*! + * \brief Enlarges the size of the array + * \param delta Size enlarged, should be positive + * \param val Default value + * \return Self + */ + ArrayNode* EnlargeBy(int64_t delta, const ObjectRef& val = ObjectRef(nullptr)) { + ObjectRef* itr = MutableEnd(); + while (delta-- > 0) { + new (itr++) ObjectRef(val); + ++size_; + } + return this; + } + + /*! + * \brief Shrinks the size of the array + * \param delta Size shrinked, should be positive + * \return Self + */ + ArrayNode* ShrinkBy(int64_t delta) { + ObjectRef* itr = MutableEnd(); + while (delta-- > 0) { + (--itr)->ObjectRef::~ObjectRef(); + --size_; + } + return this; + } + + /*! \brief Number of elements used */ + int64_t size_; + + /*! \brief Number of elements allocated */ + int64_t capacity_; + + /*! \brief Initial size of ArrayNode */ + static constexpr int64_t kInitSize = 4; + + /*! \brief Expansion factor of the Array */ + static constexpr int64_t kIncFactor = 2; + + // CRTP parent class + friend InplaceArrayBase; + + // Reference class + template + friend class Array; + + // To specialize make_object + friend ObjectPtr make_object<>(); +}; + +/*! + * \brief Array, container representing a contigious sequence of ObjectRefs. + * + * Array implements in-place copy-on-write semantics. + * + * As in typical copy-on-write, a method which would typically mutate the array + * instead opaquely copies the underlying container, and then acts on its copy. + * + * If the array has reference count equal to one, we directly update the + * container in place without copying. This is optimization is sound because + * when the reference count is equal to one this reference is guranteed to be + * the sole pointer to the container. + * + * + * operator[] only provides const access, use Set to mutate the content. + * \tparam T The content ObjectRef type. + */ +template ::value>::type> +class Array : public ObjectRef { + public: + using value_type = T; + // constructors + /*! + * \brief default constructor + */ + Array() { data_ = ArrayNode::Empty(); } + + /*! + * \brief move constructor + * \param other source + */ + Array(Array&& other) : ObjectRef() { // NOLINT(*) + data_ = std::move(other.data_); + } + + /*! + * \brief copy constructor + * \param other source + */ + Array(const Array& other) : ObjectRef() { // NOLINT(*) + data_ = other.data_; + } + + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Array(ObjectPtr n) : ObjectRef(n) {} + + /*! + * \brief Constructor from iterator + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template + Array(IterType first, IterType last) { + Assign(first, last); + } + + /*! + * \brief constructor from initializer list + * \param init The initializer list + */ + Array(std::initializer_list init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief constructor from vector + * \param init The vector + */ + Array(const std::vector& init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + */ + explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); } + + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array& operator=(Array&& other) { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array& operator=(const Array& other) { + data_ = other.data_; + return *this; + } + + public: + // iterators + struct ValueConverter { + using ResultType = T; + static T convert(const ObjectRef& n) { return DowncastNoCheck(n); } + }; + + using iterator = IterAdapter; + using reverse_iterator = ReverseIterAdapter; + + /*! \return begin iterator */ + iterator begin() const { return iterator(GetArrayNode()->begin()); } + + /*! \return end iterator */ + iterator end() const { return iterator(GetArrayNode()->end()); } + + /*! \return rbegin iterator */ + reverse_iterator rbegin() const { + // ArrayNode::end() is never nullptr + return reverse_iterator(GetArrayNode()->end() - 1); + } + + /*! \return rend iterator */ + reverse_iterator rend() const { + // ArrayNode::begin() is never nullptr + return reverse_iterator(GetArrayNode()->begin() - 1); + } + + public: + // const methods in std::vector + /*! + * \brief Immutably read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const T operator[](int64_t i) const { + ArrayNode* p = GetArrayNode(); + ICHECK(p != nullptr) << "ValueError: cannot index a null array"; + ICHECK(0 <= i && i < p->size_) + << "IndexError: indexing " << i << " on an array of size " << p->size_; + return DowncastNoCheck(*(p->begin() + i)); + } + + /*! \return The size of the array */ + size_t size() const { + ArrayNode* p = GetArrayNode(); + return p == nullptr ? 0 : GetArrayNode()->size_; + } + + /*! \return The capacity of the array */ + size_t capacity() const { + ArrayNode* p = GetArrayNode(); + return p == nullptr ? 0 : GetArrayNode()->capacity_; + } + + /*! \return Whether array is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the array */ + const T front() const { + ArrayNode* p = GetArrayNode(); + ICHECK(p != nullptr) << "ValueError: cannot index a null array"; + ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; + return DowncastNoCheck(*(p->begin())); + } + + /*! \return The last element of the array */ + const T back() const { + ArrayNode* p = GetArrayNode(); + ICHECK(p != nullptr) << "ValueError: cannot index a null array"; + ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; + return DowncastNoCheck(*(p->end() - 1)); + } + + public: + // mutation in std::vector, implements copy-on-write + + /*! + * \brief push a new item to the back of the list + * \param item The item to be pushed. + */ + void push_back(const T& item) { + ArrayNode* p = CopyOnWrite(1); + p->EmplaceInit(p->size_++, item); + } + + /*! + * \brief Insert an element into the given position + * \param position An iterator pointing to the insertion point + * \param val The element to insert + */ + void insert(iterator position, const T& val) { + ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + auto addr = CopyOnWrite(1) // + ->EnlargeBy(1) // + ->MoveElementsRight(idx + 1, idx, size) // + ->MutableBegin(); + new (addr + idx) ObjectRef(val); + } + + /*! + * \brief Insert a range of elements into the given position + * \param position An iterator pointing to the insertion point + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + template + void insert(iterator position, IterType first, IterType last) { + if (first == last) { + return; + } + ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + int64_t numel = std::distance(first, last); + CopyOnWrite(numel) + ->EnlargeBy(numel) + ->MoveElementsRight(idx + numel, idx, size) + ->InitRange(idx, first, last); + } + + /*! \brief Remove the last item of the list */ + void pop_back() { + ICHECK(data_ != nullptr) << "ValueError: cannot pop_back because array is null"; + int64_t size = GetArrayNode()->size_; + ICHECK_GT(size, 0) << "ValueError: cannot pop_back because array is empty"; + CopyOnWrite()->ShrinkBy(1); + } + + /*! + * \brief Erase an element on the given position + * \param position An iterator pointing to the element to be erased + */ + void erase(iterator position) { + ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; + int64_t st = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + ICHECK(0 <= st && st < size) << "ValueError: cannot erase at index " << st + << ", because Array size is " << size; + CopyOnWrite() // + ->MoveElementsLeft(st, st + 1, size) // + ->ShrinkBy(1); + } + + /*! + * \brief Erase a given range of elements + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + void erase(iterator first, iterator last) { + if (first == last) { + return; + } + ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; + int64_t size = GetArrayNode()->size_; + int64_t st = std::distance(begin(), first); + int64_t ed = std::distance(begin(), last); + ICHECK_LT(st, ed) << "ValueError: cannot erase array in range [" << st << ", " << ed << ")"; + ICHECK(0 <= st && st <= size && 0 <= ed && ed <= size) + << "ValueError: cannot erase array in range [" << st << ", " << ed << ")" + << ", because array size is " << size; + CopyOnWrite() // + ->MoveElementsLeft(st, ed, size) // + ->ShrinkBy(ed - st); + } + + /*! + * \brief Resize the array. + * \param n The new size. + */ + void resize(int64_t n) { + ICHECK_GE(n, 0) << "ValueError: cannot resize an Array to negative size"; + if (data_ == nullptr) { + SwitchContainer(n); + return; + } + int64_t size = GetArrayNode()->size_; + if (size < n) { + CopyOnWrite(n - size)->EnlargeBy(n - size); + } else if (size > n) { + CopyOnWrite()->ShrinkBy(size - n); + } + } + + /*! + * \brief Make sure the list has the capacity of at least n + * \param n lower bound of the capacity + */ + void reserve(int64_t n) { + if (data_ == nullptr || n > GetArrayNode()->capacity_) { + SwitchContainer(n); + } + } + + /*! \brief Release reference to all the elements */ + void clear() { + if (data_ != nullptr) { + ArrayNode* p = CopyOnWrite(); + p->clear(); + } + } + + public: + // Array's own methods + + /*! + * \brief set i-th element of the array. + * \param i The index + * \param value The value to be setted. + */ + void Set(int64_t i, T value) { + ArrayNode* p = this->CopyOnWrite(); + ICHECK(0 <= i && i < p->size_) + << "IndexError: indexing " << i << " on an array of size " << p->size_; + *(p->MutableBegin() + i) = std::move(value); + } + + /*! \return The underlying ArrayNode */ + ArrayNode* GetArrayNode() const { return static_cast(data_.get()); } + + /*! + * \brief Helper function to apply fmutate to mutate an array. + * \param fmutate The transformation function T -> T. + * \tparam F the type of the mutation function. + * \note This function performs copy on write optimization. + */ + template + void MutateByApply(F fmutate) { + if (data_ == nullptr) { + return; + } + struct StackFrame { + ArrayNode* p; + ObjectRef* itr; + int64_t i; + int64_t size; + }; + std::unique_ptr s = std::make_unique(); + s->p = GetArrayNode(); + s->itr = s->p->MutableBegin(); + s->i = 0; + s->size = s->p->size_; + if (!data_.unique()) { + // Loop invariant: keeps iterating when + // 1) data is not unique + // 2) no elements are actually mutated yet + for (; s->i < s->size; ++s->i, ++s->itr) { + T new_elem = fmutate(DowncastNoCheck(*s->itr)); + // do nothing when there is no mutation + if (new_elem.same_as(*s->itr)) { + continue; + } + // loop invariant breaks when the first real mutation happens + // we copy the elements into a new unique array + ObjectPtr copy = ArrayNode::CopyFrom(s->p->capacity_, s->p); + s->itr = copy->MutableBegin() + (s->i++); + *s->itr++ = std::move(new_elem); + data_ = std::move(copy); + // make sure `data_` is unique and break + break; + } + } + // when execution comes to this line, it is guaranteed that either + // 1) i == size + // or 2) data_.unique() is true + for (; s->i < s->size; ++s->i, ++s->itr) { + *s->itr = std::move(fmutate(std::move(DowncastNoCheck(std::move(*s->itr))))); + } + } + + /*! + * \brief reset the array to content from iterator. + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template + void Assign(IterType first, IterType last) { + int64_t cap = std::distance(first, last); + ICHECK_GE(cap, 0) << "ValueError: cannot construct an Array of negative size"; + ArrayNode* p = GetArrayNode(); + if (p != nullptr && data_.unique() && p->capacity_ >= cap) { + // do not have to make new space + p->clear(); + } else { + // create new space + data_ = ArrayNode::Empty(cap); + p = GetArrayNode(); + } + // To ensure exception safety, size is only incremented after the initialization succeeds + ObjectRef* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { + new (itr) ObjectRef(*first); + } + } + + /*! + * \brief Copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + ArrayNode* CopyOnWrite() { + if (data_ == nullptr) { + return SwitchContainer(ArrayNode::kInitSize); + } + if (!data_.unique()) { + return SwitchContainer(capacity()); + } + return static_cast(data_.get()); + } + + /*! \brief specify container node */ + using ContainerType = ArrayNode; + + private: + /*! + * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. + * \param reserve_extra Number of extra slots needed + * \return ArrayNode pointer to the unique copy + */ + ArrayNode* CopyOnWrite(int64_t reserve_extra) { + ArrayNode* p = GetArrayNode(); + if (p == nullptr) { + // necessary to get around the constexpr address issue before c++17 + const int64_t kInitSize = ArrayNode::kInitSize; + return SwitchContainer(std::max(kInitSize, reserve_extra)); + } + if (p->capacity_ >= p->size_ + reserve_extra) { + return CopyOnWrite(); + } + int64_t cap = p->capacity_ * ArrayNode::kIncFactor; + cap = std::max(cap, p->size_ + reserve_extra); + return SwitchContainer(cap); + } + + /*! + * \brief Move or copy the ArrayNode to new address with the given capacity + * \param capacity The capacity requirement of the new address + */ + ArrayNode* SwitchContainer(int64_t capacity) { + if (data_ == nullptr) { + data_ = ArrayNode::Empty(capacity); + } else if (data_.unique()) { + data_ = ArrayNode::MoveFrom(capacity, GetArrayNode()); + } else { + data_ = ArrayNode::CopyFrom(capacity, GetArrayNode()); + } + return static_cast(data_.get()); + } +}; + +/*! + * \brief Concat two Arrays. + * \param lhs first Array to be concatenated. + * \param rhs second Array to be concatenated. + * \return The concatenated Array. Original Arrays are kept unchanged. + */ +template ::value>::type> +inline Array Concat(Array lhs, const Array& rhs) { + for (const auto& x : rhs) { + lhs.push_back(x); + } + return std::move(lhs); +} + +// Specialize make_object to make sure it is correct. +template <> +inline ObjectPtr make_object() { + return ArrayNode::Empty(); +} + +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::Array; +using runtime::ArrayNode; +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_ARRAY_H_ diff --git a/include/tvm/runtime/container/base.h b/include/tvm/runtime/container/base.h new file mode 100644 index 000000000000..4112c213d6f0 --- /dev/null +++ b/include/tvm/runtime/container/base.h @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/base.h + * \brief Base utilities for common POD(plain old data) container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_BASE_H_ +#define TVM_RUNTIME_CONTAINER_BASE_H_ + +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! \brief String-aware ObjectRef equal functor */ +struct ObjectHash { + /*! + * \brief Calculate the hash code of an ObjectRef + * \param a The given ObjectRef + * \return Hash code of a, string hash for strings and pointer address otherwise. + */ + size_t operator()(const ObjectRef& a) const; +}; + +/*! \brief String-aware ObjectRef hash functor */ +struct ObjectEqual { + /*! + * \brief Check if the two ObjectRef are equal + * \param a One ObjectRef + * \param b The other ObjectRef + * \return String equality if both are strings, pointer address equality otherwise. + */ + bool operator()(const ObjectRef& a, const ObjectRef& b) const; +}; + +/*! + * \brief Base template for classes with array like memory layout. + * + * It provides general methods to access the memory. The memory + * layout is ArrayType + [ElemType]. The alignment of ArrayType + * and ElemType is handled by the memory allocator. + * + * \tparam ArrayType The array header type, contains object specific metadata. + * \tparam ElemType The type of objects stored in the array right after + * ArrayType. + * + * \code + * // Example usage of the template to define a simple array wrapper + * class ArrayObj : public InplaceArrayBase { + * public: + * // Wrap EmplaceInit to initialize the elements + * template + * void Init(Iterator begin, Iterator end) { + * size_t num_elems = std::distance(begin, end); + * auto it = begin; + * this->size = 0; + * for (size_t i = 0; i < num_elems; ++i) { + * InplaceArrayBase::EmplaceInit(i, *it++); + * this->size++; + * } + * } + * } + * + * void test_function() { + * vector fields; + * auto ptr = make_inplace_array_object(fields.size()); + * ptr->Init(fields.begin(), fields.end()); + * + * // Access the 0th element in the array. + * assert(ptr->operator[](0) == fields[0]); + * } + * + * \endcode + */ +template +class InplaceArrayBase { + public: + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Const reference to ElemType at the index. + */ + const ElemType& operator[](size_t idx) const { + size_t size = Self()->GetSize(); + ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; + return *(reinterpret_cast(AddressOf(idx))); + } + + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Reference to ElemType at the index. + */ + ElemType& operator[](size_t idx) { + size_t size = Self()->GetSize(); + ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; + return *(reinterpret_cast(AddressOf(idx))); + } + + /*! + * \brief Destroy the Inplace Array Base object + */ + ~InplaceArrayBase() { + if (!(std::is_standard_layout::value && std::is_trivial::value)) { + size_t size = Self()->GetSize(); + for (size_t i = 0; i < size; ++i) { + ElemType* fp = reinterpret_cast(AddressOf(i)); + fp->ElemType::~ElemType(); + } + } + } + + protected: + /*! + * \brief Construct a value in place with the arguments. + * + * \tparam Args Type parameters of the arguments. + * \param idx Index of the element. + * \param args Arguments to construct the new value. + * + * \note Please make sure ArrayType::GetSize returns 0 before first call of + * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. + */ + template + void EmplaceInit(size_t idx, Args&&... args) { + void* field_ptr = AddressOf(idx); + new (field_ptr) ElemType(std::forward(args)...); + } + + /*! + * \brief Return the self object for the array. + * + * \return Pointer to ArrayType. + */ + inline ArrayType* Self() const { + return static_cast(const_cast(this)); + } + + /*! + * \brief Return the raw pointer to the element at idx. + * + * \param idx The index of the element. + * \return Raw pointer to the element. + */ + void* AddressOf(size_t idx) const { + static_assert( + alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, + "The size and alignment of ArrayType should respect " + "ElemType's alignment."); + + size_t kDataStart = sizeof(ArrayType); + ArrayType* self = Self(); + char* data_start = reinterpret_cast(self) + kDataStart; + return data_start + idx * sizeof(ElemType); + } +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class IterAdapter { + public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit IterAdapter(TIter iter) : iter_(iter) {} + IterAdapter& operator++() { + ++iter_; + return *this; + } + IterAdapter& operator--() { + --iter_; + return *this; + } + IterAdapter operator++(int) { + IterAdapter copy = *this; + ++iter_; + return copy; + } + IterAdapter operator--(int) { + IterAdapter copy = *this; + --iter_; + return copy; + } + + IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } + + IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } + + template + typename std::enable_if::value, + typename T::difference_type>::type inline + operator-(const IterAdapter& rhs) const { + return iter_ - rhs.iter_; + } + + bool operator==(IterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(IterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class ReverseIterAdapter { + public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} + ReverseIterAdapter& operator++() { + --iter_; + return *this; + } + ReverseIterAdapter& operator--() { + ++iter_; + return *this; + } + ReverseIterAdapter& operator++(int) { + ReverseIterAdapter copy = *this; + --iter_; + return copy; + } + ReverseIterAdapter& operator--(int) { + ReverseIterAdapter copy = *this; + ++iter_; + return copy; + } + ReverseIterAdapter operator+(difference_type offset) const { + return ReverseIterAdapter(iter_ - offset); + } + + template + typename std::enable_if::value, + typename T::difference_type>::type inline + operator-(const ReverseIterAdapter& rhs) const { + return rhs.iter_ - iter_; + } + + bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::Downcast; +using runtime::IterAdapter; +using runtime::make_object; +using runtime::Object; +using runtime::ObjectEqual; +using runtime::ObjectHash; +using runtime::ObjectPtr; +using runtime::ObjectPtrEqual; +using runtime::ObjectPtrHash; +using runtime::ObjectRef; +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_BASE_H_ diff --git a/include/tvm/runtime/container/closure.h b/include/tvm/runtime/container/closure.h new file mode 100644 index 000000000000..a280d1ada7a9 --- /dev/null +++ b/include/tvm/runtime/container/closure.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/closure.h + * \brief Runtime Closure container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_CLOSURE_H_ +#define TVM_RUNTIME_CONTAINER_CLOSURE_H_ + +#include "./base.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief An object representing a closure. This object is used by both the + * Relay VM and interpreter. + */ +class ClosureObj : public Object { + public: + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure; + static constexpr const char* _type_key = "runtime.Closure"; + TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object); +}; + +/*! \brief reference to closure. */ +class Closure : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj); +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_CLOSURE_H_ diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h new file mode 100644 index 000000000000..671e38b83581 --- /dev/null +++ b/include/tvm/runtime/container/map.h @@ -0,0 +1,1441 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/map.h + * \brief Runtime Map container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_MAP_H_ +#define TVM_RUNTIME_CONTAINER_MAP_H_ + +#ifndef USE_FALLBACK_STL_MAP +#define USE_FALLBACK_STL_MAP 0 +#endif + +#include +#include +#include + +#include "./base.h" + +namespace tvm { +namespace runtime { + +#if (USE_FALLBACK_STL_MAP != 0) + +/*! \brief Shared content of all specializations of hash map */ +class MapNode : public Object { + public: + /*! \brief Type of the keys in the hash map */ + using key_type = ObjectRef; + /*! \brief Type of the values in the hash map */ + using mapped_type = ObjectRef; + /*! \brief Type of the actual underlying container */ + using ContainerType = std::unordered_map; + /*! \brief Iterator class */ + using iterator = ContainerType::iterator; + /*! \brief Iterator class */ + using const_iterator = ContainerType::const_iterator; + /*! \brief Type of value stored in the hash map */ + using KVType = ContainerType::value_type; + + static_assert(std::is_standard_layout::value, "KVType is not standard layout"); + static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; + static constexpr const char* _type_key = "Map"; + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); + + /*! + * \brief Number of elements in the SmallMapNode + * \return The result + */ + size_t size() const { return data_.size(); } + /*! + * \brief Count the number of times a key exists in the hash map + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const { return data_.count(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { return data_.at(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { return data_.at(key); } + /*! \return begin iterator */ + iterator begin() { return data_.begin(); } + /*! \return const begin iterator */ + const_iterator begin() const { return data_.begin(); } + /*! \return end iterator */ + iterator end() { return data_.end(); } + /*! \return end iterator */ + const_iterator end() const { return data_.end(); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + const_iterator find(const key_type& key) const { return data_.find(key); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) { return data_.find(key); } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { data_.erase(position); } + /*! + * \brief Erase the entry associated with the key, do nothing if not exists + * \param key The indexing key + */ + void erase(const key_type& key) { data_.erase(key); } + /*! + * \brief Create an empty container + * \return The object created + */ + static ObjectPtr Empty() { return make_object(); } + + protected: + /*! + * \brief Create the map using contents from the given iterators. + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return ObjectPtr to the map created + */ + template + static ObjectPtr CreateFromRange(IterType first, IterType last) { + ObjectPtr p = make_object(); + p->data_ = ContainerType(first, last); + return p; + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + MapNode* map_node = static_cast(map->get()); + map_node->data_[kv.first] = kv.second; + } + /*! + * \brief Create an empty container with elements copying from another MapNode + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(MapNode* from) { + ObjectPtr p = make_object(); + p->data_ = ContainerType(from->data_.begin(), from->data_.end()); + return p; + } + /*! \brief The real container storing data */ + ContainerType data_; + template + friend class Map; +}; + +#else + +/*! \brief Shared content of all specializations of hash map */ +class MapNode : public Object { + public: + /*! \brief Type of the keys in the hash map */ + using key_type = ObjectRef; + /*! \brief Type of the values in the hash map */ + using mapped_type = ObjectRef; + /*! \brief Type of value stored in the hash map */ + using KVType = std::pair; + /*! \brief Iterator class */ + class iterator; + + static_assert(std::is_standard_layout::value, "KVType is not standard layout"); + static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; + static constexpr const char* _type_key = "Map"; + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); + + /*! + * \brief Number of elements in the SmallMapNode + * \return The result + */ + size_t size() const { return size_; } + /*! + * \brief Count the number of times a key exists in the hash map + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key); + /*! \return begin iterator */ + iterator begin() const; + /*! \return end iterator */ + iterator end() const; + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const; + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position); + /*! + * \brief Erase the entry associated with the key, do nothing if not exists + * \param key The indexing key + */ + void erase(const key_type& key) { erase(find(key)); } + + class iterator { + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = int64_t; + using value_type = KVType; + using pointer = KVType*; + using reference = KVType&; + /*! \brief Default constructor */ + iterator() : index(0), self(nullptr) {} + /*! \brief Compare iterators */ + bool operator==(const iterator& other) const { + return index == other.index && self == other.self; + } + /*! \brief Compare iterators */ + bool operator!=(const iterator& other) const { return !(*this == other); } + /*! \brief De-reference iterators */ + pointer operator->() const; + /*! \brief De-reference iterators */ + reference operator*() const { return *((*this).operator->()); } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator& operator++(); + /*! \brief Prefix self decrement, e.g. --iter */ + iterator& operator--(); + /*! \brief Suffix self increment */ + iterator operator++(int) { + iterator copy = *this; + ++(*this); + return copy; + } + /*! \brief Suffix self decrement */ + iterator operator--(int) { + iterator copy = *this; + --(*this); + return copy; + } + + protected: + /*! \brief Construct by value */ + iterator(uint64_t index, const MapNode* self) : index(index), self(self) {} + /*! \brief The position on the array */ + uint64_t index; + /*! \brief The container it points to */ + const MapNode* self; + + friend class DenseMapNode; + friend class SmallMapNode; + }; + /*! + * \brief Create an empty container + * \return The object created + */ + static inline ObjectPtr Empty(); + + protected: + /*! + * \brief Create the map using contents from the given iterators. + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return ObjectPtr to the map created + */ + template + static inline ObjectPtr CreateFromRange(IterType first, IterType last); + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static inline void InsertMaybeReHash(const KVType& kv, ObjectPtr* map); + /*! + * \brief Create an empty container with elements copying from another SmallMapNode + * \param from The source container + * \return The object created + */ + static inline ObjectPtr CopyFrom(MapNode* from); + /*! \brief number of slots minus 1 */ + uint64_t slots_; + /*! \brief number of entries in the container */ + uint64_t size_; + // Reference class + template + friend class Map; +}; + +/*! \brief A specialization of small-sized hash map */ +class SmallMapNode : public MapNode, + public runtime::InplaceArrayBase { + private: + static constexpr uint64_t kInitSize = 2; + static constexpr uint64_t kMaxSize = 4; + + public: + using MapNode::iterator; + using MapNode::KVType; + + /*! \brief Defaults to the destructor of InplaceArrayBase */ + ~SmallMapNode() = default; + /*! + * \brief Count the number of times a key exists in the SmallMapNode + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const { return find(key).index < size_; } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { + iterator itr = find(key); + ICHECK(itr.index < size_) << "IndexError: key is not in Map"; + return itr->second; + } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { + iterator itr = find(key); + ICHECK(itr.index < size_) << "IndexError: key is not in Map"; + return itr->second; + } + /*! \return begin iterator */ + iterator begin() const { return iterator(0, this); } + /*! \return end iterator */ + iterator end() const { return iterator(size_, this); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const { + KVType* ptr = static_cast(AddressOf(0)); + for (uint64_t i = 0; i < size_; ++i, ++ptr) { + if (ObjectEqual()(ptr->first, key)) { + return iterator(i, this); + } + } + return iterator(size_, this); + } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { Erase(position.index); } + + private: + /*! + * \brief Remove a position in SmallMapNode + * \param index The position to be removed + */ + void Erase(const uint64_t index) { + if (index >= size_) { + return; + } + KVType* begin = static_cast(AddressOf(0)); + KVType* last = begin + (size_ - 1); + if (index + 1 == size_) { + last->first.ObjectRef::~ObjectRef(); + last->second.ObjectRef::~ObjectRef(); + } else { + *(begin + index) = std::move(*last); + } + size_ -= 1; + } + /*! + * \brief Create an empty container + * \param n Number of empty slots + * \return The object created + */ + static ObjectPtr Empty(uint64_t n = kInitSize) { + using ::tvm::runtime::make_inplace_array_object; + ObjectPtr p = make_inplace_array_object(n); + p->size_ = 0; + p->slots_ = n; + return p; + } + /*! + * \brief Create an empty container initialized with a given range + * \param n Number of empty slots + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + * \return The object created + */ + template + static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { + ObjectPtr p = Empty(n); + KVType* ptr = static_cast(p->AddressOf(0)); + for (; first != last; ++first, ++p->size_) { + new (ptr++) KVType(*first); + } + return p; + } + /*! + * \brief Create an empty container with elements copying from another SmallMapNode + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(SmallMapNode* from) { + KVType* first = static_cast(from->AddressOf(0)); + KVType* last = first + from->size_; + return CreateFromRange(from->size_, first, last); + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + SmallMapNode* map_node = static_cast(map->get()); + iterator itr = map_node->find(kv.first); + if (itr.index < map_node->size_) { + itr->second = kv.second; + return; + } + if (map_node->size_ < map_node->slots_) { + KVType* ptr = static_cast(map_node->AddressOf(map_node->size_)); + new (ptr) KVType(kv); + ++map_node->size_; + return; + } + uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize)); + next_size = std::min(next_size, uint64_t(kMaxSize)); + ICHECK_GT(next_size, map_node->slots_); + ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); + InsertMaybeReHash(kv, &new_map); + *map = std::move(new_map); + } + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType* DeRefItr(uint64_t index) const { return static_cast(AddressOf(index)); } + /*! \brief A size function used by InplaceArrayBase */ + uint64_t GetSize() const { return size_; } + + protected: + friend class MapNode; + friend class DenseMapNode; + friend class runtime::InplaceArrayBase; +}; + +/*! \brief A specialization of hash map that implements the idea of array-based hash map. + * Another reference implementation can be found [1]. + * + * A. Overview + * + * DenseMapNode did several improvements over traditional separate chaining hash, + * in terms of cache locality, memory footprints and data organization. + * + * A1. Implicit linked list. For better cache locality, instead of using linked list + * explicitly for each bucket, we store list data into a single array that spans contiguously + * in memory, and then carefully design access patterns to make sure most of them fall into + * a single cache line. + * + * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and + * traversal. This can be divided in 3 parts. + * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, + * which means the slot is empty but not allowed to be written. + * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is + * head of a linked list. + * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit + * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when + * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are + * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to + * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, + * then x must be one of the 126 pre-defined values. + * + * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. + * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. + * 16 key-value pairs. + * + * B. Implementation details + * + * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid + * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, + * we use the Fibonacci Hashing [2] trick. + * + * B2. Traverse a linked list in the array. + * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i + * indicates that it is list head, then we found the head; otherwise the list is empty. No probing + * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we + * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of + * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). + * + * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this + * element is in the linked list, and if not, we put it at the end by probing the next empty + * position in one of the 126 candidate positions. If the linked list does not even exist, but the + * slot for list head has been occupied by another linked list, we should find this intruder another + * place. + * + * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing + * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the + * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list + * head. + * + * [1] https://github.com/skarupke/flat_hash_map + * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ + * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ + */ +class DenseMapNode : public MapNode { + private: + /*! \brief The number of elements in a memory block */ + static constexpr int kBlockCap = 16; + /*! \brief Maximum load factor of the hash map */ + static constexpr double kMaxLoadFactor = 0.99; + /*! \brief Binary representation of the metadata of an empty slot */ + static constexpr uint8_t kEmptySlot = uint8_t(0b11111111); + /*! \brief Binary representation of the metadata of a protected slot */ + static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110); + /*! \brief Number of probing choices available */ + static constexpr int kNumJumpDists = 126; + /*! \brief Head of the implicit linked list */ + struct ListNode; + /*! \brief POD type of a block of memory */ + struct Block { + uint8_t bytes[kBlockCap + kBlockCap * sizeof(KVType)]; + }; + static_assert(sizeof(Block) == kBlockCap * (sizeof(KVType) + 1), "sizeof(Block) incorrect"); + static_assert(std::is_standard_layout::value, "Block is not standard layout"); + + public: + using MapNode::iterator; + + /*! + * \brief Destroy the DenseMapNode + */ + ~DenseMapNode() { this->Reset(); } + /*! \return The number of elements of the key */ + size_t count(const key_type& key) const { return !Search(key).IsNone(); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { return At(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { return At(key); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const { + ListNode node = Search(key); + return node.IsNone() ? end() : iterator(node.index, this); + } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { + uint64_t index = position.index; + if (position.self != nullptr && index <= this->slots_) { + Erase(ListNode(index, this)); + } + } + /*! \return begin iterator */ + iterator begin() const { + if (slots_ == 0) { + return iterator(0, this); + } + for (uint64_t index = 0; index <= slots_; ++index) { + if (!ListNode(index, this).IsEmpty()) { + return iterator(index, this); + } + } + return iterator(slots_ + 1, this); + } + /*! \return end iterator */ + iterator end() const { return slots_ == 0 ? iterator(0, this) : iterator(slots_ + 1, this); } + + private: + /*! + * \brief Search for the given key + * \param key The key + * \return ListNode that associated with the key + */ + ListNode Search(const key_type& key) const { + if (this->size_ == 0) { + return ListNode(); + } + for (ListNode iter = GetListHead(ObjectHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { + if (ObjectEqual()(key, iter.Key())) { + return iter; + } + } + return ListNode(); + } + /*! + * \brief Search for the given key, throw exception if not exists + * \param key The key + * \return ListNode that associated with the key + */ + mapped_type& At(const key_type& key) const { + ListNode iter = Search(key); + ICHECK(!iter.IsNone()) << "IndexError: key is not in Map"; + return iter.Val(); + } + /*! + * \brief Try to insert a key, or do nothing if already exists + * \param key The indexing key + * \param result The linked-list entry found or just constructed + * \return A boolean, indicating if actual insertion happens + */ + bool TryInsert(const key_type& key, ListNode* result) { + if (slots_ == 0) { + return false; + } + // required that `iter` to be the head of a linked list through which we can iterator + ListNode iter = IndexFromHash(ObjectHash()(key)); + // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list + // Case 1: empty + if (iter.IsEmpty()) { + iter.NewHead(KVType(key, ObjectRef(nullptr))); + this->size_ += 1; + *result = iter; + return true; + } + // Case 2: body of an irrelevant list + if (!iter.IsHead()) { + // we move the elements around and construct the single-element linked list + return IsFull() ? false : TrySpareListHead(iter, key, result); + } + // Case 3: head of the relevant list + // we iterate through the linked list until the end + // make sure `iter` is the previous element of `next` + ListNode next = iter; + do { + // find equal item, do not insert + if (ObjectEqual()(key, next.Key())) { + *result = next; + return true; + } + // make sure `iter` is the previous element of `next` + iter = next; + } while (next.MoveToNext(this)); + // `iter` is the tail of the linked list + // always check capacity before insertion + if (IsFull()) { + return false; + } + // find the next empty slot + uint8_t jump; + if (!iter.GetNextEmpty(this, &jump, result)) { + return false; + } + result->NewTail(KVType(key, ObjectRef(nullptr))); + // link `iter` to `empty`, and move forward + iter.SetJump(jump); + this->size_ += 1; + return true; + } + /*! + * \brief Spare an entry to be the head of a linked list. + * As described in B3, during insertion, it is possible that the entire linked list does not + * exist, but the slot of its head has been occupied by other linked lists. In this case, we need + * to spare the slot by moving away the elements to another valid empty one to make insertion + * possible. + * \param target The given entry to be spared + * \param key The indexing key + * \param result The linked-list entry constructed as the head + * \return A boolean, if actual insertion happens + */ + bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { + // `target` is not the head of the linked list + // move the original item of `target` (if any) + // and construct new item on the position `target` + // To make `target` empty, we + // 1) find `w` the previous element of `target` in the linked list + // 2) copy the linked list starting from `r = target` + // 3) paste them after `w` + // read from the linked list after `r` + ListNode r = target; + // write to the tail of `w` + ListNode w = target.FindPrev(this); + // after `target` is moved, we disallow writing to the slot + bool is_first = true; + uint8_t r_meta, jump; + ListNode empty; + do { + // `jump` describes how `w` is jumped to `empty` + // rehash if there is no empty space after `w` + if (!w.GetNextEmpty(this, &jump, &empty)) { + return false; + } + // move `r` to `empty` + empty.NewTail(std::move(r.Data())); + // clear the metadata of `r` + r_meta = r.Meta(); + if (is_first) { + is_first = false; + r.SetProtected(); + } else { + r.SetEmpty(); + } + // link `w` to `empty`, and move forward + w.SetJump(jump); + w = empty; + // move `r` forward as well + } while (r.MoveToNext(this, r_meta)); + // finally we have done moving the linked list + // fill data_ into `target` + target.NewHead(KVType(key, ObjectRef(nullptr))); + this->size_ += 1; + *result = target; + return true; + } + /*! + * \brief Remove a ListNode + * \param iter The node to be removed + */ + void Erase(const ListNode& iter) { + this->size_ -= 1; + if (!iter.HasNext()) { + // `iter` is the last + if (!iter.IsHead()) { + // cut the link if there is any + iter.FindPrev(this).SetJump(0); + } + iter.Data().KVType::~KVType(); + iter.SetEmpty(); + } else { + ListNode last = iter, prev = iter; + for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { + } + iter.Data() = std::move(last.Data()); + last.SetEmpty(); + prev.SetJump(0); + } + } + /*! \brief Clear the container to empty, release all entries and memory acquired */ + void Reset() { + uint64_t n_blocks = CalcNumBlocks(this->slots_); + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr = data_[bi].bytes; + KVType* data_ptr = reinterpret_cast(data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { + uint8_t& meta = *meta_ptr; + if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { + meta = uint8_t(kEmptySlot); + data_ptr->KVType::~KVType(); + } + } + } + ReleaseMemory(); + } + /*! \brief Release the memory acquired by the container without deleting its entries stored inside + */ + void ReleaseMemory() { + delete[] data_; + data_ = nullptr; + slots_ = 0; + size_ = 0; + fib_shift_ = 63; + } + /*! + * \brief Create an empty container + * \param fib_shift The fib shift provided + * \param n_slots Number of slots required, should be power-of-two + * \return The object created + */ + static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { + ICHECK_GT(n_slots, uint64_t(SmallMapNode::kMaxSize)); + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(n_slots - 1); + Block* block = p->data_ = new Block[n_blocks]; + p->slots_ = n_slots - 1; + p->size_ = 0; + p->fib_shift_ = fib_shift; + for (uint64_t i = 0; i < n_blocks; ++i, ++block) { + std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot)); + } + return p; + } + /*! + * \brief Create an empty container with elements copying from another DenseMapNode + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(DenseMapNode* from) { + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(from->slots_); + p->data_ = new Block[n_blocks]; + p->slots_ = from->slots_; + p->size_ = from->size_; + p->fib_shift_ = from->fib_shift_; + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr_from = from->data_[bi].bytes; + KVType* data_ptr_from = reinterpret_cast(from->data_[bi].bytes + kBlockCap); + uint8_t* meta_ptr_to = p->data_[bi].bytes; + KVType* data_ptr_to = reinterpret_cast(p->data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; + ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { + uint8_t& meta = *meta_ptr_to = *meta_ptr_from; + ICHECK(meta != kProtectedSlot); + if (meta != uint8_t(kEmptySlot)) { + new (data_ptr_to) KVType(*data_ptr_from); + } + } + } + return p; + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + DenseMapNode* map_node = static_cast(map->get()); + ListNode iter; + // Try to insert. If succeed, we simply return + if (map_node->TryInsert(kv.first, &iter)) { + iter.Val() = kv.second; + return; + } + ICHECK_GT(map_node->slots_, uint64_t(SmallMapNode::kMaxSize)); + // Otherwise, start rehash + ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2); + // Insert the given `kv` into the new hash map + InsertMaybeReHash(kv, &p); + uint64_t n_blocks = CalcNumBlocks(map_node->slots_); + // Then Insert data from the original block. + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr = map_node->data_[bi].bytes; + KVType* data_ptr = reinterpret_cast(map_node->data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { + uint8_t& meta = *meta_ptr; + if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { + meta = uint8_t(kEmptySlot); + KVType kv = std::move(*data_ptr); + InsertMaybeReHash(kv, &p); + } + } + } + map_node->ReleaseMemory(); + *map = p; + } + /*! + * \brief Check whether the hash table is full + * \return A boolean indicating whether hash table is full + */ + bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; } + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { + for (++index; index <= slots_; ++index) { + if (!ListNode(index, this).IsEmpty()) { + return index; + } + } + return slots_ + 1; + } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { + while (index != 0) { + index -= 1; + if (!ListNode(index, this).IsEmpty()) { + return index; + } + } + return slots_ + 1; + } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } + /*! \brief Construct from hash code */ + ListNode IndexFromHash(uint64_t hash_value) const { + return ListNode(FibHash(hash_value, fib_shift_), this); + } + /*! \brief Construct from hash code if the position is head of list */ + ListNode GetListHead(uint64_t hash_value) const { + ListNode node = IndexFromHash(hash_value); + return node.IsHead() ? node : ListNode(); + } + /*! \brief Construct the number of blocks in the hash table */ + static uint64_t CalcNumBlocks(uint64_t n_slots_m1) { + uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0; + return (n_slots + kBlockCap - 1) / kBlockCap; + } + /*! + * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. + * \param cap The lower-bound of the required capacity + * \param fib_shift The result shift for Fibonacci Hashing + * \param n_slots The result number of slots + */ + static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { + uint32_t shift = 64; + uint64_t slots = 1; + for (uint64_t c = cap; c; c >>= 1) { + shift -= 1; + slots <<= 1; + } + ICHECK_GT(slots, cap); + if (slots < cap * 2) { + *fib_shift = shift - 1; + *n_slots = slots << 1; + } else { + *fib_shift = shift; + *n_slots = slots; + } + } + /*! + * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. + * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. + * \param hash_value The raw hash value + * \param fib_shift The shift in Fibonacci Hashing + * \return An index calculated using Fibonacci Hashing + */ + static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { + constexpr uint64_t coeff = 11400714819323198485ull; + return (coeff * hash_value) >> fib_shift; + } + /*! \brief The implicit in-place linked list used to index a chain */ + struct ListNode { + /*! \brief Construct None */ + ListNode() : index(0), block(nullptr) {} + /*! \brief Construct from position */ + ListNode(uint64_t index, const DenseMapNode* self) + : index(index), block(self->data_ + (index / kBlockCap)) {} + /*! \brief Metadata on the entry */ + uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } + /*! \brief Data on the entry */ + KVType& Data() const { + return *(reinterpret_cast(block->bytes + kBlockCap + + (index % kBlockCap) * sizeof(KVType))); + } + /*! \brief Key on the entry */ + key_type& Key() const { return Data().first; } + /*! \brief Value on the entry */ + mapped_type& Val() const { return Data().second; } + /*! \brief If the entry is head of linked list */ + bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } + /*! \brief If the entry is none */ + bool IsNone() const { return block == nullptr; } + /*! \brief If the entry is empty slot */ + bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); } + /*! \brief If the entry is protected slot */ + bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); } + /*! \brief Set the entry to be empty */ + void SetEmpty() const { Meta() = uint8_t(kEmptySlot); } + /*! \brief Set the entry to be protected */ + void SetProtected() const { Meta() = uint8_t(kProtectedSlot); } + /*! \brief Set the entry's jump to its next entry */ + void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } + /*! \brief Construct a head of linked list in-place */ + void NewHead(KVType v) const { + Meta() = 0b00000000; + new (&Data()) KVType(std::move(v)); + } + /*! \brief Construct a tail of linked list in-place */ + void NewTail(KVType v) const { + Meta() = 0b10000000; + new (&Data()) KVType(std::move(v)); + } + /*! \brief If the entry has next entry on the linked list */ + bool HasNext() const { return kNextProbeLocation[Meta() & 0b01111111] != 0; } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapNode* self, uint8_t meta) { + uint64_t offset = kNextProbeLocation[meta & 0b01111111]; + if (offset == 0) { + index = 0; + block = nullptr; + return false; + } + index = (index + offset) & (self->slots_); + block = self->data_ + (index / kBlockCap); + return true; + } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapNode* self) { return MoveToNext(self, Meta()); } + /*! \brief Get the previous entry on the linked list */ + ListNode FindPrev(const DenseMapNode* self) const { + // start from the head of the linked list, which must exist + ListNode next = self->IndexFromHash(ObjectHash()(Key())); + // `prev` is always the previous item of `next` + ListNode prev = next; + for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { + } + return prev; + } + /*! \brief Get the next empty jump */ + bool GetNextEmpty(const DenseMapNode* self, uint8_t* jump, ListNode* result) const { + for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { + ListNode candidate((index + kNextProbeLocation[idx]) & (self->slots_), self); + if (candidate.IsEmpty()) { + *jump = idx; + *result = candidate; + return true; + } + } + return false; + } + /*! \brief Index on the real array */ + uint64_t index; + /*! \brief Pointer to the actual block */ + Block* block; + }; + + protected: + /*! \brief fib shift in Fibonacci Hashing */ + uint32_t fib_shift_; + /*! \brief array of data blocks */ + Block* data_; + /* clang-format off */ + /*! \brief Candidates of probing distance */ + TVM_DLL static constexpr uint64_t kNextProbeLocation[kNumJumpDists] { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + // Quadratic probing with triangle numbers. See also: + // 1) https://en.wikipedia.org/wiki/Quadratic_probing + // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ + // 3) https://github.com/skarupke/flat_hash_map + 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, + 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, + 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, + 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, + 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, + 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, + 2211, 2278, 2346, 2415, 2485, 2556, 2628, + // larger triangle numbers + 8515, 19110, 42778, 96141, 216153, + 486591, 1092981, 2458653, 5532801, 12442566, + 27993903, 62983476, 141717030, 318844378, 717352503, + 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, + 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, + 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, + 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, + 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, 457381325854679626, + 1029107982097042876, 2315492959180353330, 5209859154120846435, + }; + /* clang-format on */ + friend class MapNode; +}; + +#define TVM_DISPATCH_MAP(base, var, body) \ + { \ + using TSmall = SmallMapNode*; \ + using TDense = DenseMapNode*; \ + uint64_t slots = base->slots_; \ + if (slots <= SmallMapNode::kMaxSize) { \ + TSmall var = static_cast(base); \ + body; \ + } else { \ + TDense var = static_cast(base); \ + body; \ + } \ + } + +#define TVM_DISPATCH_MAP_CONST(base, var, body) \ + { \ + using TSmall = const SmallMapNode*; \ + using TDense = const DenseMapNode*; \ + uint64_t slots = base->slots_; \ + if (slots <= SmallMapNode::kMaxSize) { \ + TSmall var = static_cast(base); \ + body; \ + } else { \ + TDense var = static_cast(base); \ + body; \ + } \ + } + +inline MapNode::iterator::pointer MapNode::iterator::operator->() const { + TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); +} + +inline MapNode::iterator& MapNode::iterator::operator++() { + TVM_DISPATCH_MAP_CONST(self, p, { + index = p->IncItr(index); + return *this; + }); +} + +inline MapNode::iterator& MapNode::iterator::operator--() { + TVM_DISPATCH_MAP_CONST(self, p, { + index = p->DecItr(index); + return *this; + }); +} + +inline size_t MapNode::count(const key_type& key) const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); +} + +inline const MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); +} + +inline MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) { + TVM_DISPATCH_MAP(this, p, { return p->at(key); }); +} + +inline MapNode::iterator MapNode::begin() const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); +} + +inline MapNode::iterator MapNode::end() const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); }); +} + +inline MapNode::iterator MapNode::find(const MapNode::key_type& key) const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); +} + +inline void MapNode::erase(const MapNode::iterator& position) { + TVM_DISPATCH_MAP(this, p, { return p->erase(position); }); +} + +#undef TVM_DISPATCH_MAP +#undef TVM_DISPATCH_MAP_CONST + +inline ObjectPtr MapNode::Empty() { return SmallMapNode::Empty(); } + +inline ObjectPtr MapNode::CopyFrom(MapNode* from) { + if (from->slots_ <= SmallMapNode::kMaxSize) { + return SmallMapNode::CopyFrom(static_cast(from)); + } else { + return DenseMapNode::CopyFrom(static_cast(from)); + } +} + +template +inline ObjectPtr MapNode::CreateFromRange(IterType first, IterType last) { + int64_t _cap = std::distance(first, last); + if (_cap < 0) { + return SmallMapNode::Empty(); + } + uint64_t cap = static_cast(_cap); + if (cap < SmallMapNode::kMaxSize) { + return SmallMapNode::CreateFromRange(cap, first, last); + } + uint32_t fib_shift; + uint64_t n_slots; + DenseMapNode::CalcTableSize(cap, &fib_shift, &n_slots); + ObjectPtr obj = DenseMapNode::Empty(fib_shift, n_slots); + for (; first != last; ++first) { + KVType kv(*first); + DenseMapNode::InsertMaybeReHash(kv, &obj); + } + return obj; +} + +inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize; + MapNode* base = static_cast(map->get()); + if (base->slots_ < kSmallMapMaxSize) { + SmallMapNode::InsertMaybeReHash(kv, map); + } else if (base->slots_ == kSmallMapMaxSize) { + if (base->size_ < base->slots_) { + SmallMapNode::InsertMaybeReHash(kv, map); + } else { + ObjectPtr new_map = MapNode::CreateFromRange(base->begin(), base->end()); + DenseMapNode::InsertMaybeReHash(kv, &new_map); + *map = std::move(new_map); + } + } else { + DenseMapNode::InsertMaybeReHash(kv, map); + } +} + +template <> +inline ObjectPtr make_object<>() = delete; + +#endif + +/*! + * \brief Map container of NodeRef->NodeRef in DSL graph. + * Map implements copy on write semantics, which means map is mutable + * but copy will happen when array is referenced in more than two places. + * + * operator[] only provide const acces, use Set to mutate the content. + * \tparam K The key NodeRef type. + * \tparam V The value NodeRef type. + */ +template ::value>::type, + typename = typename std::enable_if::value>::type> +class Map : public ObjectRef { + public: + using key_type = K; + using mapped_type = V; + class iterator; + /*! + * \brief default constructor + */ + Map() { data_ = MapNode::Empty(); } + /*! + * \brief move constructor + * \param other source + */ + Map(Map&& other) { data_ = std::move(other.data_); } + /*! + * \brief copy constructor + * \param other source + */ + Map(const Map& other) : ObjectRef(other.data_) {} + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Map& operator=(Map&& other) { + data_ = std::move(other.data_); + return *this; + } + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Map& operator=(const Map& other) { + data_ = other.data_; + return *this; + } + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Map(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + Map(IterType begin, IterType end) { + data_ = MapNode::CreateFromRange(begin, end); + } + /*! + * \brief constructor from initializer list + * \param init The initalizer list + */ + Map(std::initializer_list> init) { + data_ = MapNode::CreateFromRange(init.begin(), init.end()); + } + /*! + * \brief constructor from unordered_map + * \param init The unordered_map + */ + template + Map(const std::unordered_map& init) { // NOLINT(*) + data_ = MapNode::CreateFromRange(init.begin(), init.end()); + } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + const V at(const K& key) const { return DowncastNoCheck(GetMapNode()->at(key)); } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + const V operator[](const K& key) const { return this->at(key); } + /*! \return The size of the array */ + size_t size() const { + MapNode* n = GetMapNode(); + return n == nullptr ? 0 : n->size(); + } + /*! \return The number of elements of the key */ + size_t count(const K& key) const { + MapNode* n = GetMapNode(); + return n == nullptr ? 0 : GetMapNode()->count(key); + } + /*! \return whether array is empty */ + bool empty() const { return size() == 0; } + /*! \brief Release reference to all the elements */ + void clear() { + MapNode* n = GetMapNode(); + if (n != nullptr) { + data_ = MapNode::Empty(); + } + } + /*! + * \brief set the Map. + * \param key The index key. + * \param value The value to be setted. + */ + void Set(const K& key, const V& value) { + CopyOnWrite(); + MapNode::InsertMaybeReHash(MapNode::KVType(key, value), &data_); + } + /*! \return begin iterator */ + iterator begin() const { return iterator(GetMapNode()->begin()); } + /*! \return end iterator */ + iterator end() const { return iterator(GetMapNode()->end()); } + /*! \return find the key and returns the associated iterator */ + iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); } + + void erase(const K& key) { CopyOnWrite()->erase(key); } + + /*! + * \brief copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + MapNode* CopyOnWrite() { + if (data_.get() == nullptr) { + data_ = MapNode::Empty(); + } else if (!data_.unique()) { + data_ = MapNode::CopyFrom(GetMapNode()); + } + return GetMapNode(); + } + /*! \brief specify container node */ + using ContainerType = MapNode; + + /*! \brief Iterator of the hash map */ + class iterator { + public: + using iterator_category = std::bidirectional_iterator_tag; + using difference_type = int64_t; + using value_type = const std::pair; + using pointer = value_type*; + using reference = value_type; + + iterator() : itr() {} + + /*! \brief Compare iterators */ + bool operator==(const iterator& other) const { return itr == other.itr; } + /*! \brief Compare iterators */ + bool operator!=(const iterator& other) const { return itr != other.itr; } + /*! \brief De-reference iterators is not allowed */ + pointer operator->() const = delete; + /*! \brief De-reference iterators */ + reference operator*() const { + auto& kv = *itr; + return std::make_pair(DowncastNoCheck(kv.first), DowncastNoCheck(kv.second)); + } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator& operator++() { + ++itr; + return *this; + } + /*! \brief Suffix self increment */ + iterator operator++(int) { + iterator copy = *this; + ++(*this); + return copy; + } + + private: + iterator(const MapNode::iterator& itr) // NOLINT(*) + : itr(itr) {} + + template + friend class Map; + + MapNode::iterator itr; + }; + + private: + /*! \brief Return data_ as type of pointer of MapNode */ + MapNode* GetMapNode() const { return static_cast(data_.get()); } +}; + +/*! + * \brief Merge two Maps. + * \param lhs the first Map to merge. + * \param rhs the second Map to merge. + * @return The merged Array. Original Maps are kept unchanged. + */ +template ::value>::type, + typename = typename std::enable_if::value>::type> +inline Map Merge(Map lhs, const Map& rhs) { + for (const auto& p : rhs) { + lhs.Set(p.first, p.second); + } + return std::move(lhs); +} + +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::Map; +using runtime::MapNode; +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_MAP_H_ diff --git a/include/tvm/runtime/container/optional.h b/include/tvm/runtime/container/optional.h new file mode 100644 index 000000000000..bea4228c48b8 --- /dev/null +++ b/include/tvm/runtime/container/optional.h @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/optional.h + * \brief Runtime Optional container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_OPTIONAL_H_ +#define TVM_RUNTIME_CONTAINER_OPTIONAL_H_ + +#include + +#include "./base.h" + +namespace tvm { +namespace runtime { + +/*! \brief Helper to represent nullptr for optional. */ +struct NullOptType {}; + +/*! + * \brief Optional container that to represent to a Nullable variant of T. + * \tparam T The original ObjectRef. + * + * \code + * + * Optional opt0 = nullptr; + * Optional opt1 = String("xyz"); + * ICHECK(opt0 == nullptr); + * ICHECK(opt1 == "xyz"); + * + * \endcode + */ +template +class Optional : public ObjectRef { + public: + using ContainerType = typename T::ContainerType; + static_assert(std::is_base_of::value, "Optional is only defined for ObjectRef."); + // default constructors. + Optional() = default; + Optional(const Optional&) = default; + Optional(Optional&&) = default; + Optional& operator=(const Optional&) = default; + Optional& operator=(Optional&&) = default; + /*! + * \brief Construct from an ObjectPtr + * whose type already matches the ContainerType. + * \param ptr + */ + explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} + /*! \brief Nullopt handling */ + Optional(NullOptType) {} // NOLINT(*) + // nullptr handling. + // disallow implicit conversion as 0 can be implicitly converted to nullptr_t + explicit Optional(std::nullptr_t) {} + Optional& operator=(std::nullptr_t) { + data_ = nullptr; + return *this; + } + // normal value handling. + Optional(T other) // NOLINT(*) + : ObjectRef(std::move(other)) {} + Optional& operator=(T other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + // delete the int constructor + // since Optional(0) is ambiguious + // 0 can be implicitly casted to nullptr_t + explicit Optional(int val) = delete; + Optional& operator=(int val) = delete; + /*! + * \return A not-null container value in the optional. + * \note This function performs not-null checking. + */ + T value() const { + ICHECK(data_ != nullptr); + return T(data_); + } + /*! + * \return The contained value if the Optional is not null + * otherwise return the default_value. + */ + T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; } + + /*! \return Whether the container is not nullptr.*/ + explicit operator bool() const { return *this != nullptr; } + // operator overloadings + bool operator==(std::nullptr_t) const { return data_ == nullptr; } + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } + auto operator==(const Optional& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() == other.value()); + if (same_as(other)) return RetType(true); + if (*this != nullptr && other != nullptr) { + return value() == other.value(); + } else { + // one of them is nullptr. + return RetType(false); + } + } + auto operator!=(const Optional& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() != other.value()); + if (same_as(other)) return RetType(false); + if (*this != nullptr && other != nullptr) { + return value() != other.value(); + } else { + // one of them is nullptr. + return RetType(true); + } + } + auto operator==(const T& other) const { + using RetType = decltype(value() == other); + if (same_as(other)) return RetType(true); + if (*this != nullptr) return value() == other; + return RetType(false); + } + auto operator!=(const T& other) const { return !(*this == other); } + template + auto operator==(const U& other) const { + using RetType = decltype(value() == other); + if (*this == nullptr) return RetType(false); + return value() == other; + } + template + auto operator!=(const U& other) const { + using RetType = decltype(value() != other); + if (*this == nullptr) return RetType(true); + return value() != other; + } + static constexpr bool _type_is_nullable = true; +}; + +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::Optional; +constexpr runtime::NullOptType NullOpt{}; +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_OPTIONAL_H_ diff --git a/include/tvm/runtime/container/shape_tuple.h b/include/tvm/runtime/container/shape_tuple.h new file mode 100644 index 000000000000..774077fc3d5e --- /dev/null +++ b/include/tvm/runtime/container/shape_tuple.h @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/shape_tuple.h + * \brief Runtime ShapeTuple container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ +#define TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ + +#include +#include + +#include "./base.h" + +namespace tvm { +namespace runtime { + +/*! \brief An object representing a shape tuple. */ +class ShapeTupleObj : public Object { + public: + /*! \brief The type of shape index element. */ + using index_type = int64_t; + /*! \brief The pointer to shape tuple data. */ + index_type* data; + /*! \brief The size of the shape tuple object. */ + uint64_t size; + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeShapeTuple; + static constexpr const char* _type_key = "runtime.ShapeTuple"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTupleObj, Object); + + private: + /*! \brief ShapeTuple object which is moved from std::vector container. */ + class FromStd; + + friend class ShapeTuple; +}; + +/*! \brief An object representing shape tuple moved from std::vector. */ +class ShapeTupleObj::FromStd : public ShapeTupleObj { + public: + /*! \brief The type of shape index element. */ + using index_type = ShapeTupleObj::index_type; + /*! + * \brief Construct a new FromStd object + * + * \param other The moved/copied std::vector object + * + * \note If user passes const reference, it will trigger copy. If it's rvalue, + * it will be moved into other. + */ + explicit FromStd(std::vector other) : data_container{other} {} + + private: + /*! \brief Container that holds the memory. */ + std::vector data_container; + + friend class ShapeTuple; +}; + +/*! + * \brief Reference to shape tuple objects. + */ +class ShapeTuple : public ObjectRef { + public: + /*! \brief The type of shape index element. */ + using index_type = ShapeTupleObj::index_type; + + /*! + * \brief Construct an empty shape tuple. + */ + ShapeTuple() : ShapeTuple(std::vector()) {} + + /*! + * \brief Constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + ShapeTuple(IterType begin, IterType end) : ShapeTuple(std::vector(begin, end)) {} + + /*! + * \brief constructor from initializer list + * \param shape The initializer list + */ + ShapeTuple(std::initializer_list shape) : ShapeTuple(shape.begin(), shape.end()) {} + + /*! + * \brief Construct a new ShapeTuple object + * + * \param shape The moved/copied std::vector object + * + * \note If user passes const reference, it will trigger copy. If it's rvalue, + * it will be moved into other. + */ + ShapeTuple(std::vector shape); // NOLINT(*) + + /*! + * \brief Return the data pointer + * + * \return const index_type* data pointer + */ + const index_type* data() const { return get()->data; } + + /*! + * \brief Return the size of the shape tuple + * + * \return size_t shape tuple size + */ + size_t size() const { return get()->size; } + + /*! + * \brief Immutably read i-th element from the shape tuple. + * \param idx The index + * \return the i-th element. + */ + index_type operator[](size_t idx) const { + ICHECK(0 <= idx && idx < this->size()) + << "IndexError: indexing " << idx << " on an array of size " << this->size(); + return this->data()[idx]; + } + + /*! + * \brief Immutably read i-th element from the shape tuple. + * \param idx The index + * \return the i-th element. + */ + index_type at(size_t idx) const { return this->operator[](idx); } + + /*! \return Whether shape tuple is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the shape tuple */ + index_type front() const { return this->at(0); } + + /*! \return The last element of the shape tuple */ + index_type back() const { return this->at(this->size() - 1); } + + /*! \return begin iterator */ + const index_type* begin() const { return get()->data; } + + /*! \return end iterator */ + const index_type* end() const { return (get()->data + size()); } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeTuple, ObjectRef, ShapeTupleObj); +}; + +inline ShapeTuple::ShapeTuple(std::vector shape) { + auto ptr = make_object(std::move(shape)); + ptr->size = ptr->data_container.size(); + ptr->data = ptr->data_container.data(); + data_ = std::move(ptr); +} + +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::ShapeTuple; +using runtime::ShapeTupleObj; +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ diff --git a/include/tvm/runtime/container/string.h b/include/tvm/runtime/container/string.h new file mode 100644 index 000000000000..664d19818be1 --- /dev/null +++ b/include/tvm/runtime/container/string.h @@ -0,0 +1,523 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/string.h + * \brief Runtime String container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_STRING_H_ +#define TVM_RUNTIME_CONTAINER_STRING_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +// We use c++14 std::experimental::string_view for optimizing hash computation +// only right now, its usage is limited in this file. Any broader usage of +// std::experiment in our core codebase is discouraged and needs community +// discussion for each use case. Reference for feature test macros of +// string_view: +// https://isocpp.org/std/standing-documents/sd-6-sg10-feature-test-recommendations +// https://en.cppreference.com/w/User:D41D8CD98F/feature_testing_macros +#if defined(__cpp_lib_experimental_string_view) && __cpp_lib_experimental_string_view >= 201411 +#define TVM_USE_CXX14_STRING_VIEW_HASH 1 +#else +#define TVM_USE_CXX14_STRING_VIEW_HASH 0 +#endif + +// Tested with clang version 9.0.1 and c++17. It will detect string_view support +// correctly. +#if defined(__cpp_lib_string_view) && __cpp_lib_string_view >= 201606 +#define TVM_USE_CXX17_STRING_VIEW_HASH 1 +#else +#define TVM_USE_CXX17_STRING_VIEW_HASH 0 +#endif + +#if TVM_USE_CXX17_STRING_VIEW_HASH +#include +#elif TVM_USE_CXX14_STRING_VIEW_HASH +#include +#endif + +#include +#include +#include + +namespace llvm { +// String to llvm object compatibility. +class StringRef; +} // namespace llvm + +namespace tvm { +namespace runtime { + +// Forward declare TVMArgValue +class TVMArgValue; + +/*! \brief An object representing string. It's POD type. */ +class StringObj : public Object { + public: + /*! \brief The pointer to string data. */ + const char* data; + + /*! \brief The length of the string object. */ + uint64_t size; + + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString; + static constexpr const char* _type_key = "runtime.String"; + TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object); + + private: + /*! \brief String object which is moved from std::string container. */ + class FromStd; + + friend class String; +}; + +/*! + * \brief Reference to string objects. + * + * \code + * + * // Example to create runtime String reference object from std::string + * std::string s = "hello world"; + * + * // You can create the reference from existing std::string + * String ref{std::move(s)}; + * + * // You can rebind the reference to another string. + * ref = std::string{"hello world2"}; + * + * // You can use the reference as hash map key + * std::unordered_map m; + * m[ref] = 1; + * + * // You can compare the reference object with other string objects + * assert(ref == "hello world", true); + * + * // You can convert the reference to std::string again + * string s2 = (string)ref; + * + * \endcode + */ +class String : public ObjectRef { + public: + /*! + * \brief Construct an empty string. + */ + String() : String(std::string()) {} + /*! + * \brief Construct a new String object + * + * \param other The moved/copied std::string object + * + * \note If user passes const reference, it will trigger copy. If it's rvalue, + * it will be moved into other. + */ + String(std::string other); // NOLINT(*) + + /*! + * \brief Construct a new String object + * + * \param other a char array. + */ + String(const char* other) // NOLINT(*) + : String(std::string(other)) {} + + /*! + * \brief Change the value the reference object points to. + * + * \param other The value for the new String + * + */ + inline String& operator=(std::string other); + + /*! + * \brief Change the value the reference object points to. + * + * \param other The value for the new String + */ + inline String& operator=(const char* other); + + /*! + * \brief Compares this String object to other + * + * \param other The String to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const String& other) const { + return memncmp(data(), other.data(), size(), other.size()); + } + + /*! + * \brief Compares this String object to other + * + * \param other The string to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const std::string& other) const { + return memncmp(data(), other.data(), size(), other.size()); + } + + /*! + * \brief Compares this to other + * + * \param other The character array to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const char* other) const { + return memncmp(data(), other, size(), std::strlen(other)); + } + + /*! + * \brief Returns a pointer to the char array in the string. + * + * \return const char* + */ + const char* c_str() const { return get()->data; } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t size() const { + const auto* ptr = get(); + return ptr->size; + } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t length() const { return size(); } + + /*! + * \brief Retun if the string is empty + * + * \return true if empty, false otherwise. + */ + bool empty() const { return size() == 0; } + + /*! + * \brief Read an element. + * \param pos The position at which to read the character. + * + * \return The char at position + */ + char at(size_t pos) const { + if (pos < size()) { + return data()[pos]; + } else { + throw std::out_of_range("tvm::String index out of bounds"); + } + } + + /*! + * \brief Return the data pointer + * + * \return const char* data pointer + */ + const char* data() const { return get()->data; } + + /*! + * \brief Convert String to an std::string object + * + * \return std::string + */ + operator std::string() const { return std::string{get()->data, size()}; } + + // LLVM compatibility function, implemented in src/target/llvm/llvm_common.h + /*! + * \brief Convert String to an llvm::StringRef object + * + * \return llvm::StringRef + */ + inline operator llvm::StringRef() const; + + /*! + * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String + * \param val The value to be checked + * \return A boolean indicating if val can be converted to String + */ + inline static bool CanConvertFrom(const TVMArgValue& val); + + /*! + * \brief Hash the binary bytes + * \param data The data pointer + * \param size The size of the bytes. + * \return the hash value. + */ + static size_t HashBytes(const char* data, size_t size) { + // This function falls back to string copy with c++11 compiler and is + // recommended to be compiled with c++14 +#if TVM_USE_CXX17_STRING_VIEW_HASH + return std::hash()(std::string_view(data, size)); +#elif TVM_USE_CXX14_STRING_VIEW_HASH + return std::hash()(std::experimental::string_view(data, size)); +#else + return std::hash()(std::string(data, size)); +#endif + } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); + + private: + /*! + * \brief Compare two char sequence + * + * \param lhs Pointers to the char array to compare + * \param rhs Pointers to the char array to compare + * \param lhs_count Length of the char array to compare + * \param rhs_count Length of the char array to compare + * \return int zero if both char sequences compare equal. negative if this + * appear before other, positive otherwise. + */ + static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); + + /*! + * \brief Concatenate two char sequences + * + * \param lhs Pointers to the lhs char array + * \param lhs_size The size of the lhs char array + * \param rhs Pointers to the rhs char array + * \param rhs_size The size of the rhs char array + * + * \return The concatenated char sequence + */ + static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) { + std::string ret(lhs, lhs_size); + ret.append(rhs, rhs_size); + return String(ret); + } + + // Overload + operator + friend String operator+(const String& lhs, const String& rhs); + friend String operator+(const String& lhs, const std::string& rhs); + friend String operator+(const std::string& lhs, const String& rhs); + friend String operator+(const String& lhs, const char* rhs); + friend String operator+(const char* lhs, const String& rhs); + + friend struct tvm::runtime::ObjectEqual; +}; + +/*! \brief An object representing string moved from std::string. */ +class StringObj::FromStd : public StringObj { + public: + /*! + * \brief Construct a new FromStd object + * + * \param other The moved/copied std::string object + * + * \note If user passes const reference, it will trigger copy. If it's rvalue, + * it will be moved into other. + */ + explicit FromStd(std::string other) : data_container{other} {} + + private: + /*! \brief Container that holds the memory. */ + std::string data_container; + + friend class String; +}; + +inline String::String(std::string other) { + auto ptr = make_object(std::move(other)); + ptr->size = ptr->data_container.size(); + ptr->data = ptr->data_container.data(); + data_ = std::move(ptr); +} + +inline String& String::operator=(std::string other) { + String replace{std::move(other)}; + data_.swap(replace.data_); + return *this; +} + +inline String& String::operator=(const char* other) { return operator=(std::string(other)); } + +inline String operator+(const String& lhs, const String& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const String& lhs, const std::string& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const std::string& lhs, const String& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const char* lhs, const String& rhs) { + size_t lhs_size = std::strlen(lhs); + size_t rhs_size = rhs.size(); + return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const String& lhs, const char* rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = std::strlen(rhs); + return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); +} + +// Overload < operator +inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; } + +inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; } + +// Overload > operator +inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; } + +inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; } + +// Overload <= operator +inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } + +inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } + +// Overload >= operator +inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } + +inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(rhs) <= 0; } + +// Overload == operator +inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; } + +inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; } + +// Overload != operator +inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; } + +inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; } + +inline std::ostream& operator<<(std::ostream& out, const String& input) { + out.write(input.data(), input.size()); + return out; +} + +inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { + if (lhs == rhs && lhs_count == rhs_count) return 0; + + for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { + if (lhs[i] < rhs[i]) return -1; + if (lhs[i] > rhs[i]) return 1; + } + if (lhs_count < rhs_count) { + return -1; + } else if (lhs_count > rhs_count) { + return 1; + } else { + return 0; + } +} + +inline size_t ObjectHash::operator()(const ObjectRef& a) const { + if (const auto* str = a.as()) { + return String::HashBytes(str->data, str->size); + } + return ObjectPtrHash()(a); +} + +inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) const { + if (a.same_as(b)) { + return true; + } + if (const auto* str_a = a.as()) { + if (const auto* str_b = b.as()) { + return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0; + } + } + return false; +} +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::String; +using runtime::StringObj; +} // namespace tvm + +namespace std { + +template <> +struct hash<::tvm::runtime::String> { + std::size_t operator()(const ::tvm::runtime::String& str) const { + return ::tvm::runtime::String::HashBytes(str.data(), str.size()); + } +}; +} // namespace std + +#endif // TVM_RUNTIME_CONTAINER_STRING_H_ diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index a493469a333d..58b9ff1932cc 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -257,8 +257,6 @@ inline const char* DeviceName(int type) { return "ext_dev"; case kDLWebGPU: return "webgpu"; - case kDLMicroDev: - return "micro_dev"; case kDLHexagon: return "hexagon"; default: diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index ada9b74503bc..1127a9ae732c 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -25,7 +25,9 @@ #define TVM_RUNTIME_NDARRAY_H_ #include -#include +#include +#include +#include #include #include #include @@ -126,7 +128,7 @@ class NDArray : public ObjectRef { * \param dtype The data type of the new array. * \note The memory size of new array must be smaller than the current one. */ - TVM_DLL NDArray CreateView(std::vector shape, DLDataType dtype); + TVM_DLL NDArray CreateView(ShapeTuple shape, DLDataType dtype); /*! * \brief Create a reference view of NDArray that * represents as DLManagedTensor. @@ -141,7 +143,7 @@ class NDArray : public ObjectRef { * \param mem_scope The memory scope of the array. * \return The created Array */ - TVM_DLL static NDArray Empty(std::vector shape, DLDataType dtype, Device dev, + TVM_DLL static NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional mem_scope = NullOpt); /*! * \brief Create a NDArray backed by a dlpack tensor. @@ -164,7 +166,7 @@ class NDArray : public ObjectRef { TVM_DLL static void CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); - TVM_DLL std::vector Shape() const; + TVM_DLL ShapeTuple Shape() const; TVM_DLL runtime::DataType DataType() const; // internal namespace struct Internal; @@ -239,7 +241,7 @@ class NDArray::ContainerBase { * \brief The shape container, * can be used used for shape data. */ - std::vector shape_; + ShapeTuple shape_; }; /*! @@ -259,13 +261,13 @@ class NDArray::Container : public Object, public NDArray::ContainerBase { dl_tensor.byte_offset = 0; } - Container(void* data, std::vector shape, DLDataType dtype, Device dev) { + Container(void* data, ShapeTuple shape, DLDataType dtype, Device dev) { // Initialize the type index. type_index_ = Container::RuntimeTypeIndex(); dl_tensor.data = data; shape_ = std::move(shape); dl_tensor.ndim = static_cast(shape_.size()); - dl_tensor.shape = dmlc::BeginPtr(shape_); + dl_tensor.shape = const_cast(shape_.data()); dl_tensor.dtype = dtype; dl_tensor.strides = nullptr; dl_tensor.byte_offset = 0; @@ -355,8 +357,7 @@ inline void NDArray::CopyTo(const NDArray& other) const { inline NDArray NDArray::CopyTo(const Device& dev) const { ICHECK(data_ != nullptr); const DLTensor* dptr = operator->(); - NDArray ret = - Empty(std::vector(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev); + NDArray ret = Empty(ShapeTuple(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev); this->CopyTo(ret); return ret; } @@ -458,7 +459,7 @@ inline bool NDArray::Load(dmlc::Stream* strm) { if (ndim != 0) { ICHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format"; } - NDArray ret = NDArray::Empty(shape, dtype, dev); + NDArray ret = NDArray::Empty(ShapeTuple(shape), dtype, dev); int64_t num_elems = 1; int elem_bytes = (ret->dtype.bits + 7) / 8; for (int i = 0; i < ret->ndim; ++i) { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index f13bdee09f87..0ed61177e65a 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -68,6 +68,8 @@ struct TypeIndex { kRuntimeArray = 4, /*! \brief runtime::Map. */ kRuntimeMap = 5, + /*! \brief runtime::ShapeTuple. */ + kRuntimeShapeTuple = 6, // static assignments that may subject to change. kRuntimeClosure, kRuntimeADT, diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 58bd2859c10a..9bfe379a3d77 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -25,7 +25,8 @@ #define TVM_RUNTIME_PACKED_FUNC_H_ #include -#include +#include +#include #include #include #include diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index e0fabfc5d8aa..2cdd180730ec 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -24,7 +24,8 @@ #ifndef TVM_RUNTIME_VM_EXECUTABLE_H_ #define TVM_RUNTIME_VM_EXECUTABLE_H_ -#include +#include +#include #include #include #include diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 15de1df98a78..58c6ee037fb5 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_VM_VM_H_ #define TVM_RUNTIME_VM_VM_H_ -#include +#include #include #include #include diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h index 32e74f6ef9d5..0ba7421ce409 100644 --- a/include/tvm/te/schedule_pass.h +++ b/include/tvm/te/schedule_pass.h @@ -88,18 +88,6 @@ bool VerifyCompactBuffer(const Stmt& stmt); */ Stmt ScheduleOps(Schedule s, Map dom_map, bool debug_keep_trivial_loop); -/*! - * \brief Try to modify the AST generated by ScheduleOps to support TensorCore. - * - * \param stmt The stmt to be trasnformed. - * \param schedule The original schedule. - * \param extern_buffer Map specifies external - * buffer assignment of input and outputs. - * \return Transformed stmt. - */ -Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, - Map extern_buffer); - /*! * \brief Postprocessing the Stmt generated by ScheduleOps to create * a PrimFunc that can then be used for further TIR optimizations. diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 401ba102c2f4..85677a726574 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -25,7 +25,6 @@ #define TVM_TE_TENSOR_H_ #include -#include #include #include diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 83f228da9475..a01d69b372d2 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -25,7 +25,8 @@ #define TVM_TIR_BUFFER_H_ #include -#include +#include +#include #include #include diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index e1d097474dd9..40d66a2d8357 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -29,7 +29,9 @@ #include #include #include -#include +#include +#include +#include #include #include #include diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index 963458ccee4a..6b5d6c48ddd0 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -29,7 +29,7 @@ #define TVM_TIR_OP_ATTR_TYPES_H_ #include -#include +#include #include namespace tvm { diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index b85fdec8cba9..9a09d0ad211f 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -24,6 +24,16 @@ namespace tvm { namespace tir { +/*! \brief The level of detailed error message rendering */ +enum class ScheduleErrorRenderLevel : int32_t { + /*! \brief Render a detailed error message */ + kDetail = 0, + /*! \brief Render the error in fast mode */ + kFast = 1, + /*! \brief No error message at all */ + kNone = 2, +}; + /**************** Random variable: BlockRV ****************/ /*! \brief A random variable that evaluates to a TensorIR block */ @@ -185,6 +195,35 @@ class ScheduleNode : public runtime::Object { * \return A list of loops above the given block in its scope, from outer to inner */ virtual Array GetLoops(const BlockRV& block_rv) = 0; + /******** Schedule: loops manipulation ********/ + /******** Schedule: compute location ********/ + /*! + * \brief Inline a block into its consumer(s). It requires: + * 1) The block is a complete non-root block, which only produces one buffer + * 2) The block must not be the only leaf in the scope. + * 3) The body of the block must be a BufferStore statement in the form of, + * A[i, j, k, ...] = ... + * where the indices of the LHS are all distinct atomic variables, + * and no variables other than those indexing variables are allowed in the statement. + * \param block The block to be inlined to its consumer(s) + */ + virtual void ComputeInline(const BlockRV& block) = 0; + /*! + * \brief Inline a block into its only producer. It requires: + * 1) The block is a complete non-root block, which only produces and consumers one buffer + * 2) The block must not be the only leaf in the scope. + * 3) The only producer of the block is a read-after-write producer and a complete non-root block + * 4) The body of the block must be a BufferStore statement in the form of, + * B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...) + * where the indices of each `BufferLoad` on the RHS are all distinct atomic variables, + * and no variables other than those indexing variables are allowed in the statement. + * \param block The block to be inlined to its producer + */ + virtual void ReverseComputeInline(const BlockRV& block) = 0; + /******** Schedule: loop binding/annotation ********/ + /******** Schedule: cache read/write ********/ + /******** Schedule: reduction ********/ + /******** Schedule: blockize & tensorize ********/ }; /*! @@ -209,13 +248,15 @@ class Schedule : public runtime::ObjectRef { * \param mod The IRModule to be scheduled * \param debug_mode Do extra correctness checking after the class creation * and each time after calling the Replace method. + * \param error_render_level The level of error rendering * \return The concrete schedule created * \sa ScheduleDebugMask * \note The checks performed includes: * 1) VerifySRefTree * 2) VerifyCachedFlags */ - TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode); + TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode, + ScheduleErrorRenderLevel error_render_level); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); }; diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 8273f9912a57..a6681f0b9941 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -27,7 +27,6 @@ #define TVM_TIR_STMT_FUNCTOR_H_ #include -#include #include #include #include diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 2de255da3fa2..2113d58f1ffa 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -212,6 +212,17 @@ TVM_DLL Pass InstrumentBoundCheckers(); */ TVM_DLL Pass MakePackedAPI(int num_unpacked_args); +/*! + * \brief Transform the high-level PrimFunc to a C signature that can be used + * to call the operator directly. + * + * The main task of this function is to create code that maps the values in the + * api_args to Var that is required by body + * + * \return The pass. + */ +TVM_DLL Pass MakeUnpackedAPI(); + /*! * \brief Remap the thread axis * @@ -371,8 +382,9 @@ TVM_DLL Pass ConvertBlocksToOpaque(); /*! * \brief Compact the buffer access region by removing the buffer regions that are not accessed, * i.e. narrowing the buffer shape and adjust the access region if necessary. - * \example - * Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed. + * + * Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed. + * * \code * * for i in range(0, 16): diff --git a/include/tvm/topi/cuda/reduction.h b/include/tvm/topi/cuda/reduction.h index 7160419422a6..51f35ed8dc25 100644 --- a/include/tvm/topi/cuda/reduction.h +++ b/include/tvm/topi/cuda/reduction.h @@ -70,7 +70,7 @@ Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch, if (out_stage->op.as()->axis.size() > 0) { all_reduce = false; num_thread = 32; - if (target->kind->name == "opencl") { + if (target->kind->name == "opencl" || target->kind->name == "metal") { // Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests. // Don't know why. num_thread = 16; diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h new file mode 100644 index 000000000000..da76022c552b --- /dev/null +++ b/include/tvm/topi/detail/strided_slice.h @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file strided_slice.h + * \brief Utility functions for strided_slice op + */ +#ifndef TVM_TOPI_DETAIL_STRIDED_SLICE_H_ +#define TVM_TOPI_DETAIL_STRIDED_SLICE_H_ + +#include + +#include +#include +#include +#include +#include + +#include "constant_utils.h" + +namespace tvm { +namespace topi { +namespace detail { + +using namespace tvm::te; + +inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) { + int64_t begin_range = stride < 0 ? -1 : 0; + int64_t end_range = stride < 0 ? extent - 1 : extent; + if (index < 0) { + index += extent; + } + return std::min(std::max(index, begin_range), end_range); +} + +inline std::tuple, std::vector, std::vector> ConvertToVec( + const Array& begin, const Array& end, const Array& strides, + std::string slice_mode) { + std::vector stride_vec(strides.size(), 1); + if (slice_mode == "end") { + for (size_t i = 0; i < strides.size(); ++i) { + ICHECK(strides[i].defined()); + stride_vec[i] = GetConstInt(strides[i]); + } + } + const int64_t max_range = std::numeric_limits::max(); + std::vector begin_vec; + for (size_t i = 0; i < begin.size(); ++i) { + if (!begin[i].defined()) { + // value=None + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } else { + begin_vec.push_back(GetConstInt(begin[i])); + } + } + std::vector end_vec; + for (size_t i = 0; i < end.size(); ++i) { + // allow end to be None + if (!end[i].defined()) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else if (slice_mode == "size") { + int64_t end_val = GetConstInt(end[i]); + if (end_val < 0) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else { + end_vec.push_back(begin_vec[i] + end_val); + } + } else { + end_vec.push_back(GetConstInt(end[i])); + } + } + return std::make_tuple(begin_vec, end_vec, stride_vec); +} + +inline Array StridedSliceCanonicalizeBegin(const Array& ishape, + const std::vector& begin, + const std::vector& strides, + const Array& axes, DataType dtype, + std::string slice_mode = "end") { + Array begin_expr; + for (size_t i = 0; i < axes.size(); ++i) { + if (ishape[axes[i]]->IsInstance()) { + int64_t dim_i = GetConstInt(ishape[axes[i]]); + int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]); + begin_expr.push_back(make_const(dtype, begin_i)); + } else { + auto idim = ishape[axes[i]]; + auto b_expr = make_const(dtype, begin[i]); + PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr; + auto s = strides[i]; + if (s < 0) { + b = tvm::min(b, idim - 1); + } else { + b = tvm::if_then_else(b < 0, 0, b); + } + begin_expr.push_back(b); + } + } + return begin_expr; +} + +inline Array StridedSliceOutputShape(const Array& ishape, + const std::vector& begin, + const std::vector& end, + const std::vector& strides, + const Array& axes, std::string slice_mode, + const Array& begin_canonicalized, + bool use_any = false) { + const size_t src_tensor_dim = ishape.size(); + Array out_shape; + for (size_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(ishape[i]); + } + + for (size_t i = 0; i < axes.size(); ++i) { + if (ishape[axes[i]]->IsInstance()) { + const int64_t dim_i = GetConstInt(ishape[axes[i]]); + ICHECK(begin_canonicalized[i]->IsInstance()); + int64_t begin_i = GetConstInt(begin_canonicalized[i]); + int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]); + int interval = std::abs(end_i - begin_i); + int slice_size = + static_cast((interval + std::abs(strides[i]) - 1) / std::abs(strides[i])); + ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) + << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i; + out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size))); + } else if (use_any) { + out_shape.Set(axes[i], tvm::tir::Any()); + } else { + out_shape.Set(axes[i], tvm::tir::Var("dim", out_shape[i]->dtype)); + } + } + + return out_shape; +} + +} // namespace detail +} // namespace topi +} // namespace tvm +#endif // TVM_TOPI_DETAIL_STRIDED_SLICE_H_ diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 29c3156ab5d6..d3328c59afb4 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -619,7 +619,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, out = reshape(out, r_p_shape); // Crop the start and end of dimensions of out - Array begin_idx, end_idx, strides; + Array begin_idx, end_idx, strides; for (size_t i = 0; i < r_p_shape.size(); ++i) { strides.push_back(Integer(1)); if (i > 0 && i <= num_block_dims) { diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 36acc7376c7c..8d1a49a4cc5f 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -27,8 +27,10 @@ #include #include #include +#include #include #include +#include #include #include @@ -39,8 +41,6 @@ #include #include -#include "detail/broadcast.h" - namespace tvm { namespace topi { @@ -551,7 +551,7 @@ inline Array split(const Tensor& x, Array split_indices, int a } /*! - * \brief strided_slice of a tensor with dynamic begin/end/stride + * \brief strided_slice of a tensor where begin/end/stride can be mixed static and dynamic * * \param x The input tensor * \param begin The indices to begin with in the slicing @@ -561,31 +561,45 @@ inline Array split(const Tensor& x, Array split_indices, int a * \param name The name of the operation * \param tag The tag to mark the operation * - * \return A Tensor whose op member is the split operation + * \return A Tensor whose op member is the dynamic_strided_slice operation */ -inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& begin, - const te::Tensor& end, const te::Tensor& strides, - std::string name = "T_strided_slice_dynamic", - std::string tag = topi::kInjective) { - int64_t src_tensor_dim = x->shape.size(); +inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begin, + const Array& end, const Array& strides, + std::string name = "T_dynamic_strided_slice", + std::string tag = kInjective) { + const size_t src_tensor_dim = x->shape.size(); + ICHECK_LE(begin.size(), src_tensor_dim); + ICHECK_LE(end.size(), src_tensor_dim); + ICHECK_LE(strides.size(), src_tensor_dim); + ICHECK_EQ(begin.size(), end.size()); + ICHECK_EQ(begin.size(), strides.size()); + + const size_t num_slice_axes = begin.size(); Array out_shape; - const int64_t num_dynamic_axes = begin->shape[0].as()->value; - for (int64_t i = 0; i < num_dynamic_axes; ++i) { - out_shape.push_back(tvm::tir::Var("dim")); + + for (size_t i = 0; i < num_slice_axes; ++i) { + auto d = indexdiv(end[i] - begin[i], strides[i]); + if (d->IsInstance()) { + // Preserve static dimension if possible + out_shape.push_back(d); + } else { + out_shape.push_back(tvm::tir::Var("dim")); + } } - for (int64_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { + + for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) { out_shape.push_back(x->shape[i]); } + return te::compute( out_shape, [&](const Array& indices) { Array real_indices; - // dynamic slicing - for (int32_t i = 0; i < num_dynamic_axes; ++i) { - real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1)); + for (size_t i = 0; i < num_slice_axes; ++i) { + real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1)); } // keep input dim - for (int32_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { + for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) { real_indices.push_back(indices[i]); } return x(real_indices); @@ -594,137 +608,152 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b } /*! - * \brief strided_slice of a tensor + * \brief strided_slice of a tensor with dynamic begin/end/stride * * \param x The input tensor * \param begin The indices to begin with in the slicing * \param end Indicies indicating end of the slice * \param strides Specifies the stride values, it can be negative * in that case, the input tensor will be reversed in that particular axis - * \param slice_mode Specifies the slice mode * \param name The name of the operation * \param tag The tag to mark the operation * - * \return A Tensor whose op member is the split operation + * \return A Tensor whose op member is the dynamic_strided_slice operation */ -inline Tensor strided_slice(const Tensor& x, const Array& begin, - const Array& end, const Array& strides, - std::string slice_mode = "end", std::string name = "T_strided_slice", - std::string tag = kInjective) { - size_t src_tensor_dim = static_cast(x->shape.size()); - // Quick path for dynamic shape strided slice. - // This is for ease of use to dynamice strided slice in topi. - bool is_static = IsConstIntArray(x->shape); - is_static &= IsConstIntArray(begin); - is_static &= IsConstIntArray(end); - is_static &= IsConstIntArray(strides); - - Array out_shape; - if (!is_static) { - ICHECK_EQ(strides.size(), src_tensor_dim); - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(indexdiv(end[i] - begin[i], strides[i])); - } - return te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides[i] + begin[i]); - } - return x(real_indices); - }, - name, tag); - } - - // Setup the ranges. - // NOTE: this code duplicates the shape inference logic relay.op - // Consider to refactor in the future. - std::vector stride_vec(src_tensor_dim, 1); - for (size_t i = 0; i < strides.size(); ++i) { - ICHECK(strides[i].defined()); - stride_vec[i] = GetConstInt(strides[i]); - } - - const int64_t max_range = std::numeric_limits::max(); +inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& begin, + const te::Tensor& end, const te::Tensor& strides, + std::string name = "T_strided_slice_dynamic", + std::string tag = topi::kInjective) { + const int64_t num_dynamic_axes = begin->shape[0].as()->value; + ICHECK_EQ(end->shape[0].as()->value, num_dynamic_axes); + ICHECK_EQ(strides->shape[0].as()->value, num_dynamic_axes); - std::vector begin_vec; - for (size_t i = 0; i < begin.size(); ++i) { - if (!begin[i].defined()) { - // value=None - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } else { - begin_vec.push_back(GetConstInt(begin[i])); - } - } - for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + Array begin_expr, end_expr, strides_expr; + for (int64_t i = 0; i < num_dynamic_axes; ++i) { + auto i64_ind = IntImm(DataType::Int(64), i); + begin_expr.push_back(begin(i64_ind)); + end_expr.push_back(end(i64_ind)); + strides_expr.push_back(strides(i64_ind)); } + return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag); +} - std::vector end_vec; - for (size_t i = 0; i < end.size(); ++i) { - // allow end to be None +/*! + * \brief Calcluate the output shape of strided_slice, the entry point for Relay type relation + * + * \param ishape The input tensor shape + * \param begin The indices to begin with in the slicing + * \param end Indicies indicating end of the slice + * \param strides Specifies the stride values, it can be negative + * in that case, the input tensor will be reversed in that particular axis + * \param axes Axes along which slicing is applied. When it is specified, the length of begin, end, + * strides, and axes argument must be equal + * \param slice_mode Specifies the slice mode + * + * \return The output shape of strided_slice using the arguments above + */ +inline Array StridedSliceOutputShape( + const Array& ishape, const Array& begin, const Array& end, + const Array& strides, const Array& axes, const std::string& slice_mode) { + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); + std::vector begin_vec, end_vec, strides_vec; + std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); + auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, + begin[0]->dtype, slice_mode); + return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode, + begin_canonicalized, true); +} - if (!end[i].defined()) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else if (slice_mode == "size") { - int64_t end_val = GetConstInt(end[i]); - if (end_val < 0) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else { - end_vec.push_back(begin_vec[i] + end_val); - } - } else { - end_vec.push_back(GetConstInt(end[i])); - } - } - for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } - // Compute - Array begin_expr; - Array strides_expr; - - for (size_t i = 0; i < src_tensor_dim; ++i) { - int64_t begin_range = stride_vec[i] < 0 ? -1 : 0; - int64_t dim_i = GetConstInt(x->shape[i]); - int64_t end_range = stride_vec[i] < 0 ? dim_i - 1 : dim_i; - // transform negative indices to positive value, clips on the correct range - auto index_canonicalization = [dim_i, begin_range, end_range](int64_t index) { - if (index < 0) { - index += dim_i; - } - return std::min(std::max(index, begin_range), end_range); - }; - - int64_t begin_i = index_canonicalization(begin_vec[i]); - int64_t end_i = index_canonicalization(end_vec[i]); - - int interval = std::abs(end_i - begin_i); - int slice_size = - static_cast((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); - ICHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) - << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] - << "] is invalid for axis=" << i; - - begin_expr.push_back(make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back( - make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i])); - out_shape.push_back(slice_size); - } +/*! + * \brief strided_slice of a tensor + * + * \param x The input tensor + * \param begin The indices to begin with in the slicing + * \param end Indicies indicating end of the slice + * \param strides Specifies the stride values, it can be negative + * in that case, the input tensor will be reversed in that particular axis + * \param axes Axes along which slicing is applied. When it is specified, the length of begin, end, + * strides, and axes argument must be equal + * \param slice_mode Specifies the slice mode + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the sstrided_slice operation + */ +inline Tensor strided_slice_with_axes(const Tensor& x, const Array& begin, + const Array& end, const Array& strides, + const Array& axes, std::string slice_mode = "end", + std::string name = "T_strided_slice_with_axes", + std::string tag = kInjective) { + const size_t src_tensor_dim = x->shape.size(); + ICHECK(axes.size() <= src_tensor_dim); + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); + + std::vector begin_vec, end_vec, strides_vec; + std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); + + auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes, + begin[0]->dtype, slice_mode); + auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes, + slice_mode, begin_expr); - return compute( + return te::compute( out_shape, - [&](const Array& indices) { + [&](const Array& indices) { Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); + for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); + for (size_t i = 0; i < axes.size(); ++i) { + auto stride = make_const(strides[i].dtype(), strides_vec[i]); + PrimExpr ind = indices[axes[i]] * stride + begin_expr[i]; + real_indices.Set(axes[i], ind); } return x(real_indices); }, name, tag); } +/*! + * \brief strided_slice of a tensor + * + * \param x The input tensor + * \param begin The indices to begin with in the slicing + * \param end Indicies indicating end of the slice + * \param strides Specifies the stride values, it can be negative + * in that case, the input tensor will be reversed in that particular axis + * \param slice_mode Specifies the slice mode + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the strided_slice operation + */ +inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array& end, + const Array& strides, std::string slice_mode = "end", + std::string name = "T_strided_slice", std::string tag = kInjective) { + size_t src_tensor_dim = static_cast(x->shape.size()); + Array axes; + for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); + Array begin_full(begin); + Array end_full(end); + Array strides_full(strides); + + const IntImm one = IntImm(DataType::Int(64), 1); + const IntImm zero = IntImm(DataType::Int(64), 0); + const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits::max()); + + for (size_t i = strides.size(); i < src_tensor_dim; ++i) { + strides_full.push_back(one); + } + for (size_t i = begin.size(); i < src_tensor_dim; ++i) { + begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range); + } + for (size_t i = end.size(); i < src_tensor_dim; ++i) { + end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range); + } + + return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name, + tag); +} + /*! * \brief Split a tensor into a number of sub-tensors * diff --git a/python/gen_requirements.py b/python/gen_requirements.py index 6869e4829d98..00dd4643190e 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -195,7 +195,7 @@ ("astroid", None), ("attrs", None), ("autodocsumm", None), - ("black", None), + ("black", "==20.8b1"), ("cloudpickle", None), ("commonmark", ">=0.7.3"), # From PR #213. ("coremltools", None), diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 0adad82d9bec..55a228882691 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -31,7 +31,7 @@ # tvm.runtime from .runtime.object import Object from .runtime.ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl -from .runtime.ndarray import vpi, rocm, ext_dev, micro_dev, hexagon +from .runtime.ndarray import vpi, rocm, ext_dev, hexagon from .runtime import ndarray as nd # tvm.error @@ -40,6 +40,7 @@ # tvm.ir from .ir import IRModule from .ir import transform +from .ir import instrument from .ir import container from . import ir @@ -67,7 +68,6 @@ # Contrib initializers from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel - # NOTE: This file should be python2 compatible so we can # raise proper error message when user run the package using # an older version of the python diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 6cfa3e5c286a..bf763a194311 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -73,7 +73,7 @@ def convert_to_tvm_func(pyfunc): local_pyfunc = pyfunc def cfun(args, type_codes, num_args, ret, _): - """ ctypes function """ + """ctypes function""" num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args pyargs = (C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)) # pylint: disable=broad-except diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 4eda5e8cc332..450a356aebdf 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -173,7 +173,6 @@ class Device(ctypes.Structure): 9: "vpi", 10: "rocm", 12: "ext_dev", - 13: "micro_dev", 14: "hexagon", 15: "webgpu", } @@ -194,7 +193,6 @@ class Device(ctypes.Structure): "vpi": 9, "rocm": 10, "ext_dev": 12, - "micro_dev": 13, "hexagon": 14, "webgpu": 15, } @@ -376,7 +374,7 @@ def api_version(self): The version of the SDK """ - return self._GetDeviceAttr(self.device_type, self.device_id, 12) + return self._GetDeviceAttr(self.device_type, self.device_id, 11) @property def driver_version(self): diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index c3b32b5960eb..5c532c692b1d 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -22,7 +22,7 @@ @tvm._ffi.register_object("arith.ModularSet") class ModularSet(Object): - """Represent range of (coeff * x + base) for x in Z """ + """Represent range of (coeff * x + base) for x in Z""" def __init__(self, coeff, base): self.__init_handle_by_constructor__(_ffi_api.ModularSet, coeff, base) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 7cfe6ccbc2c0..03cc00def6b7 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -48,12 +48,12 @@ @tvm._ffi.register_object("auto_scheduler.Iterator") class Iterator(Object): - """ A loop iterator structure. """ + """A loop iterator structure.""" @tvm._ffi.register_object("auto_scheduler.Stage") class Stage(Object): - """ A stage in the compute declaration. Similar to tvm.te.schedule.Stage. """ + """A stage in the compute declaration. Similar to tvm.te.schedule.Stage.""" # Static trans table for compute_at location # This is used to transform the compute_at location to C++ enum @@ -62,7 +62,7 @@ class Stage(Object): @tvm._ffi.register_object("auto_scheduler.State") class StateObject(Object): - """ The internal State object """ + """The internal State object""" def __eq__(self, other): return _ffi_api.StateEqual(self, other) @@ -579,7 +579,7 @@ def rfactor(self, stage, iterator, factor_iter_id): return self.stages[int(new_stage_id)].op def copy(self): - """ Do deep copy of this State. """ + """Do deep copy of this State.""" state = State(self.state_object, self.compute_dag) state.stage_id_map = self.stage_id_map.copy() return state diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index ea4a129727c3..8d762602bfd1 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -84,7 +84,7 @@ class BuildFunc: @tvm._ffi.register_object("auto_scheduler.MeasureCallback") class MeasureCallback(Object): - """ The base class of measurement callback functions. """ + """The base class of measurement callback functions.""" @tvm._ffi.register_object("auto_scheduler.PythonBasedMeasureCallback") @@ -244,7 +244,7 @@ def recover_measure_input(inp, rebuild_state=False): @tvm._ffi.register_object("auto_scheduler.ProgramBuilder") class ProgramBuilder(Object): - """ The base class of ProgramBuilders. """ + """The base class of ProgramBuilders.""" def build(self, measure_inputs, verbose=1): """Build programs and return results. @@ -265,7 +265,7 @@ def build(self, measure_inputs, verbose=1): @tvm._ffi.register_object("auto_scheduler.ProgramRunner") class ProgramRunner(Object): - """ The base class of ProgramRunners. """ + """The base class of ProgramRunners.""" def run(self, measure_inputs, build_results, verbose=1): """Run measurement and return results. @@ -585,7 +585,7 @@ def __del__(self): class MeasureErrorNo(object): - """ Error type for MeasureResult. """ + """Error type for MeasureResult.""" NO_ERROR = 0 # No error INSTANTIATION_ERROR = 1 # Errors happen when apply transform steps from init state diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 71acd03928a1..099502d17d78 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -24,6 +24,7 @@ import logging import threading +from copy import deepcopy import tvm from tvm import autotvm, transform @@ -64,7 +65,10 @@ def call_all_topi_funcs(mod, params, target): disabled_pass={"AutoSchedulerLayoutRewrite"}, ): try: - opt_mod, _ = relay.optimize(mod, target, params) + # TODO(jwfromm) Remove this once AlterOpLayout bug that mutates + # source module is fixed. Until then, create a clone. + mod_clone = deepcopy(mod) + opt_mod, _ = relay.optimize(mod_clone, target, params) grc = graph_executor_codegen.GraphExecutorCodegen(None, target) grc.codegen(opt_mod["main"]) except tvm.TVMError: @@ -72,11 +76,16 @@ def call_all_topi_funcs(mod, params, target): "Get errors with GraphExecutorCodegen for task extraction. " "Fallback to VMCompiler." ) + mod_clone = deepcopy(mod) compiler = relay.vm.VMCompiler() if params: compiler.set_params(params) - mod = tvm.IRModule.from_expr(mod) if isinstance(mod, relay.Function) else mod - compiler.lower(mod, target) + mod_clone = ( + tvm.IRModule.from_expr(mod_clone) + if isinstance(mod_clone, relay.Function) + else mod_clone + ) + compiler.lower(mod_clone, target) autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index f0388a886c5f..a88c1305b560 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -96,7 +96,7 @@ def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule" @tvm._ffi.register_object("auto_scheduler.SearchPolicy") class SearchPolicy(Object): - """ The base class of search policies. """ + """The base class of search policies.""" def continue_search_one_round(self, num_measure, measurer): """ diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 5cae556e2747..dd5073331083 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -540,7 +540,7 @@ def _restore_status(self, log_file, num_measures_per_round): class TaskSchedulerCallback: - """The base class of task scheduler callback functions. """ + """The base class of task scheduler callback functions.""" def pre_tune(self, task_scheduler, task_id): """The callback before tuning each task. diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index 14dc5b8984c3..1c03491c5614 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -190,7 +190,7 @@ def get_const_tuple(in_tuple): def list_to_tuple(x): - """ Convert a list to a tuple recursively. """ + """Convert a list to a tuple recursively.""" assert isinstance(x, list) return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x) @@ -250,7 +250,7 @@ def kill_child_processes(parent_pid, sig=signal.SIGTERM): def make_traceback_info(): - """ Get the error message from traceback. """ + """Get the error message from traceback.""" info = str(traceback.format_exc()) if len(info) > MAX_TRACEBACK_INFO_LEN: info = ( diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index dff0f098d84a..8d2591dce50b 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -39,7 +39,7 @@ def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ - binds, _ = build_module.get_binds(args, binds) + binds, _ = build_module.get_binds(args, compact=False, binds=binds) sch = sch.normalize() # Phase 0 bounds = schedule.InferBound(sch) diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 60a26ecd7d81..f41795fb0810 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -276,8 +276,6 @@ def get_build_kwargs(self): if "cuda" in self.task.target.keys: kwargs["cuda_arch"] = "sm_" + "".join(dev.compute_version.split(".")) - if self.task.target.device_name == "micro_dev": - kwargs.setdefault("build_option", {})["tir.disable_vectorize"] = True return kwargs diff --git a/python/tvm/autotvm/task/code_hash.py b/python/tvm/autotvm/task/code_hash.py index 3331fc13c719..2bd053da7244 100644 --- a/python/tvm/autotvm/task/code_hash.py +++ b/python/tvm/autotvm/task/code_hash.py @@ -19,6 +19,7 @@ code hashing is used to check the consistence of schedule code and the parameters loaded from log """ +import functools import inspect import zlib @@ -35,6 +36,7 @@ def attach_code_hash(s): """ def decorator(func): + @functools.wraps(func) def wrapper(*args, **kwargs): func(*args, **kwargs) raw_hash = zlib.crc32("".join(inspect.getsourcelines(func)[0]).encode()) @@ -56,6 +58,7 @@ def attach_code_hash_to_arg(arg_idx=1): """ def decorator(func): + @functools.wraps(func) def wrapper(*args, **kwargs): func(*args, **kwargs) assert isinstance(args[arg_idx], schedule.Schedule) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 9117ce398d49..3dceac1b7ffd 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -22,6 +22,7 @@ """ import threading import logging +from copy import deepcopy import tvm from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext @@ -53,7 +54,10 @@ def _lower(mod, target, params): # If failed to compile, then fallback to use VM compiler. # TODO: Currently VM compiler is likely to stack overflow for large models. try: - opt_mod, _ = relay.optimize(mod, target, params) + # TODO(jwfromm) Remove this once AlterOpLayout bug that mutates + # source module is fixed. Until then, create a clone. + mod_clone = deepcopy(mod) + opt_mod, _ = relay.optimize(mod_clone, target, params) grc = graph_executor_codegen.GraphExecutorCodegen(None, target) grc.codegen(opt_mod["main"]) except tvm.TVMError as e: @@ -61,10 +65,11 @@ def _lower(mod, target, params): "Get errors with GraphExecutorCodegen for task extraction. " "Fallback to VMCompiler. Error details:\n%s" % str(e) ) + mod_clone = deepcopy(mod) compiler = relay.vm.VMCompiler() if params: compiler.set_params(params) - compiler.lower(mod, target=target) + compiler.lower(mod_clone, target=target) def extract_from_program(mod, params, target, target_host=None, ops=None): diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 668832b8a86c..1f5827d7e9d0 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -21,6 +21,8 @@ func is a state-less function, or a string that registers the standard task. """ +import functools + import numpy as np from tvm import runtime @@ -411,6 +413,7 @@ def matmul(N, L, M, dtype): """ def _decorate(f): + @functools.wraps(f) def wrapper(*args, **kwargs): assert not kwargs, "Do not support kwargs in template function call" workload = args_to_workload(args, task_name) diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 2558c7669ac9..32d8674640ed 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -26,6 +26,8 @@ See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. """ +import functools + import tvm.te._ffi_api from tvm.target import Target from tvm.te import tensor @@ -149,6 +151,7 @@ def register_topi_compute(task_name, func=None): """ def _decorate(topi_compute): + @functools.wraps(topi_compute) @_register_task_compute(task_name) def wrapper(*args, **kwargs): """wrapper function for topi compute""" @@ -224,6 +227,7 @@ def register_topi_schedule(task_name, func=None): """ def _decorate(topi_schedule): + @functools.wraps(topi_schedule) @_register_task_schedule(task_name) def wrapper(outs, *args, **kwargs): """wrapper function for topi schedule""" diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index f346b33bc34d..5c0a46336532 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -65,7 +65,6 @@ def _alias(name): """convert alias for some packages""" table = { "vtacpu": "vta", - "metal": "opencl", "webgpu": "opencl", "vulkan": "opencl", "nvptx": "cuda", @@ -201,6 +200,8 @@ def load_reference_log(backend, model, workload_name): """ backend = _alias(backend) + if backend not in PACKAGE_VERSION: + return [] version = PACKAGE_VERSION[backend] package_name = "%s_%s.log" % (backend, version) filename = Path(AUTOTVM_TOPHUB_ROOT_PATH, package_name) diff --git a/python/tvm/contrib/cblas.py b/python/tvm/contrib/cblas.py index 58bf933d44b8..1dfeb801b370 100644 --- a/python/tvm/contrib/cblas.py +++ b/python/tvm/contrib/cblas.py @@ -72,7 +72,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs C: Tensor The result tensor. """ - b = lhs.shape[0] + b = te.max(lhs.shape[0], rhs.shape[0]) n = lhs.shape[2] if transa else lhs.shape[1] m = rhs.shape[1] if transb else rhs.shape[2] return te.extern( diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index f48ae395fbcd..64cbbd28604c 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -92,7 +92,7 @@ def get_target_by_dump_machine(compiler): """ def get_target_triple(): - """ Get target triple according to dumpmachine option of compiler.""" + """Get target triple according to dumpmachine option of compiler.""" if compiler: cmd = [compiler, "-dumpmachine"] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) diff --git a/python/tvm/contrib/mkl.py b/python/tvm/contrib/mkl.py index c6e340619ef8..449d660c9027 100644 --- a/python/tvm/contrib/mkl.py +++ b/python/tvm/contrib/mkl.py @@ -105,7 +105,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs C: Tensor The result tensor. """ - b = lhs.shape[0] + b = te.max(lhs.shape[0], rhs.shape[0]) n = lhs.shape[2] if transa else lhs.shape[1] m = rhs.shape[1] if transb else rhs.shape[2] return te.extern( diff --git a/python/tvm/contrib/peak.py b/python/tvm/contrib/peak.py index 195f3dc9d81e..4133aa31a50b 100644 --- a/python/tvm/contrib/peak.py +++ b/python/tvm/contrib/peak.py @@ -26,7 +26,7 @@ def _convert_to_remote(func, remote): - """ convert module function to remote rpc function""" + """convert module function to remote rpc function""" temp = utils.tempdir() path_dso = temp.relpath("tmp_func.tar") func.export_library(path_dso) diff --git a/python/tvm/contrib/tedd.py b/python/tvm/contrib/tedd.py index 10598e26824e..a65f5e474a3d 100644 --- a/python/tvm/contrib/tedd.py +++ b/python/tvm/contrib/tedd.py @@ -147,7 +147,7 @@ def get_itervar_label_color(itervar, iv_type): def linebrk(s, n): - """ Break input string s with
for every n charactors.""" + """Break input string s with
for every n charactors.""" result = "" j = 0 for i, c in enumerate(s): diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py new file mode 100644 index 000000000000..c423656d78f5 --- /dev/null +++ b/python/tvm/driver/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.driver""" +import tvm._ffi + +tvm._ffi._init_api("driver", __name__) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a3d0bb656736..a4df63f225b2 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -37,96 +37,58 @@ from tvm.tir.buffer import Buffer from tvm.tir.expr import Var +from . import _ffi_api as ffi + def get_binds(args, compact=False, binds=None): """Internal function to get binds and arg_list given arguments. - Parameters ---------- args : list of Buffer or Tensor or Var The argument lists to the function. - compact : bool If the statement has already bound to a compact buffer. - binds : dict of :any:`Tensor` to :any:`Buffer`, optional Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. - Returns ------- binds: dict The bind specification - arg_list: list The list of symbolic buffers of arguments. """ - binds = {} if binds is None else binds.copy() - arg_list = [] - for x in args: - if isinstance(x, tensor.Tensor): - any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape) - buffer_type = "auto_broadcast" if any_dim and not compact else "" - if x not in binds: - buf = tvm.tir.decl_buffer( - x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type - ) - binds[x] = buf - arg_list.append(buf) - else: - arg_list.append(binds[x]) - elif isinstance(x, schedule.Buffer): - arg_list.append(x) - elif isinstance(x, tvm.tir.Var): - arg_list.append(x) - else: - raise ValueError("args must be Tensor, Buffer or Var") + binds, arg_list = ffi.get_binds(args, compact, binds) return binds, arg_list -def form_irmodule(sch, args, name, binds): +def schedule_to_module( + sch: schedule.Schedule, + args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, + name: str = "main", + binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, +) -> IRModule: """According to the given schedule, form a function. - Parameters ---------- sch : tvm.te.schedule.Schedule The given scheduler to form the raw body - args : list of Buffer or Tensor or Var The argument lists to the function. - name : str - The name of result function. - + The name of result function, default name is "main" binds : dict of :any:`Tensor` to :any:`Buffer`, optional The binds information - Returns ------- The body formed according to the given schedule """ - # normalize schedule first - pass_ctx = PassContext.current() - sch = sch.normalize() - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds) - - compact = schedule.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, binds) - - stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) - func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) - - func = func.with_attr("global_symbol", name) - - if pass_ctx.config.get("tir.noalias", True): - func = func.with_attr("tir.noalias", True) - return tvm.IRModule({name: func}) + return ffi.schedule_to_module(sch, args, name, binds) def lower( - inputs: Union[schedule.Schedule, PrimFunc, IRModule], + inp: Union[schedule.Schedule, PrimFunc, IRModule], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, name: str = "main", binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, @@ -136,7 +98,7 @@ def lower( Parameters ---------- - input : Union[schedule.Schedule, PrimFunc, IRModule] + inputs : Union[schedule.Schedule, PrimFunc, IRModule] The TE schedule or TensorIR PrimFunc/IRModule to be built args : Optional[List[Union[Buffer, tensor.Tensor, Var]]] @@ -160,90 +122,13 @@ def lower( m : IRModule The result IRModule """ - # config setup - pass_ctx = PassContext.current() - instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False)) - disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False)) - add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", []) - - lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] - lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] - lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] - lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] - - # Phase 0 - pass_list = lower_phase0 - is_legacy_te_schedule: bool = False - - if isinstance(inputs, schedule.Schedule): - if args is None: - raise ValueError("args must be given for lowering from TE schedule") - mod = form_irmodule(inputs, args, name, binds) - is_legacy_te_schedule = True - elif isinstance(inputs, PrimFunc): - func = inputs.with_attr("global_symbol", name) - if pass_ctx.config.get("tir.noalias", True): - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: func}) - elif isinstance(inputs, IRModule): - mod = inputs - else: - raise TypeError( - f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got {type(inputs)}" - ) - - # Phase 1 - if is_legacy_te_schedule: - pass_list += [ - tvm.tir.transform.InjectPrefetch(), - tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), - ] - else: - pass_list += [ - tvm.tir.transform.LowerInitBlock(), - tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), - tvm.tir.transform.ConvertBlocksToOpaque(), - tvm.tir.transform.CompactBufferAllocation(), - tvm.tir.transform.FlattenBuffer(), - ] - pass_list += [ - tvm.tir.transform.BF16Legalize(), - tvm.tir.transform.NarrowDataType(32), - tvm.tir.transform.Simplify(), - ] - - pass_list += lower_phase1 - - # Phase 2 - if not simple_mode: - pass_list += [(tvm.tir.transform.LoopPartition())] - - pass_list += [ - tvm.tir.transform.VectorizeLoop(not disable_vectorize), - tvm.tir.transform.InjectVirtualThread(), - tvm.tir.transform.InjectDoubleBuffer(), - tvm.tir.transform.StorageRewrite(), - tvm.tir.transform.UnrollLoop(), - ] - pass_list += lower_phase2 - - # Phase 3 - pass_list += [ - tvm.tir.transform.Simplify(), - tvm.tir.transform.RemoveNoOp(), - ] - - pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] - pass_list += [tvm.tir.transform.HoistIfThenElse()] - pass_list += lower_phase3 - - # Instrument BoundCheckers - if instrument_bound_checkers: - pass_list += [tvm.tir.transform.InstrumentBoundCheckers()] - - optimize = tvm.transform.Sequential(pass_list) - mod = optimize(mod) - return mod + if isinstance(inp, IRModule): + return ffi.lower_module(inp, simple_mode) + if isinstance(inp, PrimFunc): + return ffi.lower_primfunc(inp, name, simple_mode) + if isinstance(inp, schedule.Schedule): + return ffi.lower_schedule(inp, args, name, binds, simple_mode) + raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def _build_for_device(input_mod, target, target_host): diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index b92ab86ef621..e5e59b2bbde2 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -46,7 +46,7 @@ @register_parser def add_tune_parser(subparsers): - """ Include parser for 'tune' subcommand """ + """Include parser for 'tune' subcommand""" parser = subparsers.add_parser("tune", help="auto-tune a model") parser.set_defaults(func=drive_tune) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 34f59aac9712..48e18fb6b6ad 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -97,10 +97,10 @@ def validate_targets(parse_targets): ) tvm_targets = [t for t in targets if t in tvm_target_kinds] - if len(tvm_targets) > 1: + if len(tvm_targets) > 2: verbose_tvm_targets = ", ".join(tvm_targets) raise TVMCException( - "Only one of the following targets can be used at a time. " + "Only two of the following targets can be used at a time. " f"Found: {verbose_tvm_targets}." ) @@ -199,6 +199,7 @@ def parse_target(target): """ codegens = [] + tvm_target_kinds = tvm.target.Target.list_kinds() parsed_tokens = tokenize_target(target) split_codegens = [] @@ -222,6 +223,7 @@ def parse_target(target): for codegen_def in split_codegens: # the first is expected to be the name name = codegen_def[0] + is_tvm_target = name in tvm_target_kinds raw_target = " ".join(codegen_def) all_opts = codegen_def[1:] if len(codegen_def) > 1 else [] opts = {} @@ -244,7 +246,9 @@ def parse_target(target): opts[opt_name] = opt_value - codegens.append({"name": name, "opts": opts, "raw": raw_target}) + codegens.append( + {"name": name, "opts": opts, "raw": raw_target, "is_tvm_target": is_tvm_target} + ) return codegens @@ -295,10 +299,21 @@ def target_from_cli(target): raise TVMCException(f"Error parsing target string '{target}'.\nThe error was: {ex}") validate_targets(parsed_targets) - target = parsed_targets[-1]["raw"] - extra_targets = parsed_targets[:-1] if len(parsed_targets) > 1 else [] + tvm_targets = [t for t in parsed_targets if t["is_tvm_target"]] + + # Validated target strings have 1 or 2 tvm targets, otherwise + # `validate_targets` above will fail. + if len(tvm_targets) == 1: + target = tvm_targets[0]["raw"] + target_host = None + else: + assert len(tvm_targets) == 2 + target = tvm_targets[0]["raw"] + target_host = tvm_targets[1]["raw"] + + extra_targets = [t for t in parsed_targets if not t["is_tvm_target"]] - return tvm.target.Target(target), extra_targets + return tvm.target.Target(target, host=target_host), extra_targets def tracker_host_port_from_cli(rpc_tracker_str): diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index dcb770b9a563..071474a31594 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -38,42 +38,50 @@ @register_parser def add_compile_parser(subparsers): - """ Include parser for 'compile' subcommand """ + """Include parser for 'compile' subcommand""" - parser = subparsers.add_parser("compile", help="compile a model") + parser = subparsers.add_parser("compile", help="compile a model.") parser.set_defaults(func=drive_compile) parser.add_argument( "--cross-compiler", default="", - help="the cross compiler to generate target libraries, e.g. 'aarch64-linux-gnu-gcc'", + help="the cross compiler to generate target libraries, e.g. 'aarch64-linux-gnu-gcc'.", ) parser.add_argument( "--cross-compiler-options", default="", - help="the cross compiler options to generate target libraries, e.g. '-mfpu=neon-vfpv4'", + help="the cross compiler options to generate target libraries, e.g. '-mfpu=neon-vfpv4'.", ) parser.add_argument( "--desired-layout", choices=["NCHW", "NHWC"], default=None, - help="change the data layout of the whole graph", + help="change the data layout of the whole graph.", ) parser.add_argument( "--dump-code", metavar="FORMAT", default="", - help="comma separarated list of formats to export, e.g. 'asm,ll,relay' ", + help="comma separated list of formats to export the input model, e.g. 'asm,ll,relay'.", ) parser.add_argument( "--model-format", choices=frontends.get_frontend_names(), - help="specify input model format", + help="specify input model format.", ) parser.add_argument( "-o", "--output", default="module.tar", - help="output the compiled module to an archive", + help="output the compiled module to a specifed archive. Defaults to 'module.tar'.", + ) + parser.add_argument( + "-f", + "--output-format", + choices=["so", "mlf"], + default="so", + help="output format. Use 'so' for shared object or 'mlf' for Model Library Format " + "(only for µTVM targets). Defaults to 'so'.", ) parser.add_argument( "--target", @@ -85,23 +93,23 @@ def add_compile_parser(subparsers): metavar="PATH", default="", help="path to an auto-tuning log file by AutoTVM. If not presented, " - "the fallback/tophub configs will be used", + "the fallback/tophub configs will be used.", ) - parser.add_argument("-v", "--verbose", action="count", default=0, help="increase verbosity") + parser.add_argument("-v", "--verbose", action="count", default=0, help="increase verbosity.") # TODO (@leandron) This is a path to a physical file, but # can be improved in future to add integration with a modelzoo # or URL, for example. - parser.add_argument("FILE", help="path to the input model file") + parser.add_argument("FILE", help="path to the input model file.") parser.add_argument( "--input-shapes", help="specify non-generic shapes for model to run, format is " - '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"', + '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]".', type=common.parse_shape_string, default=None, ) parser.add_argument( "--disabled-pass", - help="disable specific passes, comma-separated list of pass names", + help="disable specific passes, comma-separated list of pass names.", type=common.parse_pass_list_str, default="", ) @@ -132,6 +140,7 @@ def drive_compile(args): package_path=args.output, cross=args.cross_compiler, cross_options=args.cross_compiler_options, + output_format=args.output_format, dump_code=dump_code, target_host=None, desired_layout=args.desired_layout, @@ -148,7 +157,7 @@ def compile_model( package_path: Optional[str] = None, cross: Optional[Union[str, Callable]] = None, cross_options: Optional[str] = None, - export_format: str = "so", + output_format: str = "so", dump_code: Optional[List[str]] = None, target_host: Optional[str] = None, desired_layout: Optional[str] = None, @@ -177,7 +186,7 @@ def compile_model( Function that performs the actual compilation cross_options : str, optional Command line options to be passed to the cross compiler. - export_format : str + output_format : str What format to use when saving the function library. Must be one of "so" or "tar". When compiling for a remote device without a cross compiler, "tar" will likely work better. dump_code : list, optional @@ -262,7 +271,11 @@ def compile_model( # Create a new tvmc model package object from the graph definition. package_path = tvmc_model.export_package( - graph_module, package_path, cross, cross_options, export_format + graph_module, + package_path, + cross, + cross_options, + output_format, ) # Write dumps to file. diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 89ca1b8fc329..ceee5ccd7266 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -77,7 +77,7 @@ def load(self, path, shape_dict=None, **kwargs): def import_keras(): - """ Lazy import function for Keras""" + """Lazy import function for Keras""" # Keras writes the message "Using TensorFlow backend." to stderr # Redirect stderr during the import to disable this stderr = sys.stderr @@ -93,7 +93,7 @@ def import_keras(): class KerasFrontend(Frontend): - """ Keras frontend for TVMC """ + """Keras frontend for TVMC""" @staticmethod def name(): @@ -151,7 +151,7 @@ def sequential_to_functional(self, model): class OnnxFrontend(Frontend): - """ ONNX frontend for TVMC """ + """ONNX frontend for TVMC""" @staticmethod def name(): @@ -172,7 +172,7 @@ def load(self, path, shape_dict=None, **kwargs): class TensorflowFrontend(Frontend): - """ TensorFlow frontend for TVMC """ + """TensorFlow frontend for TVMC""" @staticmethod def name(): @@ -199,7 +199,7 @@ def load(self, path, shape_dict=None, **kwargs): class TFLiteFrontend(Frontend): - """ TFLite frontend for TVMC """ + """TFLite frontend for TVMC""" @staticmethod def name(): @@ -237,7 +237,7 @@ def load(self, path, shape_dict=None, **kwargs): class PyTorchFrontend(Frontend): - """ PyTorch frontend for TVMC """ + """PyTorch frontend for TVMC""" @staticmethod def name(): diff --git a/python/tvm/driver/tvmc/main.py b/python/tvm/driver/tvmc/main.py index 1d360d98206e..2574daab02ac 100644 --- a/python/tvm/driver/tvmc/main.py +++ b/python/tvm/driver/tvmc/main.py @@ -53,7 +53,7 @@ def _example_parser(main_subparser): def _main(argv): - """ TVM command line interface. """ + """TVM command line interface.""" parser = argparse.ArgumentParser( prog="tvmc", diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 26a1e3600b96..c0ebb842b994 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -54,6 +54,11 @@ from tvm.contrib import utils from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule +try: + from tvm.micro import export_model_library_format +except ImportError: + export_model_library_format = None + from .common import TVMCException @@ -175,7 +180,7 @@ def default_package_path(self): """ return self._tmp_dir.relpath("model_package.tar") - def export_package( + def export_classic_format( self, executor_factory: GraphExecutorFactoryModule, package_path: Optional[str] = None, @@ -203,8 +208,6 @@ def export_package( package_path : str The path that the package was saved to. """ - if lib_format not in ["so", "tar"]: - raise TVMCException("Only .so and .tar export formats are supported.") lib_name = "mod." + lib_format graph_name = "mod.json" param_name = "mod.params" @@ -241,6 +244,53 @@ def export_package( return package_path + def export_package( + self, + executor_factory: GraphExecutorFactoryModule, + package_path: Optional[str] = None, + cross: Optional[Union[str, Callable]] = None, + cross_options: Optional[str] = None, + output_format: str = "so", + ): + """Save this TVMCModel to file. + Parameters + ---------- + executor_factory : GraphExecutorFactoryModule + The factory containing compiled the compiled artifacts needed to run this model. + package_path : str, None + Where the model should be saved. Note that it will be packaged as a .tar file. + If not provided, the package will be saved to a generically named file in tmp. + cross : str or callable object, optional + Function that performs the actual compilation. + cross_options : str, optional + Command line options to be passed to the cross compiler. + output_format : str + How to save the modules function library. Must be one of "so" and "tar" to save + using the classic format or "mlf" to save using the Model Library Format. + + Returns + ------- + package_path : str + The path that the package was saved to. + """ + if output_format not in ["so", "tar", "mlf"]: + raise TVMCException("Only 'so', 'tar', and 'mlf' output formats are supported.") + + if output_format == "mlf" and cross: + raise TVMCException("Specifying the MLF output and a cross compiler is not supported.") + + if output_format in ["so", "tar"]: + package_path = self.export_classic_format( + executor_factory, package_path, cross, cross_options, output_format + ) + elif output_format == "mlf": + if export_model_library_format: + package_path = export_model_library_format(executor_factory, package_path) + else: + raise Exception("micro tvm is not enabled. Set USE_MICRO to ON in config.cmake") + + return package_path + def summary(self, file: TextIO = None): """Print the IR corressponding to this model. @@ -274,25 +324,41 @@ def import_package(self, package_path: str): package_path : str The path to the saved TVMCPackage. """ - lib_name_so = "mod.so" - lib_name_tar = "mod.tar" - graph_name = "mod.json" - param_name = "mod.params" - temp = self._tmp_dir t = tarfile.open(package_path) t.extractall(temp.relpath(".")) - with open(temp.relpath(param_name), "rb") as param_file: - self.params = bytearray(param_file.read()) - self.graph = open(temp.relpath(graph_name)).read() - if os.path.exists(temp.relpath(lib_name_so)): - self.lib_name = lib_name_so - elif os.path.exists(temp.relpath(lib_name_tar)): - self.lib_name = lib_name_tar + if os.path.exists(temp.relpath("metadata.json")): + # Model Library Format (MLF) + self.lib_name = None + self.lib_path = None + + graph = temp.relpath("runtime-config/graph/graph.json") + params = temp.relpath("parameters/default.params") + + self.type = "mlf" else: - raise TVMCException("Couldn't find exported library in the package.") - self.lib_path = temp.relpath(self.lib_name) + # Classic format + lib_name_so = "mod.so" + lib_name_tar = "mod.tar" + if os.path.exists(temp.relpath(lib_name_so)): + self.lib_name = lib_name_so + elif os.path.exists(temp.relpath(lib_name_tar)): + self.lib_name = lib_name_tar + else: + raise TVMCException("Couldn't find exported library in the package.") + self.lib_path = temp.relpath(self.lib_name) + + graph = temp.relpath("mod.json") + params = temp.relpath("mod.params") + + self.type = "classic" + + with open(params, "rb") as param_file: + self.params = bytearray(param_file.read()) + + with open(graph) as graph_file: + self.graph = graph_file.read() class TVMCResult(object): diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index c59689face63..191f8616c405 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -51,7 +51,7 @@ def add_run_parser(subparsers): # like 'webgpu', etc (@leandron) parser.add_argument( "--device", - choices=["cpu", "gpu", "cl"], + choices=["cpu", "cuda", "cl"], default="cpu", help="target device to run the compiled module. Defaults to 'cpu'", ) @@ -323,7 +323,7 @@ def run_module( tvmc_package: TVMCPackage The compiled model package object that will be run. device: str, - the device (e.g. "cpu" or "gpu") to be targeted by the RPC + the device (e.g. "cpu" or "cuda") to be targeted by the RPC session, local or remote). hostname : str, optional The hostname of the target device on which to run. @@ -359,6 +359,14 @@ def run_module( "Try calling tvmc.compile on the model before running it." ) + # Currently only two package formats are supported: "classic" and + # "mlf". The later can only be used for micro targets, i.e. with µTVM. + if tvmc_package.type == "mlf": + raise TVMCException( + "You're trying to run a model saved using the Model Library Format (MLF)." + "MLF can only be used to run micro targets (µTVM)." + ) + if hostname: if isinstance(port, str): port = int(port) diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 4bc7f1ae4468..b4cc4421b169 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -23,7 +23,7 @@ from .tensor_type import TensorType from .type_relation import TypeCall, TypeRelation from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range -from .op import Op, register_op, register_op_attr, register_intrin_lowering +from .op import Op, register_op_attr, register_intrin_lowering from .function import CallingConv, BaseFunc from .adt import Constructor, TypeData from .module import IRModule @@ -31,4 +31,5 @@ from .container import Array, Map from . import transform +from . import instrument from . import diagnostics diff --git a/python/tvm/ir/_ffi_instrument_api.py b/python/tvm/ir/_ffi_instrument_api.py new file mode 100644 index 000000000000..bf62caf30e5a --- /dev/null +++ b/python/tvm/ir/_ffi_instrument_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.instrument""" +import tvm._ffi + +tvm._ffi._init_api("instrument", __name__) diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py new file mode 100644 index 000000000000..c322f2bef3fc --- /dev/null +++ b/python/tvm/ir/instrument.py @@ -0,0 +1,159 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Common pass instrumentation across IR variants.""" +import inspect +import functools + +import tvm._ffi +import tvm.runtime + +from . import _ffi_instrument_api + + +@tvm._ffi.register_object("instrument.PassInstrument") +class PassInstrument(tvm.runtime.Object): + """A pass instrument implementation. + + Users don't need to interact with this class directly. + Instead, a `PassInstrument` instance should be created through `pass_instrument`. + + See Also + -------- + `pass_instrument` + """ + + +def _wrap_class_pass_instrument(pi_cls): + """Wrap a python class as pass instrument""" + + class PyPassInstrument(PassInstrument): + """Internal wrapper class to create a class instance.""" + + def __init__(self, *args, **kwargs): + # initialize handle in case pi_cls creation failed. + self.handle = None + inst = pi_cls(*args, **kwargs) + + # check method declartion within class, if found, wrap it. + def create_method(method): + if hasattr(inst, method) and inspect.ismethod(getattr(inst, method)): + + def func(*args): + return getattr(inst, method)(*args) + + func.__name__ = "_" + method + return func + return None + + # create runtime pass instrument object + # reister instance's enter_pass_ctx,exit_pass_ctx, should_run, run_before_pass and + # run_after_pass methods to it if present. + self.__init_handle_by_constructor__( + _ffi_instrument_api.PassInstrument, + pi_cls.__name__, + create_method("enter_pass_ctx"), + create_method("exit_pass_ctx"), + create_method("should_run"), + create_method("run_before_pass"), + create_method("run_after_pass"), + ) + + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyPassInstrument.__init__, pi_cls.__init__) + PyPassInstrument.__name__ = pi_cls.__name__ + PyPassInstrument.__doc__ = pi_cls.__doc__ + PyPassInstrument.__module__ = pi_cls.__module__ + return PyPassInstrument + + +def pass_instrument(pi_cls=None): + """Decorate a pass instrument. + + Parameters + ---------- + pi_class : + + Examples + -------- + The following code block decorates a pass instrument class. + + .. code-block:: python + @tvm.instrument.pass_instrument + class SkipPass: + def __init__(self, skip_pass_name): + self.skip_pass_name = skip_pass_name + + # Uncomment to customize + # def enter_pass_ctx(self): + # pass + + # Uncomment to customize + # def exit_pass_ctx(self): + # pass + + # If pass name contains keyword, skip it by return False. (return True: not skip) + def should_run(self, mod, pass_info) + if self.skip_pass_name in pass_info.name: + return False + return True + + # Uncomment to customize + # def run_before_pass(self, mod, pass_info): + # pass + + # Uncomment to customize + # def run_after_pass(self, mod, pass_info): + # pass + + skip_annotate = SkipPass("AnnotateSpans") + with tvm.transform.PassContext(instruments=[skip_annotate]): + tvm.relay.build(mod, "llvm") + """ + + def create_pass_instrument(pi_cls): + if not inspect.isclass(pi_cls): + raise TypeError("pi_cls must be a class") + + return _wrap_class_pass_instrument(pi_cls) + + if pi_cls: + return create_pass_instrument(pi_cls) + return create_pass_instrument + + +@tvm._ffi.register_object("instrument.PassInstrument") +class PassTimingInstrument(tvm.runtime.Object): + """A wrapper to create a passes time instrument that implemented in C++""" + + def __init__(self): + self.__init_handle_by_constructor__(_ffi_instrument_api.MakePassTimingInstrument) + + @staticmethod + def render(): + """Retrieve rendered time profile result + Returns + ------- + string : string + The rendered string result of time profiles + """ + return _ffi_instrument_api.RenderTimePassProfiles() diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py index b4cbd5563cda..1a2854615f59 100644 --- a/python/tvm/ir/op.py +++ b/python/tvm/ir/op.py @@ -85,17 +85,71 @@ def reset_attr(self, attr_name): """ _ffi_api.OpResetAttr(self, attr_name) + def add_type_rel(self, rel_name, type_rel_func=None): + """Attach the type function corresponding to the return type. -def register_op(op_name): - """Register an operator by name + Parameters + ---------- + rel_name : str + The type relation name to register. + + type_rel_func : Optional[function (args: List[Type], attrs: Attrs) -> Type] + The backing relation function which can solve an arbitrary relation on variables. + Differences with type_rel_func in C++: + 1, when type_rel_func is not None: + 1) OpAddTypeRel on C++ side will adjust type_rel_func with TypeReporter to + calling convention of relay type system. + 2) type_rel_func returns output argument's type, return None means can't + infer output's type. + 3) only support single output operators for now, the last argument is output tensor. + 2, when type_rel_func is None, will call predefined type_rel_funcs in relay + accorrding to `tvm.relay.type_relation.` + rel_name. + """ + _ffi_api.OpAddTypeRel(self, rel_name, type_rel_func) - Parameters - ---------- - op_name : str - The name of new operator - """ + def add_argument(self, name, type, description): # pylint: disable=redefined-builtin + """Add arguments information to the function. - _ffi_api.RegisterOp(op_name) + Parameters + ---------- + name : str + The argument name. + type : str + The argument type. + description : str + The argument description. + """ + _ffi_api.OpAddArgument(self, name, type, description) + + def set_support_level(self, level): + """Set the support level of op. + + Parameters + ---------- + level : int + The support level. + """ + _ffi_api.OpSetSupportLevel(self, level) + + def set_num_inputs(self, n): + """Set the support level of op. + + Parameters + ---------- + n : int + The input number. + """ + _ffi_api.OpSetNumInputs(self, n) + + def set_attrs_type_key(self, key): + """Set the attribute type key of op. + + Parameters + ---------- + key : str + The type key. + """ + _ffi_api.OpSetAttrsTypeKey(self, key) def register_op_attr(op_name, attr_key, value=None, level=10): diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 36e06eeb8b23..9296244f6cfe 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -65,12 +65,20 @@ class PassContext(tvm.runtime.Object): disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are disabled. + instruments : Optional[Sequence[PassInstrument]] + The list of pass instrument implementations. + config : Optional[Dict[str, Object]] Additional configurations for specific passes. """ def __init__( - self, opt_level=2, required_pass=None, disabled_pass=None, trace=None, config=None + self, + opt_level=2, + required_pass=None, + disabled_pass=None, + instruments=None, + config=None, ): required = list(required_pass) if required_pass else [] if not isinstance(required, (list, tuple)): @@ -80,9 +88,13 @@ def __init__( if not isinstance(disabled, (list, tuple)): raise TypeError("disabled_pass is expected to be the type of " + "list/tuple/set.") + instruments = list(instruments) if instruments else [] + if not isinstance(instruments, (list, tuple)): + raise TypeError("instruments is expected to be the type of " + "list/tuple/set.") + config = config if config else None self.__init_handle_by_constructor__( - _ffi_transform_api.PassContext, opt_level, required, disabled, trace, config + _ffi_transform_api.PassContext, opt_level, required, disabled, instruments, config ) def __enter__(self): @@ -92,11 +104,33 @@ def __enter__(self): def __exit__(self, ptype, value, trace): _ffi_transform_api.ExitPassContext(self) + def override_instruments(self, instruments): + """Override instruments within this PassContext. + + If there are existing instruments, their exit_pass_ctx callbacks are called. + Then switching to new instruments and calling new enter_pass_ctx callbacks. + + instruments : Sequence[PassInstrument] + The list of pass instrument implementations. + """ + _ffi_transform_api.OverrideInstruments(self, instruments) + @staticmethod def current(): """Return the current pass context.""" return _ffi_transform_api.GetCurrentPassContext() + @staticmethod + def list_configs(): + """List all registered `PassContext` configuration names and metadata. + + Returns + ------- + configs : Dict[str, Dict[str, str]] + + """ + return _ffi_transform_api.ListConfigs() + @tvm._ffi.register_object("transform.Pass") class Pass(tvm.runtime.Object): @@ -189,6 +223,7 @@ def __init__(self, *args, **kwargs): # initialize handle in cass pass_cls creation failed.fg self.handle = None inst = pass_cls(*args, **kwargs) + # it is important not to capture self to # avoid a cyclic dependency def _pass_func(mod, ctx): @@ -330,26 +365,3 @@ def PrintIR(header="", show_meta_data=False): The pass """ return _ffi_transform_api.PrintIR(header, show_meta_data) - - -def render_pass_profiles(): - """Returns a string render of the pass profiling data. The format of each output line is - `{name}: {time} [{time excluding sub-passes}] ({% of total}; {% of parent})`. - The indentation of each line corresponds to nesting of passes. - """ - return _ffi_transform_api.render_pass_profiles() - - -def clear_pass_profiles(): - """Clears all stored pass profiling data.""" - _ffi_transform_api.clear_pass_profiles() - - -def enable_pass_profiling(): - """Enables pass profiling.""" - _ffi_transform_api.enable_pass_profiling() - - -def disable_pass_profiling(): - """Disables pass profiling.""" - _ffi_transform_api.disable_pass_profiling() diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index d95f14f0349e..910b0ce1721f 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -53,10 +53,6 @@ def path(self): return self.tempdir.temp_dir -# Required C runtime libraries, in link order. -CRT_RUNTIME_LIB_NAMES = ["utvm_rpc_server", "utvm_rpc_common", "common"] - - STANDALONE_CRT_DIR = None @@ -110,9 +106,17 @@ def get_standalone_crt_lib(name: str) -> str: return os.path.join(get_standalone_crt_dir(), "src", "runtime", "crt", name) -def get_runtime_libs() -> str: - """Return abspath to all CRT directories which contain source (i.e. not header) files.""" - return [get_standalone_crt_lib(n) for n in CRT_RUNTIME_LIB_NAMES] +def get_runtime_libs(executor: str) -> str: + """Return abspath to all CRT directories in link order which contain + source (i.e. not header) files. + """ + if executor == "host-driven": + crt_runtime_lib_names = ["utvm_rpc_server", "utvm_rpc_common", "common"] + elif executor == "aot": + crt_runtime_lib_names = ["aot_executor", "common"] + else: + raise ValueError(f"Incorrect executor: {executor}") + return [get_standalone_crt_lib(n) for n in crt_runtime_lib_names] RUNTIME_SRC_REGEX = re.compile(r"^.*\.cc?$", re.IGNORECASE) @@ -188,6 +192,7 @@ def build_static_runtime( compiler, module, compiler_options, + executor=None, extra_libs=None, ): """Build the on-device runtime, statically linking the given modules. @@ -206,6 +211,10 @@ def build_static_runtime( used. This dict contains the `options` parameter passed to Compiler.library() and Compiler.binary() at various stages in the compilation process. + executor : Optional[str] + Executor used for runtime. Based on this we determine the libraries that need to be + linked with runtime. + extra_libs : Optional[List[MicroLibrary|str]] If specified, extra libraries to be compiled into the binary. If a MicroLibrary, it is included into the binary directly. If a string, the path to a directory; all direct children @@ -221,8 +230,11 @@ def build_static_runtime( os.makedirs(mod_build_dir) mod_src_dir = workspace.relpath(os.path.join("src", "module")) + if not executor: + executor = "host-driven" + libs = [] - for mod_or_src_dir in (extra_libs or []) + get_runtime_libs(): + for mod_or_src_dir in (extra_libs or []) + get_runtime_libs(executor): if isinstance(mod_or_src_dir, MicroLibrary): libs.append(mod_or_src_dir) continue diff --git a/python/tvm/micro/contrib/zephyr.py b/python/tvm/micro/contrib/zephyr.py index b7d7496b7440..fedd470a8350 100644 --- a/python/tvm/micro/contrib/zephyr.py +++ b/python/tvm/micro/contrib/zephyr.py @@ -172,6 +172,17 @@ def library(self, output, sources, options=None): project_dir_conf = os.path.join(self._project_dir, "prj.conf") if os.path.exists(project_dir_conf): shutil.copy(project_dir_conf, lib_prj_conf) + + # Copy board-specific Zephyr config file from the project_dir to + # the build lib dir so board-specific configs can be found and used by + # Zephyr's build system in conjunction with the generic prj.conf configs. + board_conf = os.path.join("boards", self._board + ".conf") + project_dir_board_conf = os.path.join(self._project_dir, board_conf) + if os.path.exists(project_dir_board_conf): + os.mkdir(os.path.join(output, "boards")) + lib_dir_board_conf = os.path.join(output, board_conf) + shutil.copy(project_dir_board_conf, lib_dir_board_conf) + else: with open(lib_prj_conf, "w") as prj_conf_f: prj_conf_f.write("CONFIG_CPLUSPLUS=y\n") diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index be991e22a0f8..1cc3adf9ae07 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -216,6 +216,11 @@ def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, fil The return value of tvm.relay.build, which will be exported into Model Library Format. file_name : str Path to the .tar archive to generate. + + Returns + ------- + file_name : str + The path to the generated .tar archive. """ tempdir = utils.tempdir() is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) @@ -260,3 +265,5 @@ def reset(tarinfo): return tarinfo tar_f.add(tempdir.temp_dir, arcname=".", filter=reset) + + return file_name diff --git a/python/tvm/relay/analysis/feature.py b/python/tvm/relay/analysis/feature.py index 99e2cdc785e6..0e264a0eef7d 100644 --- a/python/tvm/relay/analysis/feature.py +++ b/python/tvm/relay/analysis/feature.py @@ -20,7 +20,7 @@ class Feature(IntEnum): - """ The features a program might contain. """ + """The features a program might contain.""" fVar = 0 fGlobalVar = 1 diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 9460e23a5357..7378ed6beb8a 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -20,45 +20,6 @@ from tvm.target import Target -@tvm._ffi.register_func("relay.backend.lower") -def lower(sch, inputs, func_name, source_func): - """Backend function for lowering. - - Parameters - ---------- - sch : tvm.te.Schedule - The schedule. - - inputs : List[tvm.te.Tensor] - The inputs to the function. - - func_name : str - The name of the function. - - source-func : tvm.relay.Function - The source function to be lowered. - - Returns - ------- - mod : tvm.IRModule - The result of lowering. - """ - # pylint: disable=broad-except, import-outside-toplevel - import traceback - - try: - f = tvm.driver.lower(sch, inputs, name=func_name) - # logging.debug("lower function %s", func_name) - # logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) - except Exception: - msg = traceback.format_exc() - msg += "Error during compile function\n" - msg += "-----------------------------\n" - msg += source_func.astext() - raise RuntimeError(msg) - return f - - @tvm._ffi.register_func("relay.backend.build") def build(mod, target, target_host=None): """Backend build function. diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py index 4ed76f4b6366..701ca06a87e0 100644 --- a/python/tvm/relay/backend/executor_factory.py +++ b/python/tvm/relay/backend/executor_factory.py @@ -31,7 +31,7 @@ class ExecutorFactoryModule: @abstractmethod def get_executor_config(self): - """ Return the internal configuration the executor uses to execute the network """ + """Return the internal configuration the executor uses to execute the network""" raise NotImplementedError @abstractmethod @@ -41,7 +41,7 @@ def get_params(self): @abstractmethod def get_lib(self): - """ Return the generated library""" + """Return the generated library""" raise NotImplementedError def __getitem__(self, item): diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 0b6d1372d050..363ff893df8b 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -198,20 +198,26 @@ def _update_target(self, target): target = target if target else tvm.target.Target.current() if target is None: raise ValueError("Target is not set in env or passed as argument.") - tgts = {} - if isinstance(target, (str, tvm.target.Target)): - dev_type = tvm.tir.IntImm("int32", tvm.nd.device(str(target)).device_type) - tgts[dev_type] = tvm.target.Target(target) - elif isinstance(target, dict): - for dev, tgt in target.items(): - dev_type = tvm.tir.IntImm("int32", tvm.nd.device(dev).device_type) - tgts[dev_type] = tvm.target.Target(tgt) - else: + + if isinstance(target, str): + target = {target: target} + elif isinstance(target, tvm.target.Target): + target = {target.kind.name: target} + elif not isinstance(target, dict): raise TypeError( "target is expected to be str, tvm.target.Target, " + "or dict of str to str/tvm.target.Target, but received " + "{}".format(type(target)) ) + + tgts = {} + for dev, tgt in target.items(): + dev_type = tvm.tir.IntImm("int32", tvm.nd.device(dev).device_type) + if isinstance(tgt, str): + tgt = tvm.target.Target(tgt) + + tgts[dev_type] = tgt + return tgts def _update_target_host(self, target, target_host): diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index b368f4e5175e..320a599d5d91 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -547,11 +547,11 @@ class CallPattern(DFPattern): Parameters ---------- - op: realy.dataflow_pattern.DFPattern + op: relay.dataflow_pattern.DFPattern The operation to be called. - args: List[realy.dataflow_pattern.DFPattern] - The arguments to the call. + args: List[relay.dataflow_pattern.DFPattern] + The arguments to the call or None to match any arguments. """ @@ -569,10 +569,10 @@ class FunctionPattern(DFPattern): Parameters ---------- - params: List[realy.dataflow_pattern.DFPattern] - The parameters to the Function. + params: List[relay.dataflow_pattern.DFPattern] + The parameters to the Function or None to match any parameters. - body: realy.dataflow_pattern.DFPattern + body: relay.dataflow_pattern.DFPattern The body fo the Function """ @@ -886,7 +886,7 @@ def partition( Parameters ---------- - partion: tvm.relay.dataflow_pattern.DFPattern + pattern: tvm.relay.dataflow_pattern.DFPattern The pattern to match expr : tvm.relay.Expr The expression to split into functions diff --git a/python/tvm/relay/frontend/caffe.py b/python/tvm/relay/frontend/caffe.py index caf4f1a14741..6a3ca0849f2b 100644 --- a/python/tvm/relay/frontend/caffe.py +++ b/python/tvm/relay/frontend/caffe.py @@ -33,7 +33,7 @@ class OperatorConverter(object): - """ Operator Converted for converting Caffe ops to Relay ops """ + """Operator Converted for converting Caffe ops to Relay ops""" def __init__(self, init_layer_dict, predict_layer, exp_tab): self.init_layer_dict = init_layer_dict @@ -66,7 +66,7 @@ def __init__(self, init_layer_dict, predict_layer, exp_tab): } def convert_flatten(self, op): - """ Convert Flatten layer """ + """Convert Flatten layer""" inputs = op.bottom in_expr = self.exp_tab.get_expr(inputs[0]) @@ -77,7 +77,7 @@ def convert_flatten(self, op): return out def convert_eltwise(self, op): - """ Convert Eltwise layer """ + """Convert Eltwise layer""" inputs = op.bottom assert len(inputs) == 2, "input tensors length should be 2" @@ -115,7 +115,7 @@ def convert_eltwise(self, op): return out def _parse_conv_params(self, op): - """ Parse the parameters of Convolution and Deconvolution layer """ + """Parse the parameters of Convolution and Deconvolution layer""" nonzone = lambda val, pos, dflt: val[pos] if pos < len(val) else dflt conv_params = op.convolution_param @@ -160,7 +160,7 @@ def _parse_conv_params(self, op): return params def convert_batch_norm(self, op): - """ Convert BatchNorm layer """ + """Convert BatchNorm layer""" inputs = op.bottom in_expr = self.exp_tab.get_expr(inputs[0]) n, c, h, w = _infer_shape(in_expr) @@ -215,7 +215,7 @@ def convert_batch_norm(self, op): return out[0] def convert_scale(self, op): - """ Convert Scale layer """ + """Convert Scale layer""" inputs = op.bottom in_expr = self.exp_tab.get_expr(inputs[0]) weight_bias_blobs = self.init_layer_dict[op.name].blobs @@ -243,7 +243,7 @@ def convert_scale(self, op): return out def convert_concat(self, op): - """ Convert Concat layer """ + """Convert Concat layer""" inputs = op.bottom in_expr = (self.exp_tab.get_expr(inputs[i]) for i in range(len(inputs))) @@ -254,7 +254,7 @@ def convert_concat(self, op): return out def convert_reshape(self, op): - """ Convert Reshape layer """ + """Convert Reshape layer""" inputs = op.bottom input_name = inputs[0] @@ -294,7 +294,7 @@ def convert_reshape(self, op): return out def convert_softmax(self, op): - """ Convert Softmax layer """ + """Convert Softmax layer""" inputs = op.bottom assert len(inputs) == 1, "input tensors length should be 1" @@ -309,7 +309,7 @@ def convert_softmax(self, op): return out def convert_conv(self, op): - """ Convert Convolution layer """ + """Convert Convolution layer""" params = self._parse_conv_params(op) weight_bias_blobs = self.init_layer_dict[op.name].blobs conv_params = op.convolution_param @@ -339,7 +339,7 @@ def convert_conv(self, op): return out def convert_pooling(self, op): - """ Convert Pooling layer """ + """Convert Pooling layer""" inputs = op.bottom input_name = inputs[0] @@ -400,7 +400,7 @@ def convert_pooling(self, op): return out def convert_lrn(self, op): - """ Convert LRN layer """ + """Convert LRN layer""" inputs = op.bottom input_name = inputs[0] @@ -416,7 +416,7 @@ def convert_lrn(self, op): return out def convert_innerproduct(self, op): - """ Convert InnerProduct layer """ + """Convert InnerProduct layer""" inputs = op.bottom weight_bias_blobs = self.init_layer_dict[op.name].blobs dense_params = op.inner_product_param @@ -457,7 +457,7 @@ def convert_innerproduct(self, op): return out def convert_dropout(self, op): - """ Convert Dropout layer """ + """Convert Dropout layer""" inputs = op.bottom input_name = inputs[0] @@ -471,7 +471,7 @@ def convert_dropout(self, op): return out def convert_relu(self, op): - """ Convert ReLU layer """ + """Convert ReLU layer""" inputs = op.bottom in_expr = self.exp_tab.get_expr(inputs[0]) negative_slope = op.relu_param.negative_slope @@ -483,7 +483,7 @@ def convert_relu(self, op): return out def convert_prelu(self, op): - """ Convert PReLU layer """ + """Convert PReLU layer""" inputs = op.bottom in_expr = self.exp_tab.get_expr(inputs[0]) @@ -495,7 +495,7 @@ def convert_prelu(self, op): return out def convert_deconv(self, op): - """ Convert Deconvolution layer """ + """Convert Deconvolution layer""" params = self._parse_conv_params(op) weight_bias_blobs = self.init_layer_dict[op.name].blobs conv_params = op.convolution_param @@ -511,23 +511,76 @@ def convert_deconv(self, op): if weight: kh, kw = params["kernel_size"] weight_shape = [-1, conv_params.num_output, kh, kw] - weight_value = np.asarray(weight.data, np.float32) + if not weight.data: + if conv_params.weight_filler: + _filler = conv_params.weight_filler.value + weight_value = np.full(weight.shape.dim, _filler, np.float32) + else: + raise tvm.error.OpAttributeInvalid("At least weight_filler must be given") + else: + weight_value = np.asarray(weight.data, np.float32) weight_value = np.reshape(weight_value, weight_shape) else: - raise Exception("No weight value of layer {} in caffemodel".format(op.name)) + raise tvm.error.OpAttributeRequired( + "No weight value of layer {} in caffemodel".format(op.name) + ) weight_expr = self.exp_tab.new_const(weight_value, dtype="float32") in_expr = self.exp_tab.get_expr(inputs[0]) - out = _op.nn.conv2d_transpose(data=in_expr, weight=weight_expr, **params) - if bias: + groups = params["groups"] + channels = params["channels"] + + if bias: bias_value = np.asarray(bias.data, np.float32) bias_expr = self.exp_tab.new_const(bias_value, dtype="float32") - out = _op.nn.bias_add(out, bias_expr) + + if groups > channels: + raise tvm.error.OpAttributeInvalid( + "Groups cannot be larger than the number of input channels" + ) + + if groups == channels: + inputs_expr = _op.split(in_expr, groups, axis=1) + weights_expr = _op.split(weight_expr, groups, axis=1) + # Preventing to create Concat layer with too many tensors(> 16) + q = groups >> 4 + r = groups % 16 + + params["groups"] = 1 + params["channels"] = 1 + out = [] + for lc in range(q): + _outputs = [] + _inputs = [inputs_expr[i] for i in range(lc << 4, (lc << 4) + 16)] + _weights = [weights_expr[i] for i in range(lc << 4, (lc << 4) + 16)] + for (i, w) in zip(_inputs, _weights): + _out = _op.nn.conv2d_transpose(data=i, weight=w, **params) + if bias: + _out = _op.nn.bias_add(_out, bias_expr) + _outputs.append(_out) + out.append(_op.concatenate(_outputs, axis=1)) + if r != 0: + _outputs = [] + _inputs = [inputs_expr[i] for i in range(groups - r, groups)] + _weights = [weights_expr[i] for i in range(groups - r, groups)] + for (i, w) in zip(_inputs, _weights): + _out = _op.nn.conv2d_transpose(data=i, weight=w, **params) + if bias: + _out = _op.nn.bias_add(_out, bias_expr) + _outputs.append(_out) + out.append(_op.concatenate(_outputs, axis=1)) + out = _op.concatenate(out, axis=1) + elif groups == 1: + out = _op.nn.conv2d_transpose(data=in_expr, weight=weight_expr, **params) + if bias: + out = _op.nn.bias_add(out, bias_expr) + else: + raise tvm.error.OpAttributeInvalid("Unable to handle.") return out def convert_slice(self, op): - """ Convert Slice layer """ + """Convert Slice layer""" inputs = op.bottom in_expr = self.exp_tab.get_expr(inputs[0]) @@ -545,21 +598,21 @@ def convert_slice(self, op): return out def convert_sigmoid(self, op): - """ Convert Sigmoid layer """ + """Convert Sigmoid layer""" inputs = op.bottom in_expr = self.exp_tab.get_expr(inputs[0]) out = _op.sigmoid(in_expr) return out def convert_tanh(self, op): - """ Convert TanH layer """ + """Convert TanH layer""" inputs = op.bottom in_expr = self.exp_tab.get_expr(inputs[0]) out = _op.tanh(in_expr) return out def convert_crop(self, op): - """ Convert Crop layer """ + """Convert Crop layer""" inputs = op.bottom assert len(inputs) == 2, "Need two inputs of Crop layer" in_expr_a = self.exp_tab.get_expr(inputs[0]) @@ -615,7 +668,7 @@ def check_unsupported_ops(self): raise tvm.error.OpNotImplemented(msg.format(ops)) def fuse_op(self, layers): - """ Fusing the BatchNorm and Scale layer """ + """Fusing the BatchNorm and Scale layer""" bn, scale = layers["bn"], layers["scale"] # bn params @@ -641,7 +694,7 @@ def fuse_op(self, layers): return bn def op_fuse(self): - """fuse bn and scale """ + """fuse bn and scale""" new_layers = [] temp_layers = {} changed_layers = {} diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 55556cf583fa..d0e8c79c6392 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -493,7 +493,7 @@ def infer_type(node, mod=None): def fold_constant(node, mod=None): if mod is None: - mod = IRModule.from_expr(node) + mod = IRModule() return _transform.FoldConstantExpr(node, mod) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3f876f401b3c..c8855b2ea2be 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1341,7 +1341,30 @@ def _impl_v10(cls, inputs, attr, params): axes = inputs[3] steps = inputs[4] - data_rank = len(infer_shape(inputs[0])) + ishape = infer_shape(inputs[0]) + data_rank = len(ishape) + + def has_static_axes(): + return ( + isinstance(axes, _expr.Constant) + and isinstance(starts, _expr.Constant) + and isinstance(ends, _expr.Constant) + and (steps is None or isinstance(steps, _expr.Constant)) + ) + + if axes is not None and has_static_axes(): + axes_np = axes.data.asnumpy().astype("int64") + begin_np = starts.data.asnumpy().astype("int64") + end_np = ends.data.asnumpy().astype("int64") + if steps is None: + strides_np = np.ones_like(begin_np).astype("int64") + else: + strides_np = steps.data.asnumpy().astype("int64") + + if all([isinstance(ishape[i], int) for i in axes_np]): + return _op.strided_slice( + inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np) + ) # Update the starts and ends according to axes if required. if axes is not None: @@ -1416,8 +1439,10 @@ class GatherND(OnnxOpConverter): @classmethod def _impl_common(cls, data, indices, batch_dims=0): indices_dims = len(infer_shape(indices)) + indices_shape = infer_shape(indices) indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) - return _op.gather_nd(data, indices, batch_dims) + index_rank = indices_shape[-1] + return _op.gather_nd(data, indices, batch_dims, index_rank) @classmethod def _impl_v1(cls, inputs, attr, params): @@ -1450,6 +1475,27 @@ def _impl_v11(cls, inputs, attr, params): ) +class EyeLike(OnnxOpConverter): + """Operator converter for EyeLike.""" + + @classmethod + def _impl_v9(cls, inputs, attr, params): + in_checked_type = infer_type(inputs[0]).checked_type + in_dtype = in_checked_type.dtype + in_shape = list(get_const_tuple(in_checked_type.shape)) + dtype = attr.get("dtype", None) + if dtype is None: + dtype = in_dtype + else: + dtype = get_type(dtype) + zeros = _op.zeros(in_shape, dtype) + dim = in_shape[0] + indices = _op.arange(_op.const(0), _op.const(dim), dtype="int32") + ones = _op.full(_op.const(1), (dim,), dtype=dtype) + k = _op.const(attr.get("k", 0), dtype="int32") + return _op.scatter_nd(zeros, _op.stack([indices, indices + k], axis=0), ones, "update") + + class Greater(OnnxOpConverter): """Operator logical greater.""" @@ -2419,6 +2465,12 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): + if len(inputs) == 3 and isinstance(inputs[2], _expr.Constant): + attr["max"] = inputs[2].data.asnumpy().item() + inputs = inputs[0:2] + if len(inputs) >= 2 and isinstance(inputs[1], _expr.Constant): + attr["min"] = inputs[1].data.asnumpy().item() + inputs = inputs[0:1] if "min" in attr and "max" in attr: return Clip.convert_attributes(inputs, attr, params) @@ -2953,6 +3005,39 @@ def _impl_v11(cls, inputs, attr, params): return out +class Unique(OnnxOpConverter): + """Operator converter for unique""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + if len(inputs) != 1: + raise ValueError("Unique expects 1 input") + + data = inputs[0] + axis = attr.get("axis", None) + if axis is None: # If axis is None, flatten the input before calling unique + data = _op.reshape(data, _op.const([-1])) + else: + data_shape = infer_shape(data) + if len(data_shape) != 1: + raise ValueError("TVM only supports 1D Unique operator.") + is_sorted = attr.get("sorted", 1) # sorted is 0 or 1, 1 by default + + # ONNX documentation lists return_counts as optional but there is no input to specify + # whether it is returned. Therefore we'll just always return it. + unique = _op.unique(data, is_sorted=(is_sorted == 1), return_counts=True) + num_unique = unique[3] + + trim_unique_lambda = lambda input: _op.strided_slice(input, _op.const([0]), num_unique) + + unique_vals = trim_unique_lambda(unique[0]) + indices = trim_unique_lambda(unique[1]) + inverse_indices = unique[2] + counts = trim_unique_lambda(unique[4]) + # ONNX unique returns unique, indices, inverse_indices, (optional) counts + return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -3100,6 +3185,7 @@ def _get_convert_map(opset): "Scatter": Scatter.get_converter(opset), "ScatterElements": Scatter.get_converter(opset), "ScatterND": ScatterND.get_converter(opset), + "EyeLike": EyeLike.get_converter(opset), "Squeeze": AttrCvt("squeeze", {"axes": "axis"}), "Unsqueeze": Unsqueeze.get_converter(opset), "Pad": Pad.get_converter(opset), @@ -3116,6 +3202,7 @@ def _get_convert_map(opset): "NonZero": NonZero.get_converter(opset), "Range": Range.get_converter(opset), "CumSum": CumSum.get_converter(opset), + "Unique": Unique.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), "If": If.get_converter(opset), @@ -3304,6 +3391,12 @@ def from_onnx(self, graph, opset, get_output_expr=False): outputs_num = 1 else: outputs_num = len(op) + + if outputs_num == 1: + op = fold_constant(op) + else: + op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op)) + if outputs_num > 1: # ONNX supports optional outputs for some nodes. # This block searches for missing outputs in the ONNX graph @@ -3325,8 +3418,8 @@ def from_onnx(self, graph, opset, get_output_expr=False): # Create the new op with valid outputs if len(outputs) == 1: op = outputs[0] - else: - op = _expr.TupleWrapper(outputs, len(outputs)) + elif len(outputs) != outputs_num: + op = _expr.TupleWrapper(_expr.Tuple(outputs), len(outputs)) # Drop invalid outputs for the onnx node outputs_num = len(outputs) node_output = [output for output in node_output if output != ""] @@ -3335,10 +3428,10 @@ def from_onnx(self, graph, opset, get_output_expr=False): ), "Number of output mismatch {} vs {} in {}.".format( len(node_output), outputs_num, op_name ) + if outputs_num == 1: - self._nodes[node_output[0]] = fold_constant(op) + self._nodes[node_output[0]] = op else: - op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op)) for k, i in zip(list(node_output), range(len(node_output))): self._nodes[k] = op[i] diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b5cfcf5e3bac..acc33d73e826 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -754,9 +754,13 @@ def relu(self, inputs, input_types): return _op.nn.relu(data) def prelu(self, inputs, input_types): + # Reference: https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html#torch.nn.PReLU data = inputs[0] - alpha = inputs[1] - return _op.nn.prelu(data, alpha) + dim = self.get_dims(data) + ndims = len(dim) + axis = 0 if ndims == 1 else 1 + alpha = _op.broadcast_to(inputs[1], (dim[axis])) + return _op.nn.prelu(data, alpha, axis) def leaky_relu(self, inputs, input_types): data = inputs[0] @@ -2294,16 +2298,18 @@ def unique(self, inputs, input_types): logging.warning("TVM always assumes sorted=True for torch.unique") is_sorted = True if return_counts: - [unique, indices, num_uniq, counts] = _op.unique( + [unique, indices, inverse_indices, num_uniq, counts] = _op.unique( data, is_sorted=is_sorted, return_counts=True ) unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") - return (unique_sliced, indices, counts_sliced) + return (unique_sliced, inverse_indices, counts_sliced) else: - [unique, indices, num_uniq] = _op.unique(data, is_sorted=is_sorted, return_counts=False) + [unique, indices, inverse_indices, num_uniq] = _op.unique( + data, is_sorted=is_sorted, return_counts=False + ) unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") - return (unique_sliced, indices) + return (unique_sliced, inverse_indices) # Operator mappings def create_convert_map(self): diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4af73702ad9c..0bdec953a540 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -19,2707 +19,30 @@ """TF: Tensorflow frontend.""" import warnings from collections import defaultdict -from collections import deque # Numpy support import numpy as np import tvm from tvm.ir import IRModule -from tvm.relay.prelude import Prelude, StaticTensorArrayOps, get_tensor_array_shape +from tvm.relay.prelude import Prelude from tvm.relay.transform import InferType -from tvm.topi.utils import get_const_tuple from .. import analysis from .. import expr as _expr from .. import function as _function -from .. import op as _op from ..ty import Any from ..expr_functor import ExprMutator, ExprVisitor -from .common import AttrCvt, get_relay_op +from .common import get_relay_op from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape -from .common import infer_channels as _infer_channels from .common import infer_value as _infer_value -__all__ = ["from_tensorflow"] - - -def check_symbolic_shape(shape): - return not all([isinstance(dim, (int, tvm.tir.IntImm)) for dim in shape]) - - -def list_shape_of(tensor, ndim): - shape_tensor = _op.shape_of(tensor) - return [ - _op.strided_slice(shape_tensor, begin=[i], end=[i + 1], strides=[1]) for i in range(ndim) - ] - - -def _get_pad_pair(input1d, kernel1d, stride1d): - if isinstance(input1d, tvm.tir.Any) and stride1d != 1: - raise tvm.error.OpAttributeUnImplemented( - "SAME padding is not supported in combination with dynamic height or width when stride" - " is not 1." - ) - if stride1d == 1 or input1d % stride1d == 0: - pad = max(kernel1d - stride1d, 0) - else: - pad = max(kernel1d - (input1d % stride1d), 0) - - pad_before = pad // 2 - pad_after = pad - pad_before - - return [pad_before, pad_after] - - -def _math_name_picker(surfix): - def _impl(attr): - return "broadcast_" + surfix - - return _impl - - -def _dimension_picker(prefix, surfix=""): - def _impl(attr): - kernel = attr["kernel_shape"] - if len(kernel) == 2: - return prefix + "2d" + surfix - if len(kernel) == 3: - return prefix + "3d" + surfix - raise tvm.error.OpAttributeInvalid( - "Only 2D or 3D kernels are supported for operator {}".format(prefix + "2d or 3d") - ) - - return _impl - - -def _dimension_constraint(): - def _dim_check(attrs): - if len(attrs["kernel_shape"]) in (2, 3): - return True - return False - - return _dim_check, "Only 2d or 3d kernel supported." - - -def _get_param(params, input_node): - if isinstance(input_node, _expr.Constant): - return np.atleast_1d(input_node.data.numpy()) - return params[input_node.name_hint].numpy() - - -def _get_num_param(params, input_node): - return _get_param(params, input_node).item() - - -def _get_list_param(params, input_node, mod): - try: - return _get_param(params, input_node).tolist() - except (IndexError, KeyError, AttributeError): - return _infer_value(input_node, params, mod).numpy().tolist() - - -def _get_tuple_param(params, input_node): - return tuple(_get_param(params, input_node)) - - -def _need_prelude_for_shape_inference(op): - return "TensorArray" in op - - -def _get_more_static_shape(shape0, shape1): - """Compare two shapes with the same rank, - and return the one with fewer symbolic dimension. - """ - assert len(shape0) == len(shape1) - num_sym_dim0 = 0 - num_sym_dim1 = 0 - for dim0, dim1 in zip(list(shape0), list(shape1)): - if not isinstance(dim0, int): - num_sym_dim0 += 1 - if not isinstance(dim1, int): - num_sym_dim1 += 1 - - if num_sym_dim0 < num_sym_dim1: - return shape0 - return shape1 - - -def _rsqrt(): - def _impl(inputs, attr, params, mod): - inputs.append(tvm.relay.const(-0.5, attr["T"].name)) - return AttrCvt(op_name="power")(inputs, attr) - - return _impl - - -def _argx(func, func_name): - """A common wrapper for argmin and argmax operations""" - - def _impl(inputs, attr, params, mod): - try: - # In Tensorflow, `axis` argument is a Tensor, not attribute. We - # support the case where it inputs from a scalar constant. - axis_input_value = [_get_num_param(params, inputs[1])] - except (IndexError, KeyError): - raise TypeError( - "Unsupported argument for `{}` : `axis` should be a constant".format(func_name) - ) - out = func(inputs[0], axis=axis_input_value, keepdims=False) - dtype = attr["output_type"].name - if dtype != "int32": - out = _op.cast(out, dtype=dtype) - return out - - return _impl - - -def _elemwise(name): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) - return get_relay_op(name)(*inputs) - - return _impl - - -def _pool3d(name): - def _impl(inputs, attr, params, mod): - attr["data_format"] = attr["data_format"].decode("utf-8") - flip_layout = False - - input_shape = _infer_shape(inputs[0], mod) - - if attr["data_format"] == "NDHWC": - attr["kernel_shape"] = (attr["ksize"][1], attr["ksize"][2], attr["ksize"][3]) - attr["strides"] = (attr["strides"][1], attr["strides"][2], attr["strides"][3]) - elif attr["data_format"] == "NCDHW": - attr["kernel_shape"] = (attr["ksize"][2], attr["ksize"][3], attr["ksize"][4]) - attr["strides"] = (attr["strides"][2], attr["strides"][3], attr["strides"][4]) - else: - msg = 'Value {} of attribute "data_format" of operator Pooling ' "is not valid." - raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"])) - if attr["data_format"] == "NDHWC": - input_shape = [_infer_shape(inputs[0], mod)[i] for i in (0, 4, 1, 2, 3)] - inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3)) - attr["data_format"] = "NCDHW" - flip_layout = True - - attr["padding"] = attr["padding"].decode("utf-8") - - if attr["padding"] == "VALID": - attr["padding"] = [0, 0, 0, 0, 0, 0] - elif attr["padding"] == "SAME": - stride_d, stride_h, stride_w = attr["strides"] - kernel_d, kernel_h, kernel_w = attr["kernel_shape"] - if attr["data_format"] == "NDHWC": - in_d = input_shape[1] - in_h = input_shape[2] - in_w = input_shape[3] - else: - in_d = input_shape[2] - in_h = input_shape[3] - in_w = input_shape[4] - pad_d = _get_pad_pair(in_d, kernel_d, stride_d) - pad_v = _get_pad_pair(in_h, kernel_h, stride_h) - pad_h = _get_pad_pair(in_w, kernel_w, stride_w) - - attr["padding"] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]] - else: - msg = 'Value {} in attribute "padding" of operator Pooling is ' "not valid." - raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) - - if name == "avg_pool": - attr["count_include_pad"] = False - attr["ceil_mode"] = False - out = AttrCvt( - op_name=name, - transforms={"kernel_shape": "pool_size", "data_format": "layout"}, - ignores=["ksize"], - )(inputs, attr) - if flip_layout: - out = _op.transpose(out, axes=(0, 2, 3, 4, 1)) - return out - - return _impl - - -def _pooling(name): - def _impl(inputs, attr, params, mod): - - attr["data_format"] = attr["data_format"].decode("utf-8") - flip_layout = False - - input_shape = _infer_shape(inputs[0], mod) - - if attr["data_format"] == "NHWC": - attr["kernel_shape"] = (attr["ksize"][1], attr["ksize"][2]) - attr["strides"] = (attr["strides"][1], attr["strides"][2]) - elif attr["data_format"] == "NCHW": - attr["kernel_shape"] = (attr["ksize"][2], attr["ksize"][3]) - attr["strides"] = (attr["strides"][2], attr["strides"][3]) - else: - msg = 'Value {} of attribute "data_format" of operator Pooling ' "is not valid." - raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"])) - - if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC": - tmp_shape = _infer_shape(inputs[0], mod) - input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] - inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) - attr["data_format"] = "NCHW" - flip_layout = True - - # Fix padding - attr["padding"] = attr["padding"].decode("utf-8") - - if attr["padding"] == "VALID": - attr["padding"] = [0, 0] - elif attr["padding"] == "SAME": - stride_h, stride_w = attr["strides"] - kernel_h, kernel_w = attr["kernel_shape"] - if attr["data_format"] == "NHWC": - in_h = input_shape[1] - in_w = input_shape[2] - else: - in_h = input_shape[2] - in_w = input_shape[3] - - pad_v = _get_pad_pair(in_h, kernel_h, stride_h) - pad_h = _get_pad_pair(in_w, kernel_w, stride_w) - - attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] - elif attr["padding"] == "EXPLICIT": - paddings = attr["explicit_paddings"] - assert len(paddings) == 8 - if flip_layout or attr["data_format"] == "NHWC": - attr["padding"] = [paddings[2], paddings[4], paddings[3], paddings[5]] - else: - attr["padding"] = [paddings[4], paddings[6], paddings[5], paddings[7]] - else: - msg = 'Value {} in attribute "padding" of operator Pooling is ' "not valid." - raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) - - if name == "avg_pool": - attr["count_include_pad"] = False - - out = AttrCvt( - op_name=_dimension_picker(name), - transforms={"kernel_shape": "pool_size", "data_format": "layout"}, - ignores=["ksize", "explicit_paddings"], - extras={"ceil_mode": False}, - custom_check=_dimension_constraint(), - )(inputs, attr) - - if flip_layout: - out = _op.transpose(out, axes=(0, 2, 3, 1)) - - return out - - return _impl - - -def _conv(opname): - def _impl(inputs, attr, params, mod): - attr["data_format"] = attr["data_format"].decode("utf-8") - flip_layout = False - - if opname == "conv_transpose" and attr["data_format"] == "NHWC": - # transform to NCHW for TVM backend compatible and set 'flip_layout' - # to have output flip back to NHWC - inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2)) - attr["strides"][1], attr["strides"][2], attr["strides"][3] = ( - attr["strides"][3], - attr["strides"][1], - attr["strides"][2], - ) - attr["data_format"] = "NCHW" - - # Check whether output shapes attribute is set and not None - if ( - opname == "conv_transpose" - and len(attr["_output_shapes"]) > 0 - and attr["_output_shapes"][0] - ): - tmp_shape = attr["_output_shapes"][0] - tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] - attr["_output_shapes"][0] = tmp_shape - - flip_layout = True - - inputs_data = inputs[0] if opname != "conv_transpose" else inputs[2] - - # NCHW Layout require weights transpose - weights_shape = _infer_shape(inputs[1], mod) - if attr["data_format"] == "NCHW": - tmp_shape = weights_shape - if opname in ["conv", "conv_transpose"]: - tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] - inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) - else: - tmp_shape = [tmp_shape[ii] for ii in (2, 3, 0, 1)] - inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) - weights_shape = tmp_shape - - input_shape = _infer_shape(inputs_data, mod) - if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC": - input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] - inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2)) - if opname in ["conv", "conv_transpose"]: - weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)] - inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) - else: - weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)] - inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) - - attr["data_format"] = "NCHW" - attr["strides"] = [attr["strides"][ii] for ii in (0, 3, 1, 2)] - flip_layout = True - - if attr["data_format"] == "NHWC": - in_channels = input_shape[3] - kernel_h, kernel_w, _, depth_mult = weights_shape - attr["kernel_shape"] = (weights_shape[0], weights_shape[1]) - if opname == "conv": - attr["channels"] = weights_shape[3] - elif opname == "conv_transpose": - attr["channels"] = weights_shape[2] - else: - attr["channels"] = input_shape[3] * depth_mult - - if "dilations" in attr: - attr["dilations"] = (attr["dilations"][1], attr["dilations"][2]) - attr["strides"] = (attr["strides"][1], attr["strides"][2]) - elif attr["data_format"] == "NCHW": - in_channels = input_shape[1] - _, depth_mult, kernel_h, kernel_w = weights_shape - attr["kernel_shape"] = (weights_shape[2], weights_shape[3]) - if opname == "conv": - attr["channels"] = weights_shape[0] - elif opname == "conv_transpose": - attr["channels"] = weights_shape[1] - else: - attr["channels"] = input_shape[1] * depth_mult - if attr["channels"] < 0: - attr["channels"] *= -1 - - if "dilations" in attr: - attr["dilations"] = (attr["dilations"][2], attr["dilations"][3]) - attr["strides"] = (attr["strides"][2], attr["strides"][3]) - else: - msg = 'Value {} in attribute "data_format" of operator Conv is ' "not valid." - raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"])) - - if opname == "depthwise": - attr["groups"] = in_channels - - # Fix padding - attr["padding"] = attr["padding"].decode("utf-8") - - if attr["padding"] == "VALID": - attr["padding"] = [0, 0] - elif attr["padding"] == "SAME": - stride_h, stride_w = attr["strides"] - kernel_h, kernel_w = attr["kernel_shape"] - - pdata_shape = input_shape - # Check whether output shapes attribute is set and not None - if ( - opname == "conv_transpose" - and len(attr["_output_shapes"]) > 0 - and attr["_output_shapes"][0] - ): - pdata_shape = attr["_output_shapes"][0] - - if attr["data_format"] == "NHWC": - in_h = pdata_shape[1] - in_w = pdata_shape[2] - else: - in_h = pdata_shape[2] - in_w = pdata_shape[3] - - dilation_h = attr["dilations"][0] - dilation_w = attr["dilations"][1] - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) - pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) - - attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] - elif attr["padding"] == "EXPLICIT": - paddings = attr["explicit_paddings"] - assert len(paddings) == 8 - if flip_layout or attr["data_format"] == "NHWC": - attr["padding"] = [paddings[2], paddings[4], paddings[3], paddings[5]] - else: - attr["padding"] = [paddings[4], paddings[6], paddings[5], paddings[7]] - else: - msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid." - raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) - - if "kernel_layout" not in attr: - if opname in ["conv", "conv_transpose"]: - attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW" - else: - attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW" - - # Ignore the new attributes from TF2.0, for now. - out = AttrCvt( - op_name=_dimension_picker( - "conv", surfix="_transpose" if opname == "conv_transpose" else "" - ), - ignores=["explicit_paddings"], - transforms={ - "kernel_shape": "kernel_size", - "data_format": "data_layout", - "dilations": ("dilation", (0, 0)), - "group": ("groups", 1), - }, - custom_check=_dimension_constraint(), - )([inputs_data, inputs[1]], attr) - - if flip_layout: - out = _op.transpose(out, axes=(0, 2, 3, 1)) - - return out - - return _impl - - -# Dilation2d -def _dilation2d(): - def _impl(inputs, attr, params, mod): - if "data_format" not in attr: - attr["data_format"] = "NHWC" - - input_shape = _infer_shape(inputs[0], mod) - weights_shape = _infer_shape(inputs[1], mod) - - if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC": - input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] - inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) - weights_shape = [weights_shape[ii] for ii in (2, 0, 1)] - inputs[1] = _op.transpose(inputs[1], axes=(2, 0, 1)) - attr["data_format"] = "NCHW" - - if attr["data_format"] in ["NHWC", "NCHW"]: - if "rates" in attr: - attr["dilations"] = attr["rates"] - if "dilations" in attr: - attr["dilations"] = (attr["dilations"][1], attr["dilations"][2]) - attr["strides"] = (attr["strides"][1], attr["strides"][2]) - else: - msg = 'Value {} in attribute "data_format" of operator Dilation2D is ' "not valid." - raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"])) - - attr["padding"] = attr["padding"].decode("utf-8") - if attr["padding"] == "VALID": - attr["padding"] = [0, 0] - elif attr["padding"] == "SAME": - stride_h, stride_w = attr["strides"] - if attr["data_format"] == "NHWC": - kernel_h, kernel_w = weights_shape[0], weights_shape[1] - else: - kernel_h, kernel_w = weights_shape[1], weights_shape[2] - if attr["data_format"] == "NHWC": - in_h = input_shape[1] - in_w = input_shape[2] - else: - in_h = input_shape[2] - in_w = input_shape[3] - - dilation_h = attr["dilations"][0] - dilation_w = attr["dilations"][1] - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) - pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) - - if attr["data_format"] == "NHWC": - inputs[0] = _op.nn.pad( - data=inputs[0], - pad_width=((0, 0), (pad_v[0], pad_v[1]), (pad_h[0], pad_h[1]), (0, 0)), - ) - else: - inputs[0] = _op.nn.pad( - data=inputs[0], - pad_width=((0, 0), (0, 0), (pad_v[0], pad_v[1]), (pad_h[0], pad_h[1])), - ) - - attr["padding"] = [0, 0] - - else: - msg = 'Value {} in attribute "padding" of operator Dilation2d is not ' "valid." - raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) - - attr["kernel_layout"] = "HWI" if attr["data_format"] == "NHWC" else "IHW" - out = AttrCvt( - op_name="dilation2d", - ignores=["explicit_paddings", "rates"], - transforms={ - "data_format": "data_layout", - }, - )([inputs[0], inputs[1]], attr) - if attr["_target_layout"] == "NCHW": - out = _op.transpose(out, axes=(0, 2, 3, 1)) - return out - - return _impl - - -def _conv3d(opname): - def _impl(inputs, attr, params, mod): - attr["data_format"] = attr["data_format"].decode("utf-8") - flip_layout = False - - inputs_data = inputs[0] if opname != "conv_transpose" else inputs[2] - - # NCDHW Layout require weights transpose - weights_shape = _infer_shape(inputs[1], mod) - if attr["data_format"] == "NCDHW": - tmp_shape = weights_shape - tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)] - inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2)) - weights_shape = tmp_shape - - input_shape = _infer_shape(inputs_data, mod) - - if attr["_target_layout"] == "NCDHW" and attr["data_format"] == "NDHWC": - input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)] - inputs_data = _op.transpose(inputs_data, axes=(0, 4, 1, 2, 3)) - weights_shape = [weights_shape[ii] for ii in (4, 3, 0, 1, 2)] - inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2)) - - attr["data_format"] = "NCDHW" - attr["strides"] = [attr["strides"][ii] for ii in (0, 4, 1, 2, 3)] - flip_layout = True - - if attr["data_format"] == "NDHWC": - kernel_d, kernel_h, kernel_w, _, _ = weights_shape - attr["kernel_shape"] = (kernel_d, kernel_h, kernel_w) - if opname == "conv": - attr["channels"] = weights_shape[4] - elif opname == "conv_transpose": - attr["channels"] = weights_shape[3] - - if "dilations" in attr: - attr["dilations"] = ( - attr["dilations"][1], - attr["dilations"][2], - attr["dilations"][3], - ) - attr["strides"] = (attr["strides"][1], attr["strides"][2], attr["strides"][3]) - elif attr["data_format"] == "NCDHW": - _, _, kernel_d, kernel_h, kernel_w = weights_shape - attr["kernel_shape"] = (kernel_d, kernel_h, kernel_w) - if opname == "conv": - attr["channels"] = weights_shape[0] - elif opname == "conv_transpose": - attr["channels"] = weights_shape[1] - - if "dilations" in attr: - attr["dilations"] = ( - attr["dilations"][2], - attr["dilations"][3], - attr["dilations"][4], - ) - attr["strides"] = (attr["strides"][2], attr["strides"][3], attr["strides"][4]) - else: - msg = 'Value {} in attribute "data_format" of operator Conv is ' "not valid." - raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"])) - - # Fix padding - attr["padding"] = attr["padding"].decode("utf-8") - - if attr["padding"] == "VALID": - attr["padding"] = [0, 0, 0] - elif attr["padding"] == "SAME": - stride_d, stride_h, stride_w = attr["strides"] - kernel_d, kernel_h, kernel_w = attr["kernel_shape"] - - pdata_shape = input_shape - if opname == "conv_transpose" and len(attr["_output_shapes"]) > 0: - pdata_shape = attr["_output_shapes"][0] - - if attr["data_format"] == "NDHWC": - in_d = pdata_shape[1] - in_h = pdata_shape[2] - in_w = pdata_shape[3] - else: - in_d = pdata_shape[2] - in_h = pdata_shape[3] - in_w = pdata_shape[4] - - dilation_d = attr["dilations"][0] - dilation_h = attr["dilations"][1] - dilation_w = attr["dilations"][2] - dilated_kernel_d = (kernel_d - 1) * dilation_d + 1 - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_d = _get_pad_pair(in_d, dilated_kernel_d, stride_d) - pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) - pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) - - attr["padding"] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]] - elif attr["padding"] == "EXPLICIT": - paddings = attr["explicit_paddings"] - assert len(paddings) == 10 - if flip_layout or attr["data_format"] == "NDHWC": - attr["padding"] = [ - paddings[2], - paddings[4], - paddings[6], - paddings[3], - paddings[5], - paddings[7], - ] - else: - attr["padding"] = [ - paddings[4], - paddings[6], - paddings[8], - paddings[5], - paddings[7], - paddings[9], - ] - else: - msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid." - raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) - - if "kernel_layout" not in attr: - attr["kernel_layout"] = "DHWIO" if attr["data_format"] == "NDHWC" else "OIDHW" - - use_bias = len(inputs) == (3 if opname != "conv_transpose" else 4) - channel_axis = 1 if attr["data_format"] == "NCDHW" else 4 - - # Ignore the new attributes from TF2.0, for now. - out = AttrCvt( - op_name=_dimension_picker( - "conv", surfix="_transpose" if opname == "conv_transpose" else "" - ), - ignores=["explicit_paddings", "Tshape"], - transforms={ - "kernel_shape": "kernel_size", - "data_format": "data_layout", - "dilations": ("dilation", (0, 0)), - "group": ("groups", 1), - }, - custom_check=_dimension_constraint(), - )([inputs_data, inputs[1]], attr) - - if use_bias: - out = _op.nn.bias_add( - out, inputs[2] if opname != "conv_transpose" else inputs[3], axis=channel_axis - ) - - if flip_layout: - out = _op.transpose(out, axes=(0, 2, 3, 4, 1)) - - return out - - return _impl - - -def _nms(return_scores=False): - def _impl(inputs, attr, params, mod): - # Get parameter values - try: - max_output_size = int(np.atleast_1d(inputs[2].data.numpy().astype("int64"))[0]) - except Exception: - try: - max_output_size = ( - _infer_value(inputs[2], params, mod).numpy().astype("int64").tolist()[0] - ) - except Exception: - max_output_size = inputs[2] - iou_threshold = np.atleast_1d(inputs[3].data.numpy())[0] - # score_threshold was introduced from V3 - score_threshold = np.atleast_1d(inputs[4].data.numpy())[0] if len(inputs) > 4 else 0.0 - pad_output = "pad_to_max_output_size" - - # Generate data with shape (1, num_anchors, 5) - scores = AttrCvt( - op_name="expand_dims", - ignores=["T_threshold", pad_output], - extras={"axis": -1, "num_newaxis": 1}, - )([inputs[1]], attr) - data = get_relay_op("concatenate")([scores, inputs[0]], -1) - data = get_relay_op("expand_dims")(data, 0, 1) - - # reason why using get_valid_counts is for inference performance - ct, data, indices = get_relay_op("get_valid_counts")( - data, score_threshold=score_threshold, id_index=-1, score_index=0 - ) - # TensorFlow NMS doesn't have parameter top_k - top_k = -1 - # TF doesn't have class id for nms input - score_index = 0 - nms_ret = get_relay_op("non_max_suppression")( - data=data, - valid_count=ct, - indices=indices, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - force_suppress=True, - top_k=top_k, - coord_start=1, - score_index=score_index, - id_index=-1, - return_indices=True, - invalid_to_bottom=False, - ) - - if pad_output in attr and attr[pad_output]: - return nms_ret - # squeeze it, TF NMS is not batched - size = get_relay_op("squeeze")(nms_ret[1], axis=[1]) - data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) - - # slice to get the dynamic result - ret = get_relay_op("strided_slice")( - data_slice, begin=_expr.const([0]), end=size, slice_mode="size" - ) - - # NonMaxSuppressionV5 returns scores. pad_output is always False for NMSv5. - if return_scores: - if "soft_nms_sigma" in attr and attr["soft_nms_sigma"] != 0.0: - raise tvm.error.OpAttributeUnImplemented( - "soft_nms_sigma for NonMaxSuppressionV5 is not supported" - ) - ret_scores = _op.take(inputs[1], ret, axis=0) - return _expr.TupleWrapper(_expr.Tuple([ret, ret_scores, size]), 3) - - return ret - - return _impl - - -def _combined_nms(): - def _impl(inputs, attr, params, mod): - # Get parameter values - boxes = inputs[0] - scores = inputs[1] - try: - max_output_size = int(np.atleast_1d(inputs[2].data.numpy().astype("int64"))[0]) - except Exception: - try: - max_output_size = ( - _infer_value(inputs[2], params, mod).numpy().astype("int64").tolist()[0] - ) - except Exception: - max_output_size = inputs[2] - max_total_size = inputs[3] - iou_threshold = np.atleast_1d(inputs[4].data.numpy())[0] - score_threshold = np.atleast_1d(inputs[5].data.numpy())[0] - if attr["pad_per_class"]: - raise tvm.error.OpAttributeUnImplemented( - "pad_per_class for CombinedNonMaxSuppression is not supported" - ) - boxes_shape = _infer_shape(inputs[0], mod) - scores_shape = _infer_shape(inputs[1], mod) - batch_size = boxes_shape[0] - num_anchors = boxes_shape[1] - q = boxes_shape[2] - num_classes = scores_shape[2] - - if q != num_classes: - # When q is 1, it means same box coords are used for all classes. - boxes = _op.broadcast_to(boxes, (batch_size, num_anchors, num_classes, 4)) - boxes = _op.reshape(boxes, newshape=[batch_size, num_anchors * num_classes, 4]) - scores = _op.reshape(scores, newshape=[batch_size, num_anchors * num_classes, 1]) - - # In TF, class is specified by memory layout only. - ids = _op.arange(_op.const(num_classes, dtype="float32")) - ids = _op.broadcast_to(ids, (batch_size, num_anchors, num_classes)) - ids = _op.reshape(ids, newshape=[batch_size, num_anchors * num_classes, 1]) - - data = _op.concatenate([ids, scores, boxes], -1) - ct, data, indices = _op.vision.get_valid_counts( - data, score_threshold=score_threshold, id_index=0, score_index=1 - ) - nms_ret = _op.vision.non_max_suppression( - data=data, - valid_count=ct, - indices=indices, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - force_suppress=False, - top_k=-1, - coord_start=2, - score_index=1, - id_index=0, - return_indices=False, - invalid_to_bottom=True, - ) - # Dynamic slice to max_total_size - neg_one = _expr.const([-1]) - slice_end = _op.concatenate( - [neg_one, _op.expand_dims(max_total_size, axis=0), neg_one], axis=0 - ) - nms_ret = _op.strided_slice( - nms_ret, begin=[0, 0, 0], end=slice_end, strides=[1, 1, 1], slice_mode="size" - ) - - # Slice output into boxes, scores, classes - nmsed_boxes = _op.strided_slice( - nms_ret, begin=[0, 0, 2], end=[-1, -1, 4], slice_mode="size" - ) - if attr["clip_boxes"]: - nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32")) - nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32")) - nmsed_scores = _op.strided_slice( - nms_ret, begin=[0, 0, 1], end=[-1, -1, 1], slice_mode="size" - ) - nmsed_scores = _op.squeeze(nmsed_scores, axis=[2]) - nmsed_classes = _op.strided_slice( - nms_ret, begin=[0, 0, 0], end=[-1, -1, 1], slice_mode="size" - ) - nmsed_classes = _op.squeeze(nmsed_classes, axis=[2]) - # Get number of valid boxes - nms_count = _op.sum( - _op.cast(_op.greater(nmsed_scores, _expr.const(0, dtype="float32")), "int32"), axis=1 - ) - - # TVM uses -1 for invalid outputs while TF uses 0 - box_range = _op.arange(_expr.const(0, dtype="int32"), max_total_size, dtype="int32") - shape = _op.strided_slice(_op.shape_of(nmsed_boxes), begin=[0], end=[2]) - box_range = _op.broadcast_to(box_range, shape) - valid_mask = _op.cast(_op.less(box_range, _op.expand_dims(nms_count, axis=1)), "float32") - nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2) - # Could instead use mask for scores, classes if negative values are possible. - nmsed_scores = _op.maximum(nmsed_scores, _expr.const(0, dtype="float32")) - nmsed_classes = _op.maximum(nmsed_classes, _expr.const(0, dtype="float32")) - - return _expr.TupleWrapper( - _expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, nms_count]), 4 - ) - - return _impl - - -def _decode_image(): - def _impl(inputs, attr, params, mod): - # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. - warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input") - return inputs[0] - - return _impl - - -def _unravel_index(): - def _impl(inputs, attr, params, mod): - return _op.unravel_index(inputs[0], inputs[1]) - - return _impl - - -def _crop_and_resize(): - def _impl(inputs, attr, params, mod): - # input image is a 4-D tensor of shape [batch, image_height, image_width, depth] - # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2] - crop_size = _get_list_param(params, inputs[3], mod) - - method = attr["method"].decode() - method = "nearest_neighbor" if method == "nearest" else method - if method not in ["bilinear", "nearest_neighbor"]: - raise tvm.error.OpAttributeUnImplemented("Method {} is not supported".format(method)) - layout = attr["layout"] if "layout" in attr else "NHWC" - extrapolation_value = attr["extrapolation_value"] - - return get_relay_op("crop_and_resize")( - inputs[0], inputs[1], inputs[2], crop_size, layout, method, extrapolation_value - ) - - return _impl - - -def _cast(): - def _impl(inputs, attr, params, mod): - return inputs[0].astype(attr["DstT"].name) - - return _impl - - -def _expand_dims(): - def _impl(inputs, attr, params, mod): - dim_input = inputs.pop(1) - axis = _get_num_param(params, dim_input) - return AttrCvt( - op_name="expand_dims", - ignores=["Tdim", "N"], - extras={"axis": int(axis), "num_newaxis": 1}, - )(inputs, attr) - - return _impl - - -def _expm1(): - # op description: https://www.tensorflow.org/api_docs/python/tf/math/expm1 - def _impl(inputs, attr, params, mod): - exp_out = get_relay_op("exp")(inputs[0]) - return exp_out - tvm.relay.const(1.0) - - return _impl - - -def _resize(method): - def _impl(inputs, attr, params, mod): - if attr["_output_shapes"][0] is not None: - size = attr["_output_shapes"][0][1:3] - # Important that the size is defined. If an axis is not, we need to infer what - # the shape should be. - if -1 in size: - size = _infer_value(inputs[1], params, mod).numpy().reshape([-1]).tolist() - else: - size = _infer_value(inputs[1], params, mod).numpy().reshape([-1]).tolist() - - attr["size"] = size - inputs.pop(1) - # NHWC - attr["layout"] = "NHWC" - if attr.pop("align_corners") is True: - attr["coordinate_transformation_mode"] = "align_corners" - else: - attr["coordinate_transformation_mode"] = "asymmetric" - - # Ignore the new attributes from TF2.0, for now. - return AttrCvt( - op_name="resize", ignores=["Tdim", "half_pixel_centers"], extras={"method": method} - )(inputs, attr) - - return _impl - - -def _check_numerics(): - def _impl(inputs, attr, params, mod): - # Making a copy node assuming no need to verify - return AttrCvt(op_name="copy", ignores=["message"])(inputs, attr) - - return _impl - - -def _assert(): - # ToDo: In general people want asserts to be gone from TensorFlow graphs - # when they are optimizing them, so converting it to a no-op is - # reasonable. However, it would be nice to have the option to keep them - # once Relay gets a Halt or Assert op. - return _no_op() - - -def _no_op(): - def _impl(inputs, attr, params, mod): - # ToDo: This should really be an op that returns nothing, which could - # be represented as an empty tuple. It turns out that TVM - # infrastructure doesn't like running functions that return None and - # also don't like running functions that return an empty tuple. So it - # doesn't work, but it should be made to work and then this could be - # improved. In the mean time, it is hard to imagine a case where it - # matters in any real way that a no-op is converted to a constant 0. - return tvm.relay.const(0) - - return _impl - - -def _matmul(): - def _impl(inputs, attr, params, mod): - channels = _infer_channels(inputs[1], not attr["transpose_b"]) - if attr["transpose_a"]: - inputs[0] = _op.transpose(inputs[0], axes=(1, 0)) - if not attr["transpose_b"]: - inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) - return AttrCvt( - op_name="dense", extras={"units": channels}, ignores=["transpose_a", "transpose_b", "T"] - )(inputs, attr) - - return _impl - - -def _batch_matmul(): - def _impl(inputs, attr, params, mod): - input_x = inputs[0] - input_y = inputs[1] - orig_shape_x = _infer_shape(input_x, mod) - orig_shape_y = _infer_shape(input_y, mod) - ndim = len(orig_shape_x) - - is_static = not check_symbolic_shape(orig_shape_x) - - if ndim > 3 and not is_static: - shape_of_x = list_shape_of(inputs[0], ndim) - shape_of_y = list_shape_of(inputs[1], ndim) - - # reshape n-dimensional batch matmul into 3d - if ndim > 3: - outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] - if is_static: - num_outer_elts = np.prod(outer_dims) - new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) - new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) - else: # handle dynamic shape (dyn.reshape op) - # new shape = [prod(shape[:-2]), -2, -1] - new_shape_x = [_op.const(1), shape_of_x[-2], shape_of_x[-1]] - new_shape_y = [_op.const(1), shape_of_y[-2], shape_of_y[-1]] - for i in range(ndim - 2): - new_shape_x[0] *= shape_of_x[i] - new_shape_y[0] *= shape_of_y[i] - new_shape_x = _op.concatenate(_op.Tuple(new_shape_x), axis=0) - new_shape_y = _op.concatenate(_op.Tuple(new_shape_y), axis=0) - - input_x = _op.reshape(input_x, newshape=new_shape_x) - input_y = _op.reshape(input_y, newshape=new_shape_y) - - adj_x = attr["adj_x"] - adj_y = attr["adj_y"] - input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x - input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y - ret = get_relay_op("batch_matmul")(input_x, input_y) - - # reshape result back to n-dimensional - if ndim > 3: - if is_static: - final_shape = list(orig_shape_x) - final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2] - final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1] - else: - # calculate the resulting shape = [shape[:-2], 0, 0] - final_shape = list(shape_of_x) - final_shape[-2] = shape_of_x[-1] if adj_x else shape_of_x[-2] - final_shape[-1] = shape_of_y[-2] if adj_y else shape_of_y[-1] - final_shape = _op.concatenate(_op.Tuple(final_shape), axis=0) - - ret = _op.reshape(ret, newshape=final_shape) - return ret - - return _impl - - -def _sparse_tensor_dense_matmul(): - def _impl(inputs, attr, params, mod): - # Loading this by default causes TVM to not be loadable from other languages. - # Sparse utility from scipy - from scipy.sparse import csr_matrix - - assert len(inputs) == 4, "There should be 4 input tensors" - - indices_tensor = _infer_value(inputs[0], params, mod).numpy() - values_tensor = _infer_value(inputs[1], params, mod).numpy() - dense_shape_tensor = _infer_value(inputs[2], params, mod).numpy() - - data = inputs[3] - - rows = [x[0] for x in indices_tensor] - cols = [x[1] for x in indices_tensor] - - # Create scipy sparse Tensor(CSR) - weight_sp = csr_matrix( - (values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist()) - ) - - # As per tensorflow implementation, we have 4 possible input combination - # and the first input(A) is always sparse and second input(B) is always dense. - # Case 1: A , B , adjoint_a=False, adjoint_b=False --> A * B - # Case 2: A , B , adjoint_a=True, adjoint_b=False --> A.T * B - # Case 3: A , B , adjoint_a=False, adjoint_b=True --> A * B.T - # Case 4: A , B , adjoint_a=True, adjoint_b=True --> A.T * B.T - # - # Topi implementation for sparse_dense(matmul) has 2 possible input - # combination where first input(A) is always dense - # and second input(B) is always sparse. - # Case 1: A , B, sparse_lhs = False --> A * B.T - # Case 2: A , B, sparse_lhs = True --> B * A.T - # - # The mapping would be as below: - # TF Case 1: A , B , adjoint_a=False, adjoint_b=False - # --> In TF: A * B --> In Topi: A * B.T.T - # --> sparse_dense(transpose(B), A, sparse_lhs=True) - # - # TF Case 2: A , B , adjoint_a=True, adjoint_b=False - # --> In TF: A.T * B --> In Topi: A.T * B.T.T - # --> sparse_dense(transpose(B), transpose(A), sparse_lhs=True) - # - # TF Case 3: A , B , adjoint_a=False, adjoint_b=True - # --> In TF: A * B.T --> In Topi: A * B - # --> sparse_dense(B, A, sparse_lhs=True) - # - # TF Case 4: A , B , adjoint_a=True, adjoint_b=True - # --> In TF: A.T * B.T --> In Topi: (B * A.T).T - # --> transpose(sparse_dense(B, transpose(A), sparse_lhs=False)) - - # By default, in tensorflow the first input ,i.e., data is sparse - sparse_lhs = True - - # TF Case 1: - if not attr.get("adjoint_a") and not attr.get("adjoint_b"): - data = _op.transpose(data) - # TF Case 2: - elif attr.get("adjoint_a") and not attr.get("adjoint_b"): - data = _op.transpose(data) - weight_sp = csr_matrix(weight_sp.transpose()) - # TF Case 3: - elif not attr.get("adjoint_a") and attr.get("adjoint_b"): - pass - # TF Case 4: - # attr.get("adjoint_a") and attr.get("adjoint_b"): - else: - sparse_lhs = False - weight_sp = csr_matrix(weight_sp.transpose()) - - weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype) - weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype) - weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype) - - ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs], sparse_lhs) - - if not sparse_lhs: - # TF Case 4 - ret = _op.transpose(ret) - - return ret - - return _impl - - -def _sparse_fill_empty_rows(): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 4, "There should be 4 input tensors" - sparse_indices = inputs[0] - sparse_values = inputs[1] - sparse_indices_num_cols = _infer_shape(sparse_indices, mod)[1] - first_column = _op.split(sparse_indices, sparse_indices_num_cols, axis=1)[0] - sorted_indices = _op.argsort(_op.squeeze(first_column)) - sorted_sparse_indices = _op.take(sparse_indices, sorted_indices, axis=0) - sorted_sparse_values = _op.take(sparse_values, sorted_indices, axis=0) - new_sparse_indices, new_sparse_values, empty_row_indicator = _op.sparse_fill_empty_rows( - sorted_sparse_indices, sorted_sparse_values, inputs[2], inputs[3] - ) - - return _expr.TupleWrapper( - _expr.Tuple([new_sparse_indices, new_sparse_values, empty_row_indicator]), - 3, - ) - - return _impl - - -def _sparse_reshape(): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 3, "There should be 3 input tensors" - new_indices, new_shape = get_relay_op("sparse_reshape")(inputs[0], inputs[1], inputs[2]) - return _expr.TupleWrapper(_expr.Tuple([new_indices, new_shape]), 2) - - return _impl - - -def _math_segment_sum(): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 2, "There should be 2 input tensors" - return get_relay_op("segment_sum")(inputs[0], inputs[1]) - - return _impl - - -def _sparse_segment_sum(): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 3, "There should be 3 input tensors" - data = _op.take(inputs[0], inputs[1], axis=0) - return _op.segment_sum(data, inputs[2]) - - return _impl - - -def _sparse_segment_sum_with_num_segments(): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 4, "There should be 4 input tensors" - data = _op.take(inputs[0], inputs[1], axis=0) - num_segments = int(inputs[3].data.numpy().item()) - return _op.segment_sum(data, inputs[2], num_segments) - - return _impl - - -def row_wise_divide(multi_dim_tensor, one_dim_vector): - """ - This function enables row-wise division of multi_dim_tensor and one_dim_vector. - To achieve this, it is first tiled to the appropriate shape and then elemwise_division - """ - multi_dim_tensor_offrow_shape = _op.strided_slice( - _op.shape_of(multi_dim_tensor, "int32"), [1], [-1], slice_mode="size" - ) - one_dim_vector_tiled_shape = _op.concatenate( - [_op.reverse(multi_dim_tensor_offrow_shape, 0), _expr.const([1])], axis=0 - ) - one_dim_vector_tiled = _op.transpose(_op.tile(one_dim_vector, one_dim_vector_tiled_shape)) - return _op.divide(multi_dim_tensor, one_dim_vector_tiled) - - -def count_all_indices(segment_ids, counts_dtype, num_segments=None): - """ - This snippet calculates the sqrt count of each index among all valid indices - Valid indices are from 0 to max of [segment ids, num_segments] - """ - - max_segments = _op.reshape(_op.max(segment_ids), -1) + _expr.const([1]) - if num_segments: - max_segments = _op.maximum(max_segments, _expr.const([num_segments])) - max_ones = _op.maximum(max_segments, _op.shape_of(segment_ids)) - counts = _op.segment_sum( - _op.ones(max_ones, counts_dtype), segment_ids, num_segments=num_segments - ) - real_counts = _op.clip(counts, 1, 2147483647) # Clip max doesn't work over int32 - return real_counts - - -def _sparse_segment_sum_sqrtn(): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 3, "There should be 3 input tensors" - data = _op.take(inputs[0], inputs[1], axis=0) - real_counts = count_all_indices(inputs[2], attr["T"].name) - real_sqrt_counts = _op.sqrt(_op.cast_like(real_counts, data)) - - # Calculate regular segment sum - segment_sum = _op.segment_sum(data, inputs[2]) - - return row_wise_divide(segment_sum, real_sqrt_counts) - - return _impl - - -def _sparse_segment_sum_sqrtn_with_num_segments(): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 4, "There should be 4 input tensors" - data = _op.take(inputs[0], inputs[1], axis=0) - num_segments = int(inputs[3].data.numpy().item()) - real_counts = count_all_indices(inputs[2], attr["T"].name, num_segments=num_segments) - real_sqrt_counts = _op.sqrt(_op.cast_like(real_counts, data)) - - # Calculate regular segment sum - segment_sum = _op.segment_sum(data, inputs[2], num_segments=num_segments) - - return row_wise_divide(segment_sum, real_sqrt_counts) - - return _impl - - -def _sparse_segment_mean(): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 3, "There should be 3 input tensors" - data = _op.take(inputs[0], inputs[1], axis=0) - real_counts = count_all_indices(inputs[2], attr["T"].name) - - # Calculate regular segment sum - segment_sum = _op.segment_sum(data, inputs[2]) - - return row_wise_divide(segment_sum, real_counts) - - return _impl - - -def _sparse_segment_mean_with_num_segments(): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 4, "There should be 4 input tensors" - data = _op.take(inputs[0], inputs[1], axis=0) - num_segments = int(inputs[3].data.numpy().item()) - real_counts = count_all_indices(inputs[2], attr["T"].name, num_segments=num_segments) - - # Calculate regular segment sum - segment_sum = _op.segment_sum(data, inputs[2], num_segments=num_segments) - - return row_wise_divide(segment_sum, real_counts) - - return _impl - - -def _sparse_tensor_dense_add(): - # Sparse utility from scipy - from scipy.sparse import csr_matrix - - def _impl(inputs, attr, params, mod): - assert ( - len(inputs) == 4 - ), "There should be 4 input tensors [sparse_indices, sparse_values, sparse_shape, dense]." - - indices_tensor = _infer_value(inputs[0], params, mod).numpy() - values_tensor = _infer_value(inputs[1], params, mod).numpy() - dense_shape_tensor = _infer_value(inputs[2], params, mod).numpy() - - data = inputs[3] - - rows = [x[0] for x in indices_tensor] - cols = [x[1] for x in indices_tensor] - - # Create scipy sparse Tensor(CSR) - weight_sp = csr_matrix( - (values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist()) - ) - - weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype) - weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype) - weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype) - - ret = _op.nn.sparse_add(data, [weight_data, weight_indices, weight_indptrs]) - - return ret - - return _impl - - -def _identity(): - def _impl(inputs, attr, params, mod): - return inputs[0] - - return _impl - - -def _identityn(): - def _impl(inputs, attr, params, mod): - return inputs - - return _impl - - -def _concatV2(): - def _impl(inputs, attr, params, mod): - pop_node = inputs.pop(len(inputs) - 1) - axis = int(_get_num_param(params, pop_node)) - return AttrCvt(op_name="concatenate", ignores=["T", "N", "Tidx"], extras={"axis": axis})( - [inputs], attr - ) - - return _impl - - -def _concat(): - def _impl(inputs, attr, params, mod): - pop_node = inputs.pop(0) - axis = int(_get_num_param(params, pop_node)) - return AttrCvt(op_name="concatenate", ignores=["N"], extras={"axis": axis})([inputs], attr) - - return _impl - - -def _pack(): - def _impl(inputs, attr, params, mod): - axis = int(attr["axis"]) - inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] - return _op.concatenate(inputs_reshaped, axis) - - return _impl - - -def _tensor_array(): - def _impl(inputs, attr, params, prelude): - dtype_str = attr.get("dtype").name - assert not attr["dynamic_size"], "Dynamic size tensor array is " "not supported in TVM yet." - - if "shape" in attr: - shape = attr["shape"] - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, shape) - static_tensor_array_ops.register() - tensor_array_constructor = static_tensor_array_ops.get_global_var("tensor_array") - tensor_array = tensor_array_constructor(inputs[0]) - else: - tensor_array_constructor = prelude.get_global_var("tensor_array", dtype_str) - tensor_array = tensor_array_constructor(inputs[0]) - return tensor_array - - return _impl - - -def _tensor_array_scatter(): - def _impl(inputs, attr, params, prelude): - dtype_str = attr.get("T").name - input_ta = inputs[0] - input_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) - values_shape = _infer_shape(inputs[2], prelude.mod) - input_t_shape = values_shape[1:] - indices_shape = _infer_shape(inputs[1], prelude.mod) - - if input_shape is None: - values_rank = len(values_shape) - unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) - unstack_function = prelude.get_global_var(unstack_name, dtype_str) - values = unstack_function(inputs[2]) - tensor_array_scatter_func = prelude.get_global_var("tensor_array_scatter", dtype_str) - else: - input_t_shape = _get_more_static_shape(input_t_shape, input_shape) - values_shape = (values_shape[0],) + input_t_shape - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_t_shape) - static_tensor_array_ops.register() - # Register static indices shape - if isinstance(indices_shape[0], int): - static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True) - tensor_array_scatter_func = prelude.get_global_var_static( - "tensor_array_scatter", dtype_str, input_t_shape - ) - - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, values_shape) - static_tensor_array_ops.register() - unstack_function = prelude.get_global_var_static( - "tensor_array_unstack", dtype_str, values_shape - ) - values = unstack_function(inputs[2]) - ret = tensor_array_scatter_func(input_ta, inputs[1], values) - return ret - - return _impl - - -def _tensor_array_gather(): - def _impl(inputs, attr, params, prelude): - dtype_str = attr.get("dtype").name - input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude) - indices_shape = _infer_shape(inputs[1], prelude.mod) - - if input_shape is None: - gather_func = prelude.get_var("tensor_array_gather", dtype_str) - out = gather_func(inputs[2], inputs[1]) - else: - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape) - static_tensor_array_ops.register() - - if not isinstance(indices_shape[0], int): - gather_function = prelude.get_global_var_static( - "tensor_array_gather", dtype_str, input_shape - ) - out_tensor_t = gather_function(inputs[2], inputs[1]) - out_shape = (indices_shape[0],) + input_shape - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape) - static_tensor_array_ops.register() - - # Output shape is (indices_shape[0],) + input_shape - get_data_func = prelude.get_global_var_static( - "tensor_get_data", dtype_str, out_shape - ) - out = get_data_func(out_tensor_t) - else: - # For fixed length indices, directly generate static shape output - read_func = prelude.get_global_var_static( - "tensor_array_read", dtype_str, input_shape - ) - get_data_func = prelude.get_global_var_static( - "tensor_get_data", dtype_str, input_shape - ) - tensor_list = [] - for i in range(indices_shape[0]): - index = _op.take(inputs[1], tvm.relay.const(i)) - out_tensor = get_data_func(read_func(inputs[2], index)) - tensor_list.append(_op.expand_dims(out_tensor, axis=0)) - - if indices_shape[0] > 1: - out = _op.concatenate(tensor_list, axis=0) - else: - out = tensor_list[0] - - return out - - return _impl - - -def _tensor_array_size(): - def _impl(inputs, attr, params, prelude): - return prelude.length(inputs[0]) - - return _impl - - -def _tensor_array_write(): - def _impl(inputs, attr, params, prelude): - dtype_str = attr.get("T").name - input_ta = inputs[3] - input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) - input_t_shape = _infer_shape(inputs[2], prelude.mod) - input_rank = len(input_t_shape) - - if input_ta_shape is None: - tensor_name = "tensor{}".format(input_rank) - tensor_func = prelude.get_tensor_ctor(tensor_name, dtype_str) - v = tensor_func(inputs[2]) - write_func = prelude.get_global_var("tensor_array_write", dtype_str) - else: - input_ta_rank = len(input_ta_shape) - assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}".format( - input_ta_rank, input_rank - ) - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) - static_tensor_array_ops.register() - tensor_func = static_tensor_array_ops.get_ctor("tensor_constructor") - v = tensor_func(inputs[2]) - # Write tensor with more static shape - actual_shape = _get_more_static_shape(input_t_shape, input_ta_shape) - if actual_shape != input_t_shape: - new_shape = [] - num_any_dim = 0 - for dim in actual_shape: - if not isinstance(dim, int): - num_any_dim += 1 - new_shape.append(dim if isinstance(dim, int) else -1) - if num_any_dim <= 1: - v = tensor_func(_op.reshape(inputs[2], new_shape)) - - write_func = prelude.get_global_var_static( - "tensor_array_write", dtype_str, input_ta_shape - ) - - return write_func(input_ta, _op.take(inputs[1], tvm.relay.const(0)), v) - - return _impl - - -def _tensor_array_read(): - def _impl(inputs, attr, params, prelude): - dtype_str = attr["dtype"].name - input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude) - - if input_shape is None: - read_func = prelude.get_global_var("tensor_array_read", dtype_str) - out = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) - else: - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape) - static_tensor_array_ops.register() - read_func = static_tensor_array_ops.get_global_var("tensor_array_read") - out_tensor = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) - get_data_func = static_tensor_array_ops.get_global_var("tensor_get_data") - out = get_data_func(out_tensor) - - return out - - return _impl - - -def _tensor_array_split(): - def _impl(inputs, attr, params, prelude): - dtype_str = attr.get("T").name - input_ta = inputs[0] - input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) - lengths = _op.cast(inputs[2], "int32") - lengths_shape = _infer_shape(lengths, prelude.mod) - value_shape = _infer_shape(inputs[1], prelude.mod) - input_rank = len(value_shape) - - if input_ta_shape is None: - tensor_name = "tensor{}".format(input_rank) - tensor_ctor = prelude.get_tensor_ctor(tensor_name, dtype_str) - v = tensor_ctor(inputs[1]) - split_func = prelude.get_global_var("tensor_array_split", dtype_str) - else: - input_ta_rank = len(input_ta_shape) - assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}".format( - input_ta_rank, input_rank - ) - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) - static_tensor_array_ops.register() - - # Check static value/indices shape - if isinstance(value_shape[0], int) or isinstance(lengths_shape[0], int): - static_tensor_array_ops.define_tensor_array_split(value_shape, lengths_shape, True) - - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, value_shape) - static_tensor_array_ops.register() - tensor_ctor = static_tensor_array_ops.get_ctor("tensor_constructor") - v = tensor_ctor(inputs[1]) - split_func = prelude.get_global_var_static( - "tensor_array_split", dtype_str, input_ta_shape - ) - - return split_func(input_ta, v, lengths) - - return _impl - - -def _tensor_array_concat(): - def _impl(inputs, attr, params, prelude): - dtype_str = attr["dtype"].name - input_shape = get_tensor_array_shape(inputs[1], dtype_str, prelude) - - if input_shape is None: - concat_func = prelude.get_global_var("tensor_array_concat", dtype_str) - out = concat_func(inputs[1]) - else: - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape) - static_tensor_array_ops.register() - concat_func = prelude.get_global_var_static( - "tensor_array_concat", dtype_str, input_shape - ) - out_tensor = concat_func(inputs[1]) - out_shape = (Any(),) + input_shape[1:] - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape) - static_tensor_array_ops.register() - get_data_func = prelude.get_global_var_static("tensor_get_data", dtype_str, out_shape) - out = get_data_func(out_tensor) - - return out - - return _impl - - -def _tile(): - def _impl(inputs, attr, params, mod): - reps_input = inputs.pop() - if isinstance(reps_input, _expr.Call): - np_reps = _infer_value(reps_input, params, mod).numpy() - reps = [np_reps.flatten()[i] for i in range(np_reps.flatten().shape[0])] - else: - reps = _get_list_param(params, reps_input, mod) - new_input = [inputs.pop(0)] - - return AttrCvt(op_name="tile", extras={"reps": tuple(reps)}, ignores=["Tmultiples"])( - new_input, attr - ) - - return _impl - - -def _slice(): - def _impl(inputs, attr, params, mod): - try: - begin = _get_list_param(params, inputs[1], mod) - except Exception: - # Handle symbolic begin - begin = inputs[1] - try: - size = _get_list_param(params, inputs[2], mod) - except Exception: - # Handle symbolic size - size = inputs[2] - - # Align begin and strides for dynamic shape. - data_dim = len(_infer_shape(inputs[0], mod)) - strides = [1] * data_dim - if not isinstance(begin, (_expr.Call, _expr.Var)): - for _ in range(len(begin), data_dim): - begin.append(0) - elif not isinstance(size, (_expr.Call, _expr.Var)): - for _ in range(len(size), data_dim): - size.append(-1) - return _op.strided_slice( - inputs[0], begin=begin, end=size, strides=strides, slice_mode="size" - ) - - return _impl - - -def _reshape(): - def _impl(inputs, attr, params, mod): - pop_node = inputs.pop(1) - - try: - shape_arg = _get_tuple_param(params, pop_node) - except AttributeError: - # Shape operator is already pruned, hence - # try to infer shape by precompute prune if possible. - try: - params_new = _infer_value(pop_node, params, mod) - shape_arg = tuple(params_new.numpy().astype("int32").flatten()) - except Exception: - # Deal with symbolic shape case. - if isinstance(pop_node, _expr.Call) and "shape_of" in str(pop_node.op): - # shape_of is the direct ancestor. - return _op.reshape_like(inputs[0], pop_node.args[0]) - shape_arg = pop_node - - return AttrCvt(op_name="reshape", extras={"newshape": shape_arg}, ignores=["Tshape"])( - inputs, attr - ) - - return _impl - - -def _depth_to_space(): - def _impl(inputs, attr, params, mod): - block_size = int(attr["block_size"]) - layout = attr["data_format"].decode("utf-8") - return _op.nn.depth_to_space(inputs[0], block_size, layout) - - return _impl - - -def _space_to_depth(): - def _impl(inputs, attr, params, mod): - block_size = int(attr["block_size"]) - layout = attr["data_format"].decode("utf-8") - return _op.nn.space_to_depth(inputs[0], block_size, layout) - - return _impl - - -def _sparse_to_dense(): - def _impl(inputs, attr, params, mod): - sparse_indices = inputs[0] - output_shape = inputs[1] - sparse_values = inputs[2] - default_value = inputs[3] - - return _op.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) - - return _impl - - -def _bias_add(): - def _impl(inputs, attr, params, mod): - # Must expand for proper broadcasting in NCHW. - if "data_format" in attr and attr["data_format"].decode("utf-8") == "NCHW": - bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1)) - else: - bias = inputs[1] - return _op.add(inputs[0], bias) - - return _impl - - -def _broadcast_args(): - def _impl(inputs, attr, params, mod): - if isinstance(inputs[0], _expr.Var): - s0 = params[inputs[0].name_hint] - else: - s0 = _infer_value(inputs[0], params, mod) - if isinstance(inputs[1], _expr.Var): - s1 = params[inputs[1].name_hint] - else: - s1 = _infer_value(inputs[1], params, mod) - s0 = list(s0.numpy().reshape([-1])) - s1 = list(s1.numpy().reshape([-1])) - s0_size, s1_size = len(s0), len(s1) - - out = deque([]) - for i in range(1, min(s0_size, s1_size) + 1): - if s0[s0_size - i] == s1[s1_size - i]: - out.appendleft(s0[s0_size - i]) - elif s0[s0_size - i] == 1: - out.appendleft(s1[s1_size - i]) - else: - assert s1[s1_size - i] == 1, "Incompatible broadcast type %s and %s" % ( - s0[s0_size - i], - s1[s1_size - i], - ) - out.appendleft(s0[s0_size - i]) - if s0_size < s1_size: - for i in range(s0_size + 1, s1_size + 1): - out.appendleft(s1[s1_size - i]) - if s1_size < s0_size: - for i in range(s1_size + 1, s0_size + 1): - out.appendleft(s0[s0_size - i]) - return _expr.const(list(out), attr["T"].name) - - return _impl - - -def _broadcast_to(): - def _impl(inputs, attr, params, mod): - if isinstance(inputs[1], _expr.Var): - shape = params[inputs[1].name_hint] - else: - shape = _infer_value(inputs[1], params, mod) - shape = list(shape.numpy().reshape([-1])) - return _op.broadcast_to(inputs[0], shape) - - return _impl - - -def _squeeze(): - def _impl(inputs, attr, params, mod): - if len(attr["squeeze_dims"]) == 0: - attr["squeeze_dims"] = None - return AttrCvt( - op_name="squeeze", transforms={"squeeze_dims": "axis"}, ignores=["T", "_cloned"] - )(inputs, attr) - - return _impl - - -def _fused_batch_norm(): - def _impl(inputs, attr, params, mod): - # Tensorflow: (data, gamma, beta, moving_mean, moving_variance) - # Relay: (data, gamma, beta, moving_mean, moving_varience) - assert len(inputs) == 5 - axis = 3 - need_cast = False - - if "data_format" in attr: - attr["data_format"] = attr["data_format"].decode("utf-8") - if attr["data_format"] == "NCHW": - axis = 1 - if "U" in attr and attr["U"].name != attr["T"].name: - need_cast = True - inputs[0] = _op.cast(inputs[0], dtype=attr["U"].name) - # Check if mean and variance are empty - # If so, replace them with Mean and Variance Ops - # For run-time calculation - moving_mean_shape = [int(n) for n in inputs[3].type_annotation.shape] - moving_variance_shape = [int(n) for n in inputs[4].type_annotation.shape] - if moving_mean_shape[0] == 0 and moving_variance_shape[0] == 0: - inputs[3] = _op.mean(inputs[0], axis=axis, keepdims=False, exclude=True) - inputs[4] = _op.variance(inputs[0], axis=axis, keepdims=False, exclude=True) - out = AttrCvt( - op_name="batch_norm", - transforms={"scale_after_normalization": "scale", "variance_epsilon": "epsilon"}, - extras={"axis": axis}, - ignores=["data_format", "U", "exponential_avg_factor"], - disables=["momentum"], - )(inputs, attr) - - if need_cast: - out = _expr.TupleGetItem(out.astuple(), 0) - out = _op.cast(out, dtype=attr["T"].name) - return out - - return _impl - - -def _batch_norm(): - def _impl(inputs, attr, params, mod): - # Rearrange inputs from - # (data, moving_mean, moving_variance, beta, gamma) - # to - # (data, gamma, beta, moving_mean, moving_var) - new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]] - - axis = 3 - if "data_format" in attr: - attr["data_format"] = attr["data_format"].decode("utf-8") - if attr["data_format"] == "NCHW": - axis = 1 - - return AttrCvt( - op_name="batch_norm", - transforms={"scale_after_normalization": "scale", "variance_epsilon": "epsilon"}, - extras={"axis": axis}, - ignores=["data_format", "exponential_avg_factor"], - disables=["momentum"], - )(new_inputs, attr) - - return _impl - - -def _relu6(): - def _impl(inputs, attr, params, mod): - return _op.clip(inputs[0], a_min=0, a_max=6) - - return _impl - - -def _shape(): - def _impl(inputs, attr, params, mod): - is_symbolic_shape = False - input_shape = _infer_shape(inputs[0], mod) - for axis in input_shape: - if not isinstance(axis, (int, tvm.tir.IntImm)): - is_symbolic_shape = True - break - - if is_symbolic_shape: - ret = _op.shape_of(inputs[0], dtype=attr["out_type"].name) - else: - ret = np.array(input_shape, dtype=attr["out_type"].name) - return ret - - return _impl - - -def _fill(): - def _impl(inputs, attr, params, mod): - try: - output_shape = _infer_value(inputs[0], params, mod).numpy().tolist() - except Exception: - output_shape = inputs[0] - - return _op.full(inputs[1], output_shape, attr["T"].name) - - return _impl - - -def _lrn(): - def _impl(inputs, attr, params, mod): - attr_new = {} - depth_radius = attr.get("depth_radius", 5) - size = (depth_radius * 2) + 1 - attr_new["axis"] = 3 # Fix axis, NHWC format - attr_new["size"] = size - attr_new["bias"] = attr.get("bias", 1) - attr_new["alpha"] = attr.get("alpha", 1) * size - attr_new["beta"] = attr.get("beta", 0.5) - return AttrCvt(op_name="lrn")(inputs, attr_new) - - return _impl - - -def _sum(): - def _impl(inputs, attr, params, mod): - axis = _get_tuple_param(params, inputs[1]) - return AttrCvt( - op_name="sum", - extras={"axis": axis}, - transforms={"keep_dims": "keepdims"}, - ignores=["name", "Tidx"], - )([inputs[0]], attr) - - return _impl - - -def _reduce(op): - def _impl(inputs, attr, params, mod): - axis = _get_list_param(params, inputs[1], mod) - axis = tuple(axis) - if not axis: - axis = None - return AttrCvt( - op_name=op, - extras={"axis": axis}, - transforms={"keep_dims": "keepdims"}, - ignores=["name", "Tidx"], - )([inputs[0]], attr) - - return _impl - - -def _euclidean_norm(): - def _impl(inputs, attr, params, mod): - axis = tuple(_get_list_param(params, inputs[1], mod)) - keep_dims = bool(attr.get("keep_dims", False)) - return _op.sqrt( - _op.cast(_op.reduce.sum(_op.multiply(inputs[0], inputs[0]), axis, keep_dims), "float32") - ) - - return _impl - - -def _square(): - def _impl(inputs, attr, params, mod): - return _op.multiply(inputs[0], inputs[0]) - - return _impl - - -def _gather(): - "GatherV2, Gather" - - def _impl(inputs, attr, params, mod): - if len(inputs) > 2: - axis = _get_num_param(params, inputs.pop(2)) - else: - axis = 0 - batch_dims = 0 - if int(attr.get("batch_dims", 0)) != 0: - batch_dims = int(attr.get("batch_dims", 0)) - new_input = inputs[0:2] - op_ = AttrCvt( - op_name="take", - extras={ - "axis": tvm.tir.const(axis, "int32"), - "batch_dims": tvm.tir.const(batch_dims, "int32"), - }, - ignores=["Tindices", "Tparams", "validate_indices", "Taxis", "_class"], - )(new_input, attr) - return op_ - - return _impl - - -def _gather_nd(): - """GatherNd""" - - def _impl(inputs, attr, params, mod): - indices_dims = len(_infer_shape(inputs[1], mod)) - indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1))) - return AttrCvt(op_name="gather_nd", ignores=["Tindices", "Tparams", "Taxis", "_class"])( - [inputs[0], indices], attr - ) - - return _impl - - -def _stridedSlice(): - def _impl(inputs, attr, params, mod): - """Strided Slice. - Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice - Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/ - tensorflow/core/util/strided_slice_op.cc#L147-L368 - """ - begin = _get_list_param(params, inputs[1], mod) - end = _get_list_param(params, inputs[2], mod) - stride = _get_list_param(params, inputs[3], mod) - - begin_mask = int(attr.get("begin_mask", 0)) - end_mask = int(attr.get("end_mask", 0)) - ellipsis_mask = int(attr.get("ellipsis_mask", 0)) - new_axis_mask = int(attr.get("new_axis_mask", 0)) - shrink_axis_mask = int(attr.get("shrink_axis_mask", 0)) - in_type = _infer_type(inputs[0], mod) - data_shape = get_const_tuple(in_type.checked_type.shape) - data_dim = len(data_shape) - stride_dim = len(stride) - if data_dim == 0 and isinstance(inputs[0], _expr.Constant): - new_data = inputs[0].data.numpy().reshape(1) - return _expr.const(new_data, inputs[0].data.dtype) - - # This is a special routine to handle strided_slice after shape_of. - # We need this since in some cases we want to do strided_slice on - # a partial symbolic shape, such as (1, ?), and get a static shape - # (1,). Directly slice on shape_of will result in fully dynamic shape. - # TODO(kevinthesun): Can we generalize this process with partial eval? - if isinstance(inputs[0], _expr.Call) and inputs[0].op == _op.get("shape_of"): - bg = begin[0] - ed = end[0] - st = stride[0] - - if ed <= 0 < st: - ed += data_shape[0] - - in_shape = _infer_shape(inputs[0].args[0], mod) - dtype = in_type.checked_type.dtype - out_data = [] - idx = bg - while idx < ed: - if isinstance(in_shape[idx], int): - out_data.append(in_shape[idx]) - else: - break - idx += st - - # Only return when in_shape is fully static in the range from begin to end. - if idx >= ed: - ret = _expr.const(out_data, dtype) - if shrink_axis_mask: - ret = _op.squeeze(ret) - - return ret - - def _transform_mask(stride_dim, ellipsis_mask): - """Handle mask inputs to create new begin, end, stride and output shape""" - m_begin = [0] * data_dim - m_end = [0] * data_dim - m_stride = [0] * data_dim - fshape_indices = [] - # Count new axis after ellipsis_mask, consider while applying ellipsis_mask. - ellipsis_seen = False - new_axes_after_ellipsis = 0 - for i in range(stride_dim): - mask = 1 << i - if ellipsis_seen and (mask & new_axis_mask) != 0: - new_axes_after_ellipsis += 1 - if (mask & ellipsis_mask) != 0: - ellipsis_seen = True - if not ellipsis_seen: - # Used later for extending the stride attributes in the below loop. - ellipsis_mask |= 1 << stride_dim - stride_dim += 1 - final_index = 0 - for index in range(stride_dim): - mask = 1 << index - if mask & ellipsis_mask: - # Identify the end index for applying ellipsis_mask - to_index = min( - ((data_dim - (stride_dim - index)) + 1 + new_axes_after_ellipsis), data_dim - ) - for i in range(final_index, to_index): - m_begin[final_index] = 0 - m_end[final_index] = data_shape[final_index] - m_stride[final_index] = 1 - fshape_indices.append(final_index) - final_index += 1 - elif mask & new_axis_mask: - fshape_indices.append(-1) - elif not mask & new_axis_mask: - if final_index == len(m_begin): - break - if mask & begin_mask: - m_begin[final_index] = -1 if stride[index] < 0 else 0 - elif begin[index]: - m_begin[final_index] = begin[index] - if mask & end_mask: - m_end[final_index] = ( - -(data_shape[final_index] + 1) - if stride[index] < 0 - else data_shape[final_index] - ) - elif end[index]: - m_end[final_index] = end[index] - m_stride[final_index] = stride[index] - if mask & shrink_axis_mask: - # Tensorflow make axis with shrink_axis_mask as dimension 1 - m_begin[final_index] = ( - data_shape[final_index] + begin[index] - if begin[index] < 0 - else begin[index] - ) - m_end[final_index] = begin[index] + 1 - m_stride[final_index] = 1 - fshape_indices.append(-2) - else: - fshape_indices.append(final_index) - - final_index += 1 - return m_begin, m_end, m_stride, fshape_indices - - fshape_indices = None - if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: - begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) - out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) - out_shape = _infer_shape(out, mod=mod) - if not fshape_indices: - fshape_indices = range(len(out_shape)) - - # Create final output shape. - final_output = [] - for gather_index in fshape_indices: - if gather_index == -1: - final_output.append(1) - elif gather_index == -2: - pass - else: - final_output.append(out_shape[gather_index]) - - if not final_output: - if not shrink_axis_mask: - ret = out - else: - final_shape = [] - for dim in out_shape: - if dim != 1: - final_shape.append(dim) - if len(final_shape) == 0: - ret = _op.squeeze(out) - else: - # We need reshape to handle dynamic shape. - ret = _op.reshape(out, newshape=tuple(final_shape)) - else: - ret = _op.reshape(out, newshape=tuple(final_output)) - return ret - - return _impl - - -def _pad(name): - def _impl(inputs, attr, params, mod): - try: - padlist = _get_param(params, inputs[1]) - except (IndexError, KeyError, AttributeError): - try: - padlist = _infer_value(inputs[1], params, mod).numpy().tolist() - except Exception: - padlist = inputs[1] - - if isinstance(padlist, _expr.Expr): - paddings = padlist - else: - paddings = tuple(tuple(l) for l in padlist) - attr["pad_width"] = paddings - attr["pad_value"] = 0 - new_inputs = [inputs[0]] - if name == "PadV2": - try: - attr["pad_value"] = _get_num_param(params, inputs[2]) - except (IndexError, KeyError, AttributeError): - attr["pad_value"] = inputs[2] - return AttrCvt( - op_name="pad", - ignores=["Tpaddings"], - )(new_inputs, attr) - - return _impl - - -def _mirror_pad(): - def _impl(inputs, attr, params, mod): - padlist = _get_param(params, inputs[1]) - paddings = tuple(tuple(l) for l in padlist) - attr["pad_width"] = paddings - mode = attr["mode"].decode("utf-8") - attr["mode"] = mode - new_inputs = [inputs[0]] - return AttrCvt( - op_name="mirror_pad", - ignores=["Tpaddings"], - )(new_inputs, attr) - - return _impl - - -def _transpose(): - def _impl(inputs, attr, params, mod): - # If perm is not specified, axes is left empty, - # otherwise its value is get from params - axes = _get_list_param(params, inputs[1], mod) - return _op.transpose(inputs[0], axes=axes) - - return _impl - - -def _where(): - def _impl(inputs, attr, params, mod): - if len(inputs) == 1: - return AttrCvt(op_name="argwhere")(inputs, attr) - return AttrCvt(op_name="where")(inputs, attr) - - return _impl - - -def _clip_by_value(): - def _impl(inputs, attr, params, mod): - a_min = _get_num_param(params, inputs[1]) - a_max = _get_num_param(params, inputs[2]) - return _op.clip(inputs[0], a_min=a_min, a_max=a_max) - - return _impl - - -def _reverse_v2(): - def _impl(inputs, attr, params, mod): - axis = _get_num_param(params, inputs[1]) - return AttrCvt(op_name="reverse", ignores=["Tidx"], extras={"axis": int(axis)})( - [inputs[0]], attr - ) - - return _impl - - -def _rank(): - def _impl(inputs, attr, params, mod): - input_shape = _infer_shape(inputs[0], mod) - - name = attr["_node_name"] - params[name] = tvm.nd.array(np.array([len(input_shape)]).astype("int32")) - return [_expr.var(name, shape=params[name].shape, dtype="int32")] - - return _impl - - -def _range(): - def _impl(inputs, attr, params, mod): - try: - start = _get_param(params, inputs[0])[0] - except (IndexError, KeyError, AttributeError): - try: - start = _infer_value(inputs[1], params, mod).numpy().tolist() - start = start if not isinstance(start, list) else start[0] - except Exception: - # Symbolic start - start = inputs[0] - - try: - limit = ( - _get_param(params, inputs[1])[0] - if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) - else params.pop("Rank").numpy()[0] - ) - except (IndexError, KeyError, AttributeError): - try: - limit = _infer_value(inputs[1], params, mod).numpy().tolist() - limit = limit if not isinstance(limit, list) else limit[0] - except Exception: - limit = inputs[1] - - try: - delta = _get_param(params, inputs[2])[0] - except (IndexError, KeyError, AttributeError): - try: - delta = _infer_value(inputs[2], params, mod).numpy().tolist() - delta = delta if not isinstance(delta, list) else delta[0] - except Exception: - # Symbolic delta - delta = inputs[2] - - # if all attributes are constant, evalute the range function and return relay.const - if all( - [ - isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)), - isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)), - isinstance(delta, (np.int32, np.int64, int, np.float32, np.float64, float)), - ] - ): - return tvm.relay.const(list(range(int(start), int(limit), int(delta)))) - - dtype = attr["Tidx"].name if "Tidx" in attr else str(start.dtype) - if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)): - start = _expr.const(start, dtype=dtype) - if isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)): - limit = _expr.const(limit, dtype=dtype) - if isinstance(delta, (np.int32, np.int64, int, np.float32, np.float64, float)): - delta = _expr.const(delta, dtype=dtype) - - return AttrCvt( - op_name="arange", - ignores=["Tidx", "_class"], - extras={"start": start, "stop": limit, "step": delta, "dtype": dtype}, - )([], attr) - - return _impl - - -def _elu(): - def _impl(inputs, attr, params, mod): - dtype = attr["T"].name - alpha = tvm.relay.const(-1.0, dtype) - return alpha * _op.nn.relu(tvm.relay.const(1, dtype) - _op.exp(inputs[0])) + _op.nn.relu( - inputs[0] - ) - - return _impl - - -def _selu(): - def _impl(inputs, attr, params, mod): - dtype = attr["T"].name - alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype) - gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype) - return gamma * ( - alpha * _op.nn.relu(tvm.relay.const(1, dtype) - _op.exp(inputs[0])) - + _op.nn.relu(inputs[0]) - ) - - return _impl - - -def _mean(): - def _impl(inputs, attr, params, mod): - axis = _get_tuple_param(params, inputs[1]) - return AttrCvt( - op_name="mean", - ignores=["Tdim", "Tidx"], - transforms={"keep_dims": "keepdims"}, - extras={"axis": axis}, - )([inputs[0]], attr) - - return _impl - - -def _broadcast(name): - def _impl(inputs, attr, params, mod): - return AttrCvt(op_name=name, ignores=["name", "incompatible_shape_error", "Tidx"])( - inputs, attr - ) - - return _impl - - -def _split(has_size_vector): - # TF documentation https://www.tensorflow.org/api_docs/python/tf/split - def _impl(inputs, attr, params, mod): - try: - # order and number of inputs are different: - # if has_size_vector: - # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v - # else: - # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split - - # in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow, - # we can only support constants - if has_size_vector: - input_node_index = 0 - input_axis_index = 2 - size_splits = _get_param(params, inputs[1]) - section_beginnings = np.cumsum(size_splits)[:-1] - indices_or_sections = tuple(section_beginnings) - else: - input_node_index = 1 - input_axis_index = 0 - indices_or_sections = attr["num_split"] - input_node = inputs[input_node_index] - axis_input_value = _get_num_param(params, inputs[input_axis_index]) - except (IndexError, KeyError, AttributeError): - raise TypeError( - "Unsupported argument for split: `axis` and `num_or_size_splits` " - "should be constants" - ) - return _op.split( - input_node, indices_or_sections=indices_or_sections, axis=int(axis_input_value) - ) - - return _impl - - -def _unpack(): - def _impl(inputs, attr, params, mod): - input_node = inputs[0] - axis = attr["axis"] - input_shape = _infer_shape(input_node, mod) - axis_length = input_shape[axis] - if axis_length < 0: - raise TypeError("Unstack with unknown axis length") - splitted = _op.split(input_node, indices_or_sections=axis_length, axis=axis) - axis = [axis] - return _expr.TupleWrapper( - _expr.Tuple([_op.squeeze(split_item, axis=axis) for split_item in splitted]), - len(splitted), - ) - - return _impl - - -def _softmax(): - def _impl(inputs, attr, params, mod): - return AttrCvt(op_name="softmax", transforms={"axis": ("axis", 1)})([inputs[0]], attr) - - return _impl - - -def _softsign(): - # op description: https://www.tensorflow.org/api_docs/python/tf/math/softsign - def _impl(inputs, attr, params, mod): - abs_out = get_relay_op("abs")(inputs[0]) - add_out = abs_out + tvm.relay.const(1, attr["T"].name) - return inputs[0] / add_out - - return _impl - - -def _softplus(): - # op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus - def _impl(inputs, attr, params, mod): - exp_out = AttrCvt("exp")(inputs, attr) - inputs.append(tvm.relay.const(1, attr["T"].name)) - rh = tvm.relay.const(1, attr["T"].name) - add_out = get_relay_op("add")(exp_out, rh) - return get_relay_op("log")(add_out) - - return _impl - - -def _topk(): - def _impl(inputs, attr, params, mod): - k_input = inputs.pop(1) - try: - k = int(_get_num_param(params, k_input)) - except (IndexError, KeyError, AttributeError): - try: - k = int(_infer_value(k_input, params, mod).numpy().tolist()) - except Exception: - k = k_input - if isinstance(k, int): - if k < 1: - raise tvm.error.OpAttributeInvalid( - "Attribute k must be positive in operator TopKV2" - ) - k = _expr.const(k) - if attr["sorted"] is False: - raise tvm.error.OpAttributeUnImplemented( - "Attribute sorted=False is not supported in operator TopKV2" - ) - return AttrCvt( - op_name="topk", - ignores=["sorted"], - extras={"k": k, "is_ascend": False, "dtype": "int32"}, - )([inputs[0]], attr) - - return _impl - - -def _floordiv(): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 2 - return AttrCvt("floor_divide")(inputs, attr) - - return _impl - - -def _floormod(): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 2 - return AttrCvt("floor_mod")(inputs, attr) - - return _impl - - -def _logical(name): - def _impl(inputs, attr, params, mod): - return AttrCvt(op_name=name)(inputs, attr) - - return _impl - - -def _space_to_batch_nd(): - def _impl(inputs, attr, params, mod): - block_shape = _get_list_param(params, inputs[1], mod) - - paddings = _get_list_param(params, inputs[2], mod) - paddings = np.squeeze(paddings) - if len(paddings.shape) == 1: - paddings = np.expand_dims(paddings, axis=0) - paddings = paddings.tolist() - - attr["block_shape"] = block_shape - attr["paddings"] = paddings - out = AttrCvt("space_to_batch_nd", ignores=["Tblock_shape", "Tpaddings"])([inputs[0]], attr) - - return out - - return _impl - - -def _batch_to_space_nd(): - def _impl(inputs, attr, params, mod): - block_shape = _get_list_param(params, inputs[1], mod) - - crops = _get_list_param(params, inputs[2], mod) - crops = np.squeeze(crops) - if len(crops.shape) == 1: - crops = np.expand_dims(crops, axis=0) - crops = crops.tolist() - - attr["block_shape"] = block_shape - attr["crops"] = crops - out = AttrCvt("batch_to_space_nd", ignores=["Tblock_shape", "Tcrops"])([inputs[0]], attr) - - return out - - return _impl - - -def _atan2(): - def _impl(inputs, attr, params, mod): - divide = _elemwise("divide")(inputs, attr, params, mod) - return get_relay_op("atan")(divide) - - return _impl - - -def _prod(): - def _impl(inputs, attr, params, mod): - axis = _get_num_param(params, inputs[1]) - keepdims = attr["keep_dims"] - return _op.prod(inputs[0], int(axis), keepdims=keepdims) - - return _impl - - -def _log1p(): - # op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p - def _impl(inputs, attr, params, mod): - one = tvm.relay.const(1, attr["T"].name) - add_out = get_relay_op("add")(inputs[0], one) - return get_relay_op("log")(add_out) - - return _impl - - -def _one_hot(): - def _impl(inputs, attr, params, mod): - depth = int(_get_num_param(params, inputs[1])) - dtype = attr["T"].name - - on_value = _get_num_param(params, inputs[2]) - off_value = _get_num_param(params, inputs[3]) - new_inputs = [ - inputs[0], - tvm.relay.const(on_value, dtype), - tvm.relay.const(off_value, dtype), - ] - return AttrCvt("one_hot", ignores=["TI"], extras={"depth": depth, "dtype": dtype})( - new_inputs, attr - ) - - return _impl - - -def _squared_difference(): - def _impl(inputs, attr, params, mod): - difference = _op.subtract(inputs[0], inputs[1]) - return _op.multiply(difference, difference) - - return _impl - - -def _size(): - def _impl(inputs, attr, params, mod): - new_attr = attr - new_attr["out_type"] = attr["out_type"].name - return AttrCvt("ndarray_size", transforms={"out_type": "dtype"})(inputs, new_attr) - - return _impl - - -def _add_n(): - def _impl(inputs, attr, params, mod): - if not isinstance(inputs, tuple): - inputs = list(inputs) - assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given." - _res = inputs[0] - for each in inputs[1:]: - _res = _op.add(_res, each) - return _res - - return _impl - - -def _LSTMBlockCell(): - def _impl(inputs, attr, params, mod): - """LSTM Block cell. - Calculations and return values are described in: - https://github.com/tensorflow/tensorflow/blob/ - r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114 - - Parameters - ---------- - inputs : relay.Expr - Input data - in_state_c: list of relay.Expr - Cell state input values for all the layers - in_state_h: list of relay.Expr - Hidden state input values for all the layers - attrs : dict - Dict of operator attributes - params : dict - List of pretrained weights and bias - - Returns - ------- - relay.Expr.TupleWapper - [i, cs, f, o, ci, co, h] - """ - in_data = inputs[0] - in_state_c = inputs[1] - in_state_h = inputs[2] - in_weight = inputs[3] - in_bias = inputs[7] - forget_bias = attr.pop("forget_bias") - input_shape = _infer_shape(inputs[0], mod) - weight_shape = _infer_shape(inputs[3], mod) - batch_size, input_size = input_shape[0], input_shape[1] - num_hidden_layers = weight_shape[1] - - in_data = _op.reshape(in_data, newshape=(batch_size, input_size)) - ixh = _op.concatenate([in_data, in_state_h], axis=1) - in_weight = _op.transpose(in_weight, axes=None) - gates = _op.nn.dense(ixh, in_weight, units=num_hidden_layers) - gates_bias = _op.add(gates, in_bias) - gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1) - in_gate = _op.sigmoid(gate_list[0]) - in_transform = _op.tanh(gate_list[1]) - forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr["T"].name)) - forget_gate = _op.sigmoid(forget_gate) - out_gate = _op.sigmoid(gate_list[3]) - next_c = _op.add(_op.multiply(forget_gate, in_state_c), _op.multiply(in_gate, in_transform)) - co = _op.tanh(next_c) - next_h = out_gate * co - - return tvm.relay.TupleWrapper( - tvm.relay.Tuple([in_gate, next_c, forget_gate, out_gate, in_transform, co, next_h]), 7 - ) - - return _impl - - -def _unique(return_counts=True): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 1 - data = inputs[0] - if return_counts: - [unique, indices, num_uniq, counts] = _op.unique( - data, is_sorted=False, return_counts=True - ) - unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") - counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") - return _expr.TupleWrapper( - _expr.Tuple([unique_sliced, indices, counts_sliced]), - 3, - ) - [unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False) - unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") - return _expr.TupleWrapper( - _expr.Tuple([unique_sliced, indices]), - 2, - ) - - return _impl +from .tensorflow_ops import _convert_map +from .tensorflow_ops import _need_prelude_for_shape_inference +from .tensorflow_ops import _get_more_static_shape +__all__ = ["from_tensorflow"] # compatible operators that do NOT require any conversion. _identity_list = [] @@ -2736,189 +59,6 @@ def _impl(inputs, attr, params, mod): "AssignVariableOp", ] - -# _convert_map defines maps of name to converter functor(callable) -# for 1 to 1 mapping, use Renamer if nothing but name is different -# use AttrCvt if attributes need to be converted -# for 1 to N mapping(composed), use custom callable functions -# for N to 1 mapping, currently not supported(?) -_convert_map = { - "Abs": AttrCvt("abs"), - "Acos": AttrCvt("acos"), - "Acosh": AttrCvt("acosh"), - "Add": _elemwise("add"), - "AddN": _add_n(), - "AddV2": _elemwise("add"), - "All": _reduce("all"), - "Any": _reduce("any"), - "ArgMax": _argx(_op.argmax, "argmax"), - "ArgMin": _argx(_op.argmin, "argmin"), - "Asin": AttrCvt("asin"), - "Asinh": AttrCvt("asinh"), - "Assert": _assert(), - "Atan": AttrCvt("atan"), - "Atanh": AttrCvt("atanh"), - "Atan2": _atan2(), - "AvgPool": _pooling("avg_pool"), - "AvgPool3D": _pool3d("avg_pool3d"), - "BatchMatMul": _batch_matmul(), - "BatchMatMulV2": _batch_matmul(), - "BatchNormWithGlobalNormalization": _batch_norm(), - "BatchToSpaceND": _batch_to_space_nd(), - "BiasAdd": _bias_add(), - "BroadcastTo": _broadcast_to(), - "BroadcastArgs": _broadcast_args(), - "Cast": _cast(), - "Ceil": AttrCvt("ceil"), - "CheckNumerics": _check_numerics(), - "ClipByValue": _clip_by_value(), - "Concat": _concat(), - "ConcatV2": _concatV2(), - "Conv2D": _conv("conv"), - "Conv2DBackpropInput": _conv("conv_transpose"), - "Conv3D": _conv3d("conv"), - "Conv3DBackpropInputV2": _conv3d("conv_transpose"), - "Cos": AttrCvt("cos"), - "Cosh": AttrCvt("cosh"), - "CropAndResize": _crop_and_resize(), - "DecodeJpeg": _decode_image(), - "DepthToSpace": _depth_to_space(), - "DepthwiseConv2dNative": _conv("depthwise"), - "Dilation2D": _dilation2d(), - "Elu": _elu(), - "Equal": _broadcast("equal"), - "Erf": AttrCvt("erf"), - "EuclideanNorm": _euclidean_norm(), - "Exp": AttrCvt("exp"), - "ExpandDims": _expand_dims(), - "Expm1": _expm1(), - "Fill": _fill(), - "Floor": AttrCvt("floor"), - "FloorDiv": _floordiv(), - "FloorMod": _floormod(), - "FusedBatchNorm": _fused_batch_norm(), - "FusedBatchNormV2": _fused_batch_norm(), - "FusedBatchNormV3": _fused_batch_norm(), - "Gather": _gather(), - "GatherNd": _gather_nd(), - "GatherV2": _gather(), - "Greater": _broadcast("greater"), - "GreaterEqual": _broadcast("greater_equal"), - "Identity": _identity(), - "IdentityN": _identityn(), - "IsFinite": AttrCvt("isfinite"), - "IsInf": AttrCvt("isinf"), - "IsNan": AttrCvt("isnan"), - "LeakyRelu": AttrCvt("leaky_relu"), - "LeftShift": AttrCvt("left_shift"), - "Less": _broadcast("less"), - "LessEqual": _broadcast("less_equal"), - "Log": AttrCvt("log"), - "Log1p": _log1p(), - "LogicalAnd": _logical("logical_and"), - "LogicalNot": _logical("logical_not"), - "LogicalOr": _logical("logical_or"), - "LogSoftmax": AttrCvt("log_softmax"), - "LRN": _lrn(), - "LSTMBlockCell": _LSTMBlockCell(), - "MatMul": _matmul(), - "Max": _reduce("max"), - "Maximum": _elemwise("maximum"), - "MaxPool": _pooling("max_pool"), - "MaxPool3D": _pool3d("max_pool3d"), - "Mean": _mean(), - "Min": _reduce("min"), - "Minimum": _elemwise("minimum"), - "MirrorPad": _mirror_pad(), - "Mod": _elemwise("mod"), - "Mul": _elemwise("multiply"), - "Neg": AttrCvt("negative"), - "NonMaxSuppressionV2": _nms(), - "NonMaxSuppressionV3": _nms(), - "NonMaxSuppressionV4": _nms(), - "NonMaxSuppressionV5": _nms(True), - "CombinedNonMaxSuppression": _combined_nms(), - "NoOp": _no_op(), - "NotEqual": _broadcast("not_equal"), - "OneHot": _one_hot(), - "Pack": _pack(), - "Pad": _pad("Pad"), - "PadV2": _pad("PadV2"), - "Pow": _elemwise("power"), - "Prod": _prod(), - "Range": _range(), - "Rank": _rank(), - "RealDiv": _elemwise("divide"), - "Relu": AttrCvt("relu"), - "Relu6": _relu6(), - "Reshape": _reshape(), - "ResizeBicubic": _resize("bilinear"), - "ResizeBilinear": _resize("bilinear"), - "ResizeNearestNeighbor": _resize("nearest_neighbor"), - "ReverseV2": _reverse_v2(), - "RightShift": AttrCvt("right_shift"), - "Rint": AttrCvt("round"), - "Round": AttrCvt("round"), - "Rsqrt": _rsqrt(), - "Select": _where(), - "SelectV2": _where(), - "Selu": _selu(), - "Shape": _shape(), - "Sigmoid": AttrCvt("sigmoid"), - "Sign": AttrCvt("sign"), - "Sin": AttrCvt("sin"), - "Sinh": AttrCvt("sinh"), - "Size": _size(), - "Slice": _slice(), - "Softmax": _softmax(), - "Softplus": _softplus(), - "Softsign": _softsign(), - "SpaceToBatchND": _space_to_batch_nd(), - "SpaceToDepth": _space_to_depth(), - "SparseToDense": _sparse_to_dense(), - "SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(), - "SparseFillEmptyRows": _sparse_fill_empty_rows(), - "SparseReshape": _sparse_reshape(), - "SegmentSum": _math_segment_sum(), - "SparseSegmentSum": _sparse_segment_sum(), - "SparseSegmentSumWithNumSegments": _sparse_segment_sum_with_num_segments(), - "SparseSegmentSqrtN": _sparse_segment_sum_sqrtn(), - "SparseSegmentSqrtNWithNumSegments": _sparse_segment_sum_sqrtn_with_num_segments(), - "SparseSegmentMean": _sparse_segment_mean(), - "SparseSegmentMeanWithNumSegments": _sparse_segment_mean_with_num_segments(), - "SparseTensorDenseAdd": _sparse_tensor_dense_add(), - "Split": _split(False), - "SplitV": _split(True), - "Sqrt": AttrCvt("sqrt"), - "Square": _square(), - "SquaredDifference": _squared_difference(), - "Squeeze": _squeeze(), - "StopGradient": _identity(), - "StridedSlice": _stridedSlice(), - "Sub": _elemwise("subtract"), - "Sum": _sum(), - "Tan": AttrCvt("tan"), - "Tanh": AttrCvt("tanh"), - "TensorArrayConcatV3": _tensor_array_concat(), - "TensorArrayGatherV3": _tensor_array_gather(), - "TensorArrayReadV3": _tensor_array_read(), - "TensorArrayScatterV3": _tensor_array_scatter(), - "TensorArraySizeV3": _tensor_array_size(), - "TensorArraySplitV3": _tensor_array_split(), - "TensorArrayV3": _tensor_array(), - "TensorArrayWriteV3": _tensor_array_write(), - "Tile": _tile(), - "TopKV2": _topk(), - "Transpose": _transpose(), - "TruncateMod": _elemwise("mod"), - "Unique": _unique(False), - "UniqueWithCounts": _unique(True), - "Unpack": _unpack(), - "UnravelIndex": _unravel_index(), - "Where": _where(), - "ZerosLike": AttrCvt("zeros_like"), -} - # An internal list to contain all the control flow primitives used in Tensorflow # 1.x. _control_flow_nodes = ["Merge", "Switch", "NextIteration", "Exit", "Enter", "LoopCond"] diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py new file mode 100644 index 000000000000..e5339b33c4e9 --- /dev/null +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -0,0 +1,658 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +"""Tensorflow2.x graph to relay converter. + +If model is constructed using tf2.x API, then use this converter: + from tvm.relay.frontend.tensorflow2 import from_tensorflow +Otherwise use the tf1.x converter: + from tvm.relay.frontend.tensorflow import from_tensorflow + +""" + +import numpy as np +from tensorflow.python.framework import function_def_to_graph, tensor_util, dtypes + +import tvm +from tvm.relay.transform import InferType +from tvm.relay.prelude import Prelude +from tvm.ir import IRModule +from .. import expr as _expr +from .. import analysis +from .. import function as _function +from ..loops import while_loop as _while_loop +from .common import infer_type as _infer_type + +from .tensorflow_ops import _convert_map as _convert_map_common +from .tensorflow_ops import _need_prelude_for_shape_inference + +from ..ty import Any + +__all__ = ["from_tensorflow"] + + +def _infer_type_with_prelude(val, prelude): + body = _infer_type(val, prelude.mod) + return body.checked_type + + +def set_span(sym, node_name): + """set span of symbol""" + + span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) + if isinstance(sym, _expr.Call): + sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) + elif isinstance(sym, _expr.TupleWrapper): + tuple_value = sym.tuple_value + if isinstance(tuple_value, _expr.Call): + tuple_value = _expr.Call( + tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span + ) + sym = _expr.TupleWrapper(tuple_value, sym.size) + return sym + + +def convert_const_node(node, shape): + """convert tf const node into relay const or var""" + + # get the value of the constant + tensor_value = node.attr["value"].tensor + np_array = tensor_util.MakeNdarray(tensor_value) + + if np_array.dtype == np.dtype(object): + if shape and node.name in shape: + var_shape = shape[node.name] + else: + var_shape = tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape) + param = None + sym = [_expr.var(node.name, shape=var_shape, dtype="uint8")] + return sym, param + + if len(np_array.shape) == 0: + param = None + sym = [tvm.relay.const(np_array, np_array.dtype)] + else: + param = tvm.nd.array(np_array) + sym = [_expr.var(node.name, shape=param.shape, dtype=param.dtype)] + + return sym, param + + +def get_attr(buf): + """convert value of a node attribute. node attribute is part of a node in a graph. + + Parameters + ---------- + buf: attrvalue protobuf. + + Returns + ------- + The value of the attr, as a Python object. + + Raises: + ------- + ValueError: If this op does not have an attr with the given `name`. + """ + + fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] + + ret = [] + + if not buf.WhichOneof("value"): + return ret + + if buf.HasField("list"): + for f in fields: + if getattr(buf.list, f): + if f == "type": + ret += [dtypes.as_dtype(x) for x in list(getattr(buf.list, f))] + else: + ret += list(getattr(buf.list, f)) + else: + for f in fields: + if buf.HasField(f): + if f == "type": + ret = dtypes.as_dtype(getattr(buf, f)) + else: + ret = getattr(buf, f) + return ret + + +def parse_attr(attr_proto): + """Convert node attributes (a serialized map of key-value pairs) in a node to a dict + + Parameters + ---------- + attr_proto: + + Returns + ------- + Dict {string: python object} + + """ + attrs = {} + for key, value in attr_proto.items(): + attrs[key] = get_attr(value) + + return attrs + + +def convert_placeholder(shape, node, in_type=None): + """convert tf placeholder into relay var. + + Example + -------- + a tf placeholder with name "x" is converted to [Var(x, ty=TensorType([], float32))] + """ + + if shape and node.name in shape: + input_shape = list(shape[node.name]) + else: + input_shape = tensor_util.TensorShapeProtoToList(node.attr["shape"].shape) + for idx, dim in enumerate(input_shape): + if dim < 0: + input_shape[idx] = Any() + attr = parse_attr(node.attr) + if in_type is not None: + sym = [_expr.var(node.name, type_annotation=in_type)] + else: + sym = [_expr.var(node.name, shape=input_shape, dtype=attr["dtype"].name)] + return input_shape, sym + + +class RelayModule: + """states related to the entire relay module (multiple functions) + after converted from tf graphdef""" + + def __init__(self): + self.mod = IRModule({}) + self.params = {} + self.prelude = Prelude(self.mod) + + +class GraphProto: + """Capturing states when converting a tf graph to a single relay function.""" + + def __init__(self, module): + self._module = module + self._prelude = self._module.prelude + self._params = {} + self._nodes = {} + self._input_shapes = {} + self._output_shapes = {} + self._tf_node_map = {} + self._gdef_lib = {} + + def from_tensorflow( + self, graph, layout="NHWC", shape=None, outputs=None, input_types=None, gdef_lib=None + ): + """Wrapper to _get_relay_func which converts Tensorflow graph to Relay function + which is used as main function for the Relay module + """ + if input_types is None: + input_types = {} + + if gdef_lib is None: + gdef_lib = {} + + self._gdef_lib = gdef_lib + func = self._get_relay_func( + graph, layout=layout, shape=shape, outputs=outputs, input_types=input_types + ) + return func, self._params + + def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_types=None): + if input_types is None: + input_types = {} + + self._layout = layout + for node in graph.node: + name = node.name + self._tf_node_map[name] = node + if node.op == "Placeholder": + in_type = None + if node.name in input_types: + in_type = input_types[node.name] + self._input_shapes[name], self._nodes[name] = convert_placeholder( + shape, node, in_type + ) + elif node.op == "Const": + sym, param = convert_const_node(node, shape) + self._nodes[node.name] = sym + if param: + self._params[node.name] = param + for node in graph.node: + self._backtrack_construct(graph, node.name) + + return self._func(graph, outputs) + + def _func(self, graph, outputs): + out = [] + if outputs is None: + last_node = graph.node[-1] + op = self._nodes[last_node.name.split(":")[0]] + if last_node.op == "Exit": + out = [op[0].tuple_value] + else: + out = op + else: + for out_name in outputs: + if ":" in out_name: + out_name = out_name.split(":") + out_name, out_num = out_name[0], out_name[-1] + out_num = int(out_num) + out.append(self._nodes[out_name][out_num]) + else: + out.append(self._nodes[out_name][0]) + + if isinstance(out, _expr.TupleWrapper): + out = out.astuple() + else: + out = out[0] if len(out) == 1 else _expr.Tuple(out) + + fvars = analysis.free_vars(out) + func = _function.Function(fvars, out) + final_params = {} + for fv in fvars: + if fv.name_hint in self._params: + final_params[fv.name_hint] = self._params[fv.name_hint] + self._params = final_params + return func + + def _convert_operator(self, graph, op_name, node_name, inputs, attrs): + """Convert from Tensorflow operator to relay operator. + The converter must specify conversions explicitly for incompatible name, and + apply handlers to operator attributes. + + Parameters + ---------- + graph: + TF2 frozen graph def + op_name : str + Operator name, such as Conv2D, AvgPool + node_name: str + Name of the node in TF2 graph, such as Identity:0 + inputs : list of relay.op + List of input symbols. + attrs : dict + Dict of operator attributes + + Returns + ------- + sym : relay.op + Converted relay operator + """ + if op_name in ["PartitionedCall", "StatefulPartitionedCall"]: + sym = _partition_call_operator( + self._module, + graph, + inputs, + attrs, + self._prelude, + gdef_lib=self._gdef_lib, + ) + elif op_name in ["StatelessIf", "If"]: + sym = _convert_if( + self._module, graph, inputs, attrs, self._prelude, gdef_lib=self._gdef_lib + ) + elif op_name in ["StatelessWhile", "While"]: + sym = _convert_loop( + self._module, + graph, + inputs, + attrs, + node_name, + self._tf_node_map, + self._prelude, + gdef_lib=self._gdef_lib, + ) + elif op_name in _convert_map_common: + if _need_prelude_for_shape_inference(op_name): + sym = _convert_map_common[op_name](inputs, attrs, self._params, self._prelude) + else: + sym = _convert_map_common[op_name](inputs, attrs, self._params, self._module.mod) + else: + raise NotImplementedError("Operator {} not implemented.".format(op_name)) + + sym = set_span(sym, node_name) + return sym + + def _backtrack_construct(self, graph, node_name): + """Convert a specific tensorflow node to relay expression. + + If any of its ancestor node is not converted yet, backtrack as + far as input node and covert all nodes on the path. resurion is used here. + + This is required when parsing control flow nodes, since the parsing + order may not follow the original graph def. + + to discover input node, current tf node's input is iterated: + + tensorflow/core/framework/node_def.proto + message NodeDef { + repeated string input = 3; + } + + a node has many inputs (other nodes). each input has the following format: + data input is "node:src_output". node is the string name. + control input is "^node". + + Parameters + ---------- + graph : + TF2 frozen graph def + + node_name : str + node name + + Returns + ------- + op : relay.Expr + Converted relay expression. + + Examples + -------- + tf expression "x+1" is converted to relay expression: + CallNode(Op(add), [Var(x, ty=TensorType([], float32)), Constant(1.0)], (nullptr), []) + + """ + + input_op_name = node_name.split(":")[0].split("^")[-1] + if input_op_name not in self._nodes: + node = self._tf_node_map[input_op_name] + attr = parse_attr(node.attr) + if "_output_shapes" in attr: + self._output_shapes[node.name] = [ + tensor_util.TensorShapeProtoToList(tshape) for tshape in attr["_output_shapes"] + ] + else: + self._output_shapes[node.name] = [None] + + attr["_output_shapes"] = self._output_shapes[input_op_name] + attr["_node_name"] = node.name + attr["_target_layout"] = self._layout + inputs = [self._backtrack_construct(graph, iname) for iname in node.input] + op = self._convert_operator(graph, node.op, node.name, inputs, attr) + + if isinstance(op, np.ndarray): + self._params[node.name] = tvm.nd.array(op) + op = [ + _expr.var( + node.name, + shape=self._params[node.name].shape, + dtype=self._params[node.name].dtype, + ) + ] + elif isinstance(op, (_expr.Expr, _expr.TupleGetItem)): + op = [op] + self._nodes[input_op_name] = op + + out = self._nodes[input_op_name] + if isinstance(out, _expr.TupleWrapper): + tn = node_name.split(":") + tensor_slot = int(tn[1]) if len(tn) > 1 else 0 + return out[tensor_slot] + + return out[0] + + +def _partition_call_operator(module, graph, inputs, attr, prelude, gdef_lib): + """convert tf PartitionedCall node to a relay function call""" + node_func_name = attr.get("f").name + return _convert_function( + module, graph, inputs, attr, node_func_name, prelude, gdef_lib=gdef_lib + ) + + +def _convert_if(module, graph, inputs, attr, prelude, gdef_lib): + """Convert tf If/StatelessIf to Relay If""" + cond_expr = inputs[0] + branch_names = [attr.get(x).name for x in ["then_branch", "else_branch"]] + then_fn, else_fn = [ + _convert_function(module, graph, inputs[1:], attr, name, prelude, gdef_lib=gdef_lib) + for name in branch_names + ] + out = _expr.If(cond_expr, then_fn, else_fn) + return out + + +def _convert_loop(module, graph, inputs, attr, node_name, nodes, prelude, gdef_lib): + """convert tf while_loop to Relay loop""" + input_size = len(inputs) + cond_fn_name, body_fn_name = [attr.get(x).name for x in ["cond", "body"]] + + def convert_vars(loop_inputs, input_signature): + """convert inputs to relay vars to be used as loop variables + Loop inputs are packed as: + [iteration_number, max_iterations, loop_variables...] + """ + new_vars = [] + for i, v in enumerate(loop_inputs): + if isinstance(v, _expr.Constant): + vtype = _infer_type(v).checked_type.dtype + new_vars.append(_expr.var(input_signature[i].name, shape=(), dtype=vtype)) + else: + vtype = _infer_type_with_prelude(v, prelude) + new_vars.append(_expr.var(input_signature[i].name, type_annotation=vtype)) + return new_vars + + while_func = next( + (f for f in graph.library.function if f.signature.name == attr["body"].name), + None, + ) + loop_inputs = convert_vars(inputs, while_func.signature.input_arg) + + def cond_fn(*loop_inputs): + return _convert_function( + module, graph, loop_inputs, attr, cond_fn_name, prelude, gdef_lib=gdef_lib + ) + + # Define the loop body, in this function we need to unpack loop inputs, + # convert the loop subgraph, and pack outputs for the next iteration. + def body_fn(*loop_inputs): + # Increment loop iteration counter + loop_count = loop_inputs[0] + _expr.const(1, dtype="int32") + max_count = loop_inputs[1] + fn = _convert_function( + module, graph, loop_inputs, attr, body_fn_name, prelude, gdef_lib=gdef_lib + ) + + # Repack loop variables + out = [loop_count, max_count] + [_expr.TupleGetItem(fn, i) for i in range(2, input_size)] + return out + + loop = _while_loop(cond_fn, loop_inputs, body_fn) + outputs = loop(*inputs) + outputs = _expr.TupleWrapper( + _expr.Tuple([_expr.TupleGetItem(outputs, i) for i in range(input_size)]), input_size + ) + return outputs + + +def _convert_function( + module, graph, inputs, attr, node_func_name, prelude, gdef_lib, in_shapes=None +): + """Convert given tf node to a relay function call + + Parameters + ---------- + module : IRModule + where converted function is stored + + graph: + top level tf graphdef + + inputs : List[tvm.relay.Expr] + List of input symbols. Parameters for the function. + + attrs : Dict[tvm.Attrs] + Dict of operator attributes. + + node_func_name : str + Name of tf2 node to be converted + + Returns + ------- + op : tvm.relay.Expr + + + Examples + -------- + a tf function "x+1", is implemented as a subgraph in the libary section of the graph. + this subgraph is converted to a relay function such as + fn (%x: float32) { + add(%x, 1f) /* Identity */ + } + + the subgraph has a function name such as __inference_add_95 + the tf function call operator is returned as relay expression, such as: + free_var %x: float32; + @func___inference_add_95(%x) + + """ + func = next( + (f for f in graph.library.function if f.signature.name == node_func_name), + None, + ) + if func is None: + raise Exception("Function not found - {}".format(node_func_name)) + devices = set(node.device for node in func.node_def) + if len(devices) > 1: + raise Exception( + "node_def in function {} contains > 1 types of devices {}".format( + node_func_name, devices + ) + ) + + subgraph = gdef_lib[node_func_name] + # preserve library functions in subgraphs to make them available to nested functions + for fn in graph.library.function: + subgraph.library.function.add().CopyFrom(fn) + + # Computing subgraph's input shape and type dictionaries + input_expr_dict = {} + input_types = {} + for f_arg, input_ in zip(func.signature.input_arg, inputs): + input_expr_dict[f_arg.name] = input_ + input_types[f_arg.name] = _infer_type_with_prelude(input_, prelude) + + func_name = "func_{}".format(func.signature.name) + try: + global_func = module.mod[func_name] + sub_func = global_func + sub_params = module.params + except ValueError: + # Construct relay nodes from the subgraph + g1 = GraphProto(module) + output_sig = [func.ret[f.name] for f in func.signature.output_arg] + # TODO: unify prelude and main IRModules + sub_func, sub_params = g1.from_tensorflow( + subgraph, outputs=output_sig, input_types=input_types, gdef_lib=gdef_lib + ) + module.params.update(sub_params) + func_expr = _function.Function(sub_func.params, sub_func.body) + global_func = tvm.relay.GlobalVar(func_name) + module.mod[global_func] = func_expr + module.mod = InferType()(module.mod) + prelude.mod = module.mod + + param_exprs = [] + for param_expr in sub_func.params: + # sub_params is subset of sub_func.params + param_name = param_expr.vid.name_hint + if param_name in input_expr_dict.keys(): + param_exprs.append(input_expr_dict[param_name]) + elif param_name in sub_params.keys(): + param_exprs.append(param_expr) + else: + raise Exception("Input parameter {} not found".format(param_name)) + + sb = tvm.relay.scope_builder.ScopeBuilder() + loop_ret = global_func(*param_exprs) + sb.ret(loop_ret) + ret = sb.get() + return ret + + +def from_tensorflow(graph_def, layout="NHWC", shape=None, outputs=None): + """convert tensorflow2.x graph into relay function. + + Parameters + ---------- + graph_def : must be frozen graph (no variables allowed). + Placeholders are assumed to be inputs to the graph. + + tensorflow/core/framework/graph.proto + message GraphDef { + repeated NodeDef node = 1; + FunctionDefLibrary library = 2; + } + tensorflow/core/framework/function.proto + message FunctionDef { + repeated NodeDef node_def = 3; + } + + layout : str + The layout for the model. + + shape : List[str, List[int]] + Input to the model. It is a key and shape vector mapping. Applies to placeholders. + + outputs : List[str] + The list of output nodes. The last node is treated as the output if not + specified. + + Returns + ------- + mod : tvm.IRModule + The module that optimizations will be performed on. + + params : dict of str to tvm.nd.NDArray + Dict of converted parameters stored in tvm.nd.NDArray format. + + Examples + -------- + "x+1" tf module where x has a shape of (2,2) is converted as follows: + + mod : tvm.IRModule + def @func___inference_add_95(%x: Tensor[(2, 2), float32], %add/y: Tensor[(2, 2), float32]) + -> Tensor[(2, 2), float32] { + add(%x, %add/y) /* Identity */ /* ty=Tensor[(2, 2), float32] */ + } + + def @main(%x1: Tensor[(2, 2), float32], %add/y1: Tensor[(2, 2), float32]) { + @func___inference_add_95(%x1, %add/y1) /* Identity */ + } + + params : dict of str to tvm.nd.NDArray + {'add/y': + + """ + + # Subgraph graph_defs are cached here to avoid a TF error when parsing after prelude init + graph_def_library = {} + for func in graph_def.library.function: + inshape = func.attr["_input_shapes"].list.shape + graph_def_library[func.signature.name], _ = function_def_to_graph.function_def_to_graph_def( + func, inshape + ) + module = RelayModule() + g = GraphProto(module) + func, params = g.from_tensorflow(graph_def, layout, shape, outputs, gdef_lib=graph_def_library) + module.mod["main"] = func + module.params.update(params) + return module.mod, module.params diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py new file mode 100644 index 000000000000..c7385565857d --- /dev/null +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -0,0 +1,2998 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +# pylint: disable=import-outside-toplevel, redefined-builtin +"""TF: Tensorflow frontend.""" +import warnings +from collections import deque + +# Numpy support +import numpy as np +import tvm + +from tvm.relay.prelude import StaticTensorArrayOps, get_tensor_array_shape +from tvm.topi.utils import get_const_tuple + +from .. import expr as _expr +from .. import op as _op +from ..ty import Any +from .common import AttrCvt, get_relay_op +from .common import infer_type as _infer_type +from .common import infer_shape as _infer_shape +from .common import infer_channels as _infer_channels +from .common import infer_value as _infer_value + + +def check_symbolic_shape(shape): + return not all([isinstance(dim, (int, tvm.tir.IntImm)) for dim in shape]) + + +def list_shape_of(tensor, ndim): + shape_tensor = _op.shape_of(tensor) + return [ + _op.strided_slice(shape_tensor, begin=[i], end=[i + 1], strides=[1]) for i in range(ndim) + ] + + +def _get_pad_pair(input1d, kernel1d, stride1d): + if isinstance(input1d, tvm.tir.Any) and stride1d != 1: + raise tvm.error.OpAttributeUnImplemented( + "SAME padding is not supported in combination with dynamic height or width when stride" + " is not 1." + ) + if stride1d == 1 or input1d % stride1d == 0: + pad = max(kernel1d - stride1d, 0) + else: + pad = max(kernel1d - (input1d % stride1d), 0) + + pad_before = pad // 2 + pad_after = pad - pad_before + + return [pad_before, pad_after] + + +def _math_name_picker(surfix): + def _impl(attr): + return "broadcast_" + surfix + + return _impl + + +def _dimension_picker(prefix, surfix=""): + def _impl(attr): + kernel = attr["kernel_shape"] + if len(kernel) == 2: + return prefix + "2d" + surfix + if len(kernel) == 3: + return prefix + "3d" + surfix + raise tvm.error.OpAttributeInvalid( + "Only 2D or 3D kernels are supported for operator {}".format(prefix + "2d or 3d") + ) + + return _impl + + +def _dimension_constraint(): + def _dim_check(attrs): + if len(attrs["kernel_shape"]) in (2, 3): + return True + return False + + return _dim_check, "Only 2d or 3d kernel supported." + + +def _get_param(params, input_node): + if isinstance(input_node, _expr.Constant): + return np.atleast_1d(input_node.data.numpy()) + return params[input_node.name_hint].numpy() + + +def _get_num_param(params, input_node): + return _get_param(params, input_node).item() + + +def _get_list_param(params, input_node, mod): + try: + return _get_param(params, input_node).tolist() + except (IndexError, KeyError, AttributeError): + return _infer_value(input_node, params, mod).numpy().tolist() + + +def _get_tuple_param(params, input_node): + return tuple(_get_param(params, input_node)) + + +def _need_prelude_for_shape_inference(op): + return "TensorArray" in op + + +def _get_more_static_shape(shape0, shape1): + """Compare two shapes with the same rank, + and return the one with fewer symbolic dimension. + """ + assert len(shape0) == len(shape1) + num_sym_dim0 = 0 + num_sym_dim1 = 0 + for dim0, dim1 in zip(list(shape0), list(shape1)): + if not isinstance(dim0, int): + num_sym_dim0 += 1 + if not isinstance(dim1, int): + num_sym_dim1 += 1 + + if num_sym_dim0 < num_sym_dim1: + return shape0 + return shape1 + + +def _rsqrt(): + def _impl(inputs, attr, params, mod): + inputs.append(tvm.relay.const(-0.5, attr["T"].name)) + return AttrCvt(op_name="power")(inputs, attr) + + return _impl + + +def _argx(func, func_name): + """A common wrapper for argmin and argmax operations""" + + def _impl(inputs, attr, params, mod): + try: + # In Tensorflow, `axis` argument is a Tensor, not attribute. We + # support the case where it inputs from a scalar constant. + axis_input_value = [_get_num_param(params, inputs[1])] + except (IndexError, KeyError): + raise TypeError( + "Unsupported argument for `{}` : `axis` should be a constant".format(func_name) + ) + out = func(inputs[0], axis=axis_input_value, keepdims=False) + dtype = attr["output_type"].name + if dtype != "int32": + out = _op.cast(out, dtype=dtype) + return out + + return _impl + + +def _elemwise(name): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) + return get_relay_op(name)(*inputs) + + return _impl + + +def _pool3d(name): + def _impl(inputs, attr, params, mod): + attr["data_format"] = attr["data_format"].decode("utf-8") + flip_layout = False + + input_shape = _infer_shape(inputs[0], mod) + + if attr["data_format"] == "NDHWC": + attr["kernel_shape"] = (attr["ksize"][1], attr["ksize"][2], attr["ksize"][3]) + attr["strides"] = (attr["strides"][1], attr["strides"][2], attr["strides"][3]) + elif attr["data_format"] == "NCDHW": + attr["kernel_shape"] = (attr["ksize"][2], attr["ksize"][3], attr["ksize"][4]) + attr["strides"] = (attr["strides"][2], attr["strides"][3], attr["strides"][4]) + else: + msg = 'Value {} of attribute "data_format" of operator Pooling ' "is not valid." + raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"])) + if attr["data_format"] == "NDHWC": + input_shape = [_infer_shape(inputs[0], mod)[i] for i in (0, 4, 1, 2, 3)] + inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3)) + attr["data_format"] = "NCDHW" + flip_layout = True + + attr["padding"] = attr["padding"].decode("utf-8") + + if attr["padding"] == "VALID": + attr["padding"] = [0, 0, 0, 0, 0, 0] + elif attr["padding"] == "SAME": + stride_d, stride_h, stride_w = attr["strides"] + kernel_d, kernel_h, kernel_w = attr["kernel_shape"] + if attr["data_format"] == "NDHWC": + in_d = input_shape[1] + in_h = input_shape[2] + in_w = input_shape[3] + else: + in_d = input_shape[2] + in_h = input_shape[3] + in_w = input_shape[4] + pad_d = _get_pad_pair(in_d, kernel_d, stride_d) + pad_v = _get_pad_pair(in_h, kernel_h, stride_h) + pad_h = _get_pad_pair(in_w, kernel_w, stride_w) + + attr["padding"] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]] + else: + msg = 'Value {} in attribute "padding" of operator Pooling is ' "not valid." + raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) + + if name == "avg_pool": + attr["count_include_pad"] = False + attr["ceil_mode"] = False + out = AttrCvt( + op_name=name, + transforms={"kernel_shape": "pool_size", "data_format": "layout"}, + ignores=["ksize"], + )(inputs, attr) + if flip_layout: + out = _op.transpose(out, axes=(0, 2, 3, 4, 1)) + return out + + return _impl + + +def _pooling(name): + def _impl(inputs, attr, params, mod): + + attr["data_format"] = attr["data_format"].decode("utf-8") + flip_layout = False + + input_shape = _infer_shape(inputs[0], mod) + + if attr["data_format"] == "NHWC": + attr["kernel_shape"] = (attr["ksize"][1], attr["ksize"][2]) + attr["strides"] = (attr["strides"][1], attr["strides"][2]) + elif attr["data_format"] == "NCHW": + attr["kernel_shape"] = (attr["ksize"][2], attr["ksize"][3]) + attr["strides"] = (attr["strides"][2], attr["strides"][3]) + else: + msg = 'Value {} of attribute "data_format" of operator Pooling ' "is not valid." + raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"])) + + if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC": + tmp_shape = _infer_shape(inputs[0], mod) + input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] + inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) + attr["data_format"] = "NCHW" + flip_layout = True + + # Fix padding + attr["padding"] = attr["padding"].decode("utf-8") + + if attr["padding"] == "VALID": + attr["padding"] = [0, 0] + elif attr["padding"] == "SAME": + stride_h, stride_w = attr["strides"] + kernel_h, kernel_w = attr["kernel_shape"] + if attr["data_format"] == "NHWC": + in_h = input_shape[1] + in_w = input_shape[2] + else: + in_h = input_shape[2] + in_w = input_shape[3] + + pad_v = _get_pad_pair(in_h, kernel_h, stride_h) + pad_h = _get_pad_pair(in_w, kernel_w, stride_w) + + attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] + elif attr["padding"] == "EXPLICIT": + paddings = attr["explicit_paddings"] + assert len(paddings) == 8 + if flip_layout or attr["data_format"] == "NHWC": + attr["padding"] = [paddings[2], paddings[4], paddings[3], paddings[5]] + else: + attr["padding"] = [paddings[4], paddings[6], paddings[5], paddings[7]] + else: + msg = 'Value {} in attribute "padding" of operator Pooling is ' "not valid." + raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) + + if name == "avg_pool": + attr["count_include_pad"] = False + + out = AttrCvt( + op_name=_dimension_picker(name), + transforms={"kernel_shape": "pool_size", "data_format": "layout"}, + ignores=["ksize", "explicit_paddings"], + extras={"ceil_mode": False}, + custom_check=_dimension_constraint(), + )(inputs, attr) + + if flip_layout: + out = _op.transpose(out, axes=(0, 2, 3, 1)) + + return out + + return _impl + + +def _conv(opname): + def _impl(inputs, attr, params, mod): + attr["data_format"] = attr["data_format"].decode("utf-8") + flip_layout = False + + if opname == "conv_transpose" and attr["data_format"] == "NHWC": + # transform to NCHW for TVM backend compatible and set 'flip_layout' + # to have output flip back to NHWC + inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2)) + attr["strides"][1], attr["strides"][2], attr["strides"][3] = ( + attr["strides"][3], + attr["strides"][1], + attr["strides"][2], + ) + attr["data_format"] = "NCHW" + + # Check whether output shapes attribute is set and not None + if ( + opname == "conv_transpose" + and len(attr["_output_shapes"]) > 0 + and attr["_output_shapes"][0] + ): + tmp_shape = attr["_output_shapes"][0] + tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] + attr["_output_shapes"][0] = tmp_shape + + flip_layout = True + + inputs_data = inputs[0] if opname != "conv_transpose" else inputs[2] + + # NCHW Layout require weights transpose + weights_shape = _infer_shape(inputs[1], mod) + if attr["data_format"] == "NCHW": + tmp_shape = weights_shape + if opname in ["conv", "conv_transpose"]: + tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] + inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) + else: + tmp_shape = [tmp_shape[ii] for ii in (2, 3, 0, 1)] + inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) + weights_shape = tmp_shape + + input_shape = _infer_shape(inputs_data, mod) + if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC": + input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] + inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2)) + if opname in ["conv", "conv_transpose"]: + weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)] + inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) + else: + weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)] + inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) + + attr["data_format"] = "NCHW" + attr["strides"] = [attr["strides"][ii] for ii in (0, 3, 1, 2)] + flip_layout = True + + if attr["data_format"] == "NHWC": + in_channels = input_shape[3] + kernel_h, kernel_w, _, depth_mult = weights_shape + attr["kernel_shape"] = (weights_shape[0], weights_shape[1]) + if opname == "conv": + attr["channels"] = weights_shape[3] + elif opname == "conv_transpose": + attr["channels"] = weights_shape[2] + else: + attr["channels"] = input_shape[3] * depth_mult + + if "dilations" in attr: + attr["dilations"] = (attr["dilations"][1], attr["dilations"][2]) + attr["strides"] = (attr["strides"][1], attr["strides"][2]) + elif attr["data_format"] == "NCHW": + in_channels = input_shape[1] + _, depth_mult, kernel_h, kernel_w = weights_shape + attr["kernel_shape"] = (weights_shape[2], weights_shape[3]) + if opname == "conv": + attr["channels"] = weights_shape[0] + elif opname == "conv_transpose": + attr["channels"] = weights_shape[1] + else: + attr["channels"] = input_shape[1] * depth_mult + if attr["channels"] < 0: + attr["channels"] *= -1 + + if "dilations" in attr: + attr["dilations"] = (attr["dilations"][2], attr["dilations"][3]) + attr["strides"] = (attr["strides"][2], attr["strides"][3]) + else: + msg = 'Value {} in attribute "data_format" of operator Conv is ' "not valid." + raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"])) + + if opname == "depthwise": + attr["groups"] = in_channels + + # Fix padding + attr["padding"] = attr["padding"].decode("utf-8") + + if attr["padding"] == "VALID": + attr["padding"] = [0, 0] + elif attr["padding"] == "SAME": + stride_h, stride_w = attr["strides"] + kernel_h, kernel_w = attr["kernel_shape"] + + pdata_shape = input_shape + # Check whether output shapes attribute is set and not None + if ( + opname == "conv_transpose" + and len(attr["_output_shapes"]) > 0 + and attr["_output_shapes"][0] + ): + pdata_shape = attr["_output_shapes"][0] + + if attr["data_format"] == "NHWC": + in_h = pdata_shape[1] + in_w = pdata_shape[2] + else: + in_h = pdata_shape[2] + in_w = pdata_shape[3] + + dilation_h = attr["dilations"][0] + dilation_w = attr["dilations"][1] + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) + pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) + + attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] + elif attr["padding"] == "EXPLICIT": + paddings = attr["explicit_paddings"] + assert len(paddings) == 8 + if flip_layout or attr["data_format"] == "NHWC": + attr["padding"] = [paddings[2], paddings[4], paddings[3], paddings[5]] + else: + attr["padding"] = [paddings[4], paddings[6], paddings[5], paddings[7]] + else: + msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid." + raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) + + if "kernel_layout" not in attr: + if opname in ["conv", "conv_transpose"]: + attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW" + else: + attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW" + + # Ignore the new attributes from TF2.0, for now. + out = AttrCvt( + op_name=_dimension_picker( + "conv", surfix="_transpose" if opname == "conv_transpose" else "" + ), + ignores=["explicit_paddings"], + transforms={ + "kernel_shape": "kernel_size", + "data_format": "data_layout", + "dilations": ("dilation", (0, 0)), + "group": ("groups", 1), + }, + custom_check=_dimension_constraint(), + )([inputs_data, inputs[1]], attr) + + if flip_layout: + out = _op.transpose(out, axes=(0, 2, 3, 1)) + + return out + + return _impl + + +# Dilation2d +def _dilation2d(): + def _impl(inputs, attr, params, mod): + if "data_format" not in attr: + attr["data_format"] = "NHWC" + + input_shape = _infer_shape(inputs[0], mod) + weights_shape = _infer_shape(inputs[1], mod) + + if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC": + input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] + inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) + weights_shape = [weights_shape[ii] for ii in (2, 0, 1)] + inputs[1] = _op.transpose(inputs[1], axes=(2, 0, 1)) + attr["data_format"] = "NCHW" + + if attr["data_format"] in ["NHWC", "NCHW"]: + if "rates" in attr: + attr["dilations"] = attr["rates"] + if "dilations" in attr: + attr["dilations"] = (attr["dilations"][1], attr["dilations"][2]) + attr["strides"] = (attr["strides"][1], attr["strides"][2]) + else: + msg = 'Value {} in attribute "data_format" of operator Dilation2D is ' "not valid." + raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"])) + + attr["padding"] = attr["padding"].decode("utf-8") + if attr["padding"] == "VALID": + attr["padding"] = [0, 0] + elif attr["padding"] == "SAME": + stride_h, stride_w = attr["strides"] + if attr["data_format"] == "NHWC": + kernel_h, kernel_w = weights_shape[0], weights_shape[1] + else: + kernel_h, kernel_w = weights_shape[1], weights_shape[2] + if attr["data_format"] == "NHWC": + in_h = input_shape[1] + in_w = input_shape[2] + else: + in_h = input_shape[2] + in_w = input_shape[3] + + dilation_h = attr["dilations"][0] + dilation_w = attr["dilations"][1] + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) + pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) + + if attr["data_format"] == "NHWC": + inputs[0] = _op.nn.pad( + data=inputs[0], + pad_width=((0, 0), (pad_v[0], pad_v[1]), (pad_h[0], pad_h[1]), (0, 0)), + ) + else: + inputs[0] = _op.nn.pad( + data=inputs[0], + pad_width=((0, 0), (0, 0), (pad_v[0], pad_v[1]), (pad_h[0], pad_h[1])), + ) + + attr["padding"] = [0, 0] + + else: + msg = 'Value {} in attribute "padding" of operator Dilation2d is not ' "valid." + raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) + + attr["kernel_layout"] = "HWI" if attr["data_format"] == "NHWC" else "IHW" + out = AttrCvt( + op_name="dilation2d", + ignores=["explicit_paddings", "rates"], + transforms={ + "data_format": "data_layout", + }, + )([inputs[0], inputs[1]], attr) + if attr["_target_layout"] == "NCHW": + out = _op.transpose(out, axes=(0, 2, 3, 1)) + return out + + return _impl + + +def _conv3d(opname): + def _impl(inputs, attr, params, mod): + attr["data_format"] = attr["data_format"].decode("utf-8") + flip_layout = False + + inputs_data = inputs[0] if opname != "conv_transpose" else inputs[2] + + # NCDHW Layout require weights transpose + weights_shape = _infer_shape(inputs[1], mod) + if attr["data_format"] == "NCDHW": + tmp_shape = weights_shape + tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)] + inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2)) + weights_shape = tmp_shape + + input_shape = _infer_shape(inputs_data, mod) + + if attr["_target_layout"] == "NCDHW" and attr["data_format"] == "NDHWC": + input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)] + inputs_data = _op.transpose(inputs_data, axes=(0, 4, 1, 2, 3)) + weights_shape = [weights_shape[ii] for ii in (4, 3, 0, 1, 2)] + inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2)) + + attr["data_format"] = "NCDHW" + attr["strides"] = [attr["strides"][ii] for ii in (0, 4, 1, 2, 3)] + flip_layout = True + + if attr["data_format"] == "NDHWC": + kernel_d, kernel_h, kernel_w, _, _ = weights_shape + attr["kernel_shape"] = (kernel_d, kernel_h, kernel_w) + if opname == "conv": + attr["channels"] = weights_shape[4] + elif opname == "conv_transpose": + attr["channels"] = weights_shape[3] + + if "dilations" in attr: + attr["dilations"] = ( + attr["dilations"][1], + attr["dilations"][2], + attr["dilations"][3], + ) + attr["strides"] = (attr["strides"][1], attr["strides"][2], attr["strides"][3]) + elif attr["data_format"] == "NCDHW": + _, _, kernel_d, kernel_h, kernel_w = weights_shape + attr["kernel_shape"] = (kernel_d, kernel_h, kernel_w) + if opname == "conv": + attr["channels"] = weights_shape[0] + elif opname == "conv_transpose": + attr["channels"] = weights_shape[1] + + if "dilations" in attr: + attr["dilations"] = ( + attr["dilations"][2], + attr["dilations"][3], + attr["dilations"][4], + ) + attr["strides"] = (attr["strides"][2], attr["strides"][3], attr["strides"][4]) + else: + msg = 'Value {} in attribute "data_format" of operator Conv is ' "not valid." + raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"])) + + # Fix padding + attr["padding"] = attr["padding"].decode("utf-8") + + if attr["padding"] == "VALID": + attr["padding"] = [0, 0, 0] + elif attr["padding"] == "SAME": + stride_d, stride_h, stride_w = attr["strides"] + kernel_d, kernel_h, kernel_w = attr["kernel_shape"] + + pdata_shape = input_shape + if opname == "conv_transpose" and len(attr["_output_shapes"]) > 0: + pdata_shape = attr["_output_shapes"][0] + + if attr["data_format"] == "NDHWC": + in_d = pdata_shape[1] + in_h = pdata_shape[2] + in_w = pdata_shape[3] + else: + in_d = pdata_shape[2] + in_h = pdata_shape[3] + in_w = pdata_shape[4] + + dilation_d = attr["dilations"][0] + dilation_h = attr["dilations"][1] + dilation_w = attr["dilations"][2] + dilated_kernel_d = (kernel_d - 1) * dilation_d + 1 + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_d = _get_pad_pair(in_d, dilated_kernel_d, stride_d) + pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) + pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) + + attr["padding"] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]] + elif attr["padding"] == "EXPLICIT": + paddings = attr["explicit_paddings"] + assert len(paddings) == 10 + if flip_layout or attr["data_format"] == "NDHWC": + attr["padding"] = [ + paddings[2], + paddings[4], + paddings[6], + paddings[3], + paddings[5], + paddings[7], + ] + else: + attr["padding"] = [ + paddings[4], + paddings[6], + paddings[8], + paddings[5], + paddings[7], + paddings[9], + ] + else: + msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid." + raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) + + if "kernel_layout" not in attr: + attr["kernel_layout"] = "DHWIO" if attr["data_format"] == "NDHWC" else "OIDHW" + + use_bias = len(inputs) == (3 if opname != "conv_transpose" else 4) + channel_axis = 1 if attr["data_format"] == "NCDHW" else 4 + + # Ignore the new attributes from TF2.0, for now. + out = AttrCvt( + op_name=_dimension_picker( + "conv", surfix="_transpose" if opname == "conv_transpose" else "" + ), + ignores=["explicit_paddings", "Tshape"], + transforms={ + "kernel_shape": "kernel_size", + "data_format": "data_layout", + "dilations": ("dilation", (0, 0)), + "group": ("groups", 1), + }, + custom_check=_dimension_constraint(), + )([inputs_data, inputs[1]], attr) + + if use_bias: + out = _op.nn.bias_add( + out, inputs[2] if opname != "conv_transpose" else inputs[3], axis=channel_axis + ) + + if flip_layout: + out = _op.transpose(out, axes=(0, 2, 3, 4, 1)) + + return out + + return _impl + + +def _nms(return_scores=False): + def _impl(inputs, attr, params, mod): + # Get parameter values + try: + max_output_size = int(np.atleast_1d(inputs[2].data.numpy().astype("int64"))[0]) + except Exception: + try: + max_output_size = ( + _infer_value(inputs[2], params, mod).numpy().astype("int64").tolist()[0] + ) + except Exception: + max_output_size = inputs[2] + iou_threshold = np.atleast_1d(inputs[3].data.numpy())[0] + # score_threshold was introduced from V3 + score_threshold = np.atleast_1d(inputs[4].data.numpy())[0] if len(inputs) > 4 else 0.0 + pad_output = "pad_to_max_output_size" + + # Generate data with shape (1, num_anchors, 5) + scores = AttrCvt( + op_name="expand_dims", + ignores=["T_threshold", pad_output], + extras={"axis": -1, "num_newaxis": 1}, + )([inputs[1]], attr) + data = get_relay_op("concatenate")([scores, inputs[0]], -1) + data = get_relay_op("expand_dims")(data, 0, 1) + + # reason why using get_valid_counts is for inference performance + ct, data, indices = get_relay_op("get_valid_counts")( + data, score_threshold=score_threshold, id_index=-1, score_index=0 + ) + # TensorFlow NMS doesn't have parameter top_k + top_k = -1 + # TF doesn't have class id for nms input + score_index = 0 + nms_ret = get_relay_op("non_max_suppression")( + data=data, + valid_count=ct, + indices=indices, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + force_suppress=True, + top_k=top_k, + coord_start=1, + score_index=score_index, + id_index=-1, + return_indices=True, + invalid_to_bottom=False, + ) + + if pad_output in attr and attr[pad_output]: + return nms_ret + # squeeze it, TF NMS is not batched + size = get_relay_op("squeeze")(nms_ret[1], axis=[1]) + data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) + + # slice to get the dynamic result + ret = get_relay_op("strided_slice")( + data_slice, begin=_expr.const([0]), end=size, slice_mode="size" + ) + + # NonMaxSuppressionV5 returns scores. pad_output is always False for NMSv5. + if return_scores: + if "soft_nms_sigma" in attr and attr["soft_nms_sigma"] != 0.0: + raise tvm.error.OpAttributeUnImplemented( + "soft_nms_sigma for NonMaxSuppressionV5 is not supported" + ) + ret_scores = _op.take(inputs[1], ret, axis=0) + return _expr.TupleWrapper(_expr.Tuple([ret, ret_scores, size]), 3) + + return ret + + return _impl + + +def convert_combined_nms_with_all_class_nms( + batch_size, + max_output_boxes_per_batch, + num_class, + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size, + clip_boxes, +): + """Converts TF combined_nms using Relay all_class_max_suppression op""" + (selected_indices, selected_scores, num_detections,) = _op.vision.all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format="tensorflow", + ) + box_range = _op.arange( + _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64" + ) + assert isinstance(batch_size, int), "dynamic batch size not supported yet." + tile_batch_reps = _op.const([batch_size, 1]) + box_range_2d = _op.tile(box_range, tile_batch_reps) + valid_mask = _op.cast( + _op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32" + ) + + def select_topk(do_zero_pad): + def true_branch(): + arange = _op.arange( + _op.const(0, dtype="int64"), + _op.const(max_output_boxes_per_batch, dtype="int64"), + dtype="int64", + ) + pad = _op.full( + _op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,) + ) + topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps) + nmsed_scores = _op.gather(selected_scores, 1, topk_indices) + nmsed_scores = nmsed_scores * valid_mask + return nmsed_scores, topk_indices + + def false_branch(): + if isinstance(max_output_boxes_per_class, int): + # Do topk on smaller input if possible + slice_mx = _op.const([max_output_boxes_per_class * num_class], dtype="int64") + selected_scores_slice = _op.strided_slice( + selected_scores, begin=_op.const([0], dtype="int64"), end=slice_mx, axes=[1] + ) + else: + selected_scores_slice = selected_scores + return _op.topk(selected_scores_slice, k=max_total_size, axis=1, ret_type="both") + + # TODO(masahi): support dynamic num_boxes + # return _expr.If(do_zero_pad, true_branch(), false_branch()) + return true_branch() if do_zero_pad else false_branch() + + assert isinstance(max_output_boxes_per_batch, int), "dynamic number of boxes not supported yet." + nmsed_scores, topk_indices = select_topk(max_output_boxes_per_batch < max_total_size) + + indices = _op.take(selected_indices, topk_indices, axis=1, batch_dims=1) + nmsed_box_indices = _op.take(indices, _op.const(1), axis=2) + nmsed_classes = _op.take(indices, _op.const(0), axis=2) + nmsed_classes = _op.cast(nmsed_classes, "float32") + nmsed_boxes = _op.take(boxes, nmsed_box_indices, axis=1, batch_dims=1) + num_detections = _op.minimum(num_detections, _op.const(max_total_size, dtype="int64")) + + if clip_boxes: + nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32")) + nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32")) + + nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2) + + return _expr.TupleWrapper( + _expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4 + ) + + +def _combined_nms(): + def _impl(inputs, attr, params, mod): + # Get parameter values + boxes = inputs[0] + scores = inputs[1] + try: + max_output_size = int(np.atleast_1d(inputs[2].data.numpy().astype("int64"))[0]) + except Exception: + try: + max_output_size = ( + _infer_value(inputs[2], params, mod).numpy().astype("int64").tolist()[0] + ) + except Exception: + max_output_size = inputs[2] + max_total_size = inputs[3] + iou_threshold = np.atleast_1d(inputs[4].data.numpy())[0] + score_threshold = np.atleast_1d(inputs[5].data.numpy())[0] + if attr["pad_per_class"]: + raise tvm.error.OpAttributeUnImplemented( + "pad_per_class for CombinedNonMaxSuppression is not supported" + ) + boxes_shape = _infer_shape(inputs[0], mod) + scores_shape = _infer_shape(inputs[1], mod) + batch_size = boxes_shape[0] + num_anchors = boxes_shape[1] + q = boxes_shape[2] + num_classes = scores_shape[2] + + assert isinstance(batch_size, int) and isinstance( + num_anchors, int + ), "Dynamic inputs not supported yet" + + if q == 1: + boxes = _op.squeeze(boxes, axis=[2]) + scores_trans = _op.transpose(scores, [0, 2, 1]) + max_output_boxes_per_batch = num_anchors * num_classes + return convert_combined_nms_with_all_class_nms( + batch_size, + max_output_boxes_per_batch, + num_classes, + boxes, + scores_trans, + max_output_size, + iou_threshold, + score_threshold, + max_total_size.data.numpy().item(), + attr["clip_boxes"], + ) + + boxes = _op.reshape(boxes, newshape=[batch_size, num_anchors * num_classes, 4]) + scores = _op.reshape(scores, newshape=[batch_size, num_anchors * num_classes, 1]) + + # In TF, class is specified by memory layout only. + ids = _op.arange(_op.const(num_classes, dtype="float32")) + ids = _op.broadcast_to(ids, (batch_size, num_anchors, num_classes)) + ids = _op.reshape(ids, newshape=[batch_size, num_anchors * num_classes, 1]) + + data = _op.concatenate([ids, scores, boxes], -1) + ct, data, indices = _op.vision.get_valid_counts( + data, score_threshold=score_threshold, id_index=0, score_index=1 + ) + nms_ret = _op.vision.non_max_suppression( + data=data, + valid_count=ct, + indices=indices, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=False, + invalid_to_bottom=True, + ) + # Dynamic slice to max_total_size + neg_one = _expr.const([-1]) + slice_end = _op.concatenate( + [neg_one, _op.expand_dims(max_total_size, axis=0), neg_one], axis=0 + ) + nms_ret = _op.strided_slice( + nms_ret, begin=[0, 0, 0], end=slice_end, strides=[1, 1, 1], slice_mode="size" + ) + + # Slice output into boxes, scores, classes + nmsed_boxes = _op.strided_slice( + nms_ret, begin=[0, 0, 2], end=[-1, -1, 4], slice_mode="size" + ) + if attr["clip_boxes"]: + nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32")) + nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32")) + nmsed_scores = _op.strided_slice( + nms_ret, begin=[0, 0, 1], end=[-1, -1, 1], slice_mode="size" + ) + nmsed_scores = _op.squeeze(nmsed_scores, axis=[2]) + nmsed_classes = _op.strided_slice( + nms_ret, begin=[0, 0, 0], end=[-1, -1, 1], slice_mode="size" + ) + nmsed_classes = _op.squeeze(nmsed_classes, axis=[2]) + # Get number of valid boxes + nms_count = _op.sum( + _op.cast(_op.greater(nmsed_scores, _expr.const(0, dtype="float32")), "int32"), axis=1 + ) + + # TVM uses -1 for invalid outputs while TF uses 0 + box_range = _op.arange(_expr.const(0, dtype="int32"), max_total_size, dtype="int32") + shape = _op.strided_slice(_op.shape_of(nmsed_boxes), begin=[0], end=[2]) + box_range = _op.broadcast_to(box_range, shape) + valid_mask = _op.cast(_op.less(box_range, _op.expand_dims(nms_count, axis=1)), "float32") + nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2) + # Could instead use mask for scores, classes if negative values are possible. + nmsed_scores = _op.maximum(nmsed_scores, _expr.const(0, dtype="float32")) + nmsed_classes = _op.maximum(nmsed_classes, _expr.const(0, dtype="float32")) + + return _expr.TupleWrapper( + _expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, nms_count]), 4 + ) + + return _impl + + +def _decode_image(): + def _impl(inputs, attr, params, mod): + # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. + warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input") + return inputs[0] + + return _impl + + +def _unravel_index(): + def _impl(inputs, attr, params, mod): + return _op.unravel_index(inputs[0], inputs[1]) + + return _impl + + +def _crop_and_resize(): + def _impl(inputs, attr, params, mod): + # input image is a 4-D tensor of shape [batch, image_height, image_width, depth] + # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2] + crop_size = _get_list_param(params, inputs[3], mod) + + method = attr["method"].decode() + method = "nearest_neighbor" if method == "nearest" else method + if method not in ["bilinear", "nearest_neighbor"]: + raise tvm.error.OpAttributeUnImplemented("Method {} is not supported".format(method)) + layout = attr["layout"] if "layout" in attr else "NHWC" + extrapolation_value = attr["extrapolation_value"] + + return get_relay_op("crop_and_resize")( + inputs[0], inputs[1], inputs[2], crop_size, layout, method, extrapolation_value + ) + + return _impl + + +def _cast(): + def _impl(inputs, attr, params, mod): + return inputs[0].astype(attr["DstT"].name) + + return _impl + + +def _expand_dims(): + def _impl(inputs, attr, params, mod): + dim_input = inputs.pop(1) + axis = _get_num_param(params, dim_input) + return AttrCvt( + op_name="expand_dims", + ignores=["Tdim", "N"], + extras={"axis": int(axis), "num_newaxis": 1}, + )(inputs, attr) + + return _impl + + +def _expm1(): + # op description: https://www.tensorflow.org/api_docs/python/tf/math/expm1 + def _impl(inputs, attr, params, mod): + exp_out = get_relay_op("exp")(inputs[0]) + return exp_out - tvm.relay.const(1.0) + + return _impl + + +def _resize(method): + def _impl(inputs, attr, params, mod): + if attr["_output_shapes"][0] is not None: + size = attr["_output_shapes"][0][1:3] + # Important that the size is defined. If an axis is not, we need to infer what + # the shape should be. + if -1 in size: + size = _infer_value(inputs[1], params, mod).numpy().reshape([-1]).tolist() + else: + size = _infer_value(inputs[1], params, mod).numpy().reshape([-1]).tolist() + + attr["size"] = size + inputs.pop(1) + # NHWC + attr["layout"] = "NHWC" + if attr.pop("align_corners") is True: + attr["coordinate_transformation_mode"] = "align_corners" + else: + attr["coordinate_transformation_mode"] = "asymmetric" + + # Ignore the new attributes from TF2.0, for now. + return AttrCvt( + op_name="resize", ignores=["Tdim", "half_pixel_centers"], extras={"method": method} + )(inputs, attr) + + return _impl + + +def _check_numerics(): + def _impl(inputs, attr, params, mod): + # Making a copy node assuming no need to verify + return AttrCvt(op_name="copy", ignores=["message"])(inputs, attr) + + return _impl + + +def _assert(): + # ToDo: In general people want asserts to be gone from TensorFlow graphs + # when they are optimizing them, so converting it to a no-op is + # reasonable. However, it would be nice to have the option to keep them + # once Relay gets a Halt or Assert op. + return _no_op() + + +def _no_op(): + def _impl(inputs, attr, params, mod): + # ToDo: This should really be an op that returns nothing, which could + # be represented as an empty tuple. It turns out that TVM + # infrastructure doesn't like running functions that return None and + # also don't like running functions that return an empty tuple. So it + # doesn't work, but it should be made to work and then this could be + # improved. In the mean time, it is hard to imagine a case where it + # matters in any real way that a no-op is converted to a constant 0. + return tvm.relay.const(0) + + return _impl + + +def _matmul(): + def _impl(inputs, attr, params, mod): + channels = _infer_channels(inputs[1], not attr["transpose_b"]) + if attr["transpose_a"]: + inputs[0] = _op.transpose(inputs[0], axes=(1, 0)) + if not attr["transpose_b"]: + inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) + return AttrCvt( + op_name="dense", extras={"units": channels}, ignores=["transpose_a", "transpose_b", "T"] + )(inputs, attr) + + return _impl + + +def _batch_matmul(): + def _impl(inputs, attr, params, mod): + input_x = inputs[0] + input_y = inputs[1] + orig_shape_x = _infer_shape(input_x, mod) + orig_shape_y = _infer_shape(input_y, mod) + ndim = len(orig_shape_x) + + is_static = not check_symbolic_shape(orig_shape_x) + + if ndim > 3 and not is_static: + shape_of_x = list_shape_of(inputs[0], ndim) + shape_of_y = list_shape_of(inputs[1], ndim) + + # reshape n-dimensional batch matmul into 3d + if ndim > 3: + outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] + if is_static: + num_outer_elts = np.prod(outer_dims) + new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) + new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + else: # handle dynamic shape (dyn.reshape op) + # new shape = [prod(shape[:-2]), -2, -1] + new_shape_x = [_op.const(1), shape_of_x[-2], shape_of_x[-1]] + new_shape_y = [_op.const(1), shape_of_y[-2], shape_of_y[-1]] + for i in range(ndim - 2): + new_shape_x[0] *= shape_of_x[i] + new_shape_y[0] *= shape_of_y[i] + new_shape_x = _op.concatenate(_op.Tuple(new_shape_x), axis=0) + new_shape_y = _op.concatenate(_op.Tuple(new_shape_y), axis=0) + + input_x = _op.reshape(input_x, newshape=new_shape_x) + input_y = _op.reshape(input_y, newshape=new_shape_y) + + adj_x = attr["adj_x"] + adj_y = attr["adj_y"] + input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x + input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y + ret = get_relay_op("batch_matmul")(input_x, input_y) + + # reshape result back to n-dimensional + if ndim > 3: + if is_static: + final_shape = list(orig_shape_x) + final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2] + final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1] + else: + # calculate the resulting shape = [shape[:-2], 0, 0] + final_shape = list(shape_of_x) + final_shape[-2] = shape_of_x[-1] if adj_x else shape_of_x[-2] + final_shape[-1] = shape_of_y[-2] if adj_y else shape_of_y[-1] + final_shape = _op.concatenate(_op.Tuple(final_shape), axis=0) + + ret = _op.reshape(ret, newshape=final_shape) + return ret + + return _impl + + +def _sparse_tensor_dense_matmul(): + def _impl(inputs, attr, params, mod): + # Loading this by default causes TVM to not be loadable from other languages. + # Sparse utility from scipy + from scipy.sparse import csr_matrix + + assert len(inputs) == 4, "There should be 4 input tensors" + + indices_tensor = _infer_value(inputs[0], params, mod).numpy() + values_tensor = _infer_value(inputs[1], params, mod).numpy() + dense_shape_tensor = _infer_value(inputs[2], params, mod).numpy() + + data = inputs[3] + + rows = [x[0] for x in indices_tensor] + cols = [x[1] for x in indices_tensor] + + # Create scipy sparse Tensor(CSR) + weight_sp = csr_matrix( + (values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist()) + ) + + # As per tensorflow implementation, we have 4 possible input combination + # and the first input(A) is always sparse and second input(B) is always dense. + # Case 1: A , B , adjoint_a=False, adjoint_b=False --> A * B + # Case 2: A , B , adjoint_a=True, adjoint_b=False --> A.T * B + # Case 3: A , B , adjoint_a=False, adjoint_b=True --> A * B.T + # Case 4: A , B , adjoint_a=True, adjoint_b=True --> A.T * B.T + # + # Topi implementation for sparse_dense(matmul) has 2 possible input + # combination where first input(A) is always dense + # and second input(B) is always sparse. + # Case 1: A , B, sparse_lhs = False --> A * B.T + # Case 2: A , B, sparse_lhs = True --> B * A.T + # + # The mapping would be as below: + # TF Case 1: A , B , adjoint_a=False, adjoint_b=False + # --> In TF: A * B --> In Topi: A * B.T.T + # --> sparse_dense(transpose(B), A, sparse_lhs=True) + # + # TF Case 2: A , B , adjoint_a=True, adjoint_b=False + # --> In TF: A.T * B --> In Topi: A.T * B.T.T + # --> sparse_dense(transpose(B), transpose(A), sparse_lhs=True) + # + # TF Case 3: A , B , adjoint_a=False, adjoint_b=True + # --> In TF: A * B.T --> In Topi: A * B + # --> sparse_dense(B, A, sparse_lhs=True) + # + # TF Case 4: A , B , adjoint_a=True, adjoint_b=True + # --> In TF: A.T * B.T --> In Topi: (B * A.T).T + # --> transpose(sparse_dense(B, transpose(A), sparse_lhs=False)) + + # By default, in tensorflow the first input ,i.e., data is sparse + sparse_lhs = True + + # TF Case 1: + if not attr.get("adjoint_a") and not attr.get("adjoint_b"): + data = _op.transpose(data) + # TF Case 2: + elif attr.get("adjoint_a") and not attr.get("adjoint_b"): + data = _op.transpose(data) + weight_sp = csr_matrix(weight_sp.transpose()) + # TF Case 3: + elif not attr.get("adjoint_a") and attr.get("adjoint_b"): + pass + # TF Case 4: + # attr.get("adjoint_a") and attr.get("adjoint_b"): + else: + sparse_lhs = False + weight_sp = csr_matrix(weight_sp.transpose()) + + weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype) + weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype) + weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype) + + ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs], sparse_lhs) + + if not sparse_lhs: + # TF Case 4 + ret = _op.transpose(ret) + + return ret + + return _impl + + +def _sparse_fill_empty_rows(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 4, "There should be 4 input tensors" + sparse_indices = inputs[0] + sparse_values = inputs[1] + sparse_indices_num_cols = _infer_shape(sparse_indices, mod)[1] + first_column = _op.split(sparse_indices, sparse_indices_num_cols, axis=1)[0] + sorted_indices = _op.argsort(_op.squeeze(first_column)) + sorted_sparse_indices = _op.take(sparse_indices, sorted_indices, axis=0) + sorted_sparse_values = _op.take(sparse_values, sorted_indices, axis=0) + new_sparse_indices, new_sparse_values, empty_row_indicator = _op.sparse_fill_empty_rows( + sorted_sparse_indices, sorted_sparse_values, inputs[2], inputs[3] + ) + + return _expr.TupleWrapper( + _expr.Tuple([new_sparse_indices, new_sparse_values, empty_row_indicator]), + 3, + ) + + return _impl + + +def _sparse_reshape(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + new_indices, new_shape = get_relay_op("sparse_reshape")(inputs[0], inputs[1], inputs[2]) + return _expr.TupleWrapper(_expr.Tuple([new_indices, new_shape]), 2) + + return _impl + + +def _math_segment_sum(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 2, "There should be 2 input tensors" + return get_relay_op("segment_sum")(inputs[0], inputs[1]) + + return _impl + + +def _sparse_segment_sum(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + return _op.segment_sum(data, inputs[2]) + + return _impl + + +def _sparse_segment_sum_with_num_segments(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 4, "There should be 4 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + num_segments = int(inputs[3].data.numpy().item()) + return _op.segment_sum(data, inputs[2], num_segments) + + return _impl + + +def row_wise_divide(multi_dim_tensor, one_dim_vector): + """ + This function enables row-wise division of multi_dim_tensor and one_dim_vector. + To achieve this, it is first tiled to the appropriate shape and then elemwise_division + """ + multi_dim_tensor_offrow_shape = _op.strided_slice( + _op.shape_of(multi_dim_tensor, "int32"), [1], [-1], slice_mode="size" + ) + one_dim_vector_tiled_shape = _op.concatenate( + [_op.reverse(multi_dim_tensor_offrow_shape, 0), _expr.const([1])], axis=0 + ) + one_dim_vector_tiled = _op.transpose(_op.tile(one_dim_vector, one_dim_vector_tiled_shape)) + return _op.divide(multi_dim_tensor, one_dim_vector_tiled) + + +def count_all_indices(segment_ids, counts_dtype, num_segments=None): + """ + This snippet calculates the sqrt count of each index among all valid indices + Valid indices are from 0 to max of [segment ids, num_segments] + """ + + max_segments = _op.reshape(_op.max(segment_ids), -1) + _expr.const([1]) + if num_segments: + max_segments = _op.maximum(max_segments, _expr.const([num_segments])) + max_ones = _op.maximum(max_segments, _op.shape_of(segment_ids)) + counts = _op.segment_sum( + _op.ones(max_ones, counts_dtype), segment_ids, num_segments=num_segments + ) + real_counts = _op.clip(counts, 1, 2147483647) # Clip max doesn't work over int32 + return real_counts + + +def _sparse_segment_sum_sqrtn(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + real_counts = count_all_indices(inputs[2], attr["T"].name) + real_sqrt_counts = _op.sqrt(_op.cast_like(real_counts, data)) + + # Calculate regular segment sum + segment_sum = _op.segment_sum(data, inputs[2]) + + return row_wise_divide(segment_sum, real_sqrt_counts) + + return _impl + + +def _sparse_segment_sum_sqrtn_with_num_segments(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 4, "There should be 4 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + num_segments = int(inputs[3].data.numpy().item()) + real_counts = count_all_indices(inputs[2], attr["T"].name, num_segments=num_segments) + real_sqrt_counts = _op.sqrt(_op.cast_like(real_counts, data)) + + # Calculate regular segment sum + segment_sum = _op.segment_sum(data, inputs[2], num_segments=num_segments) + + return row_wise_divide(segment_sum, real_sqrt_counts) + + return _impl + + +def _sparse_segment_mean(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + real_counts = count_all_indices(inputs[2], attr["T"].name) + + # Calculate regular segment sum + segment_sum = _op.segment_sum(data, inputs[2]) + + return row_wise_divide(segment_sum, real_counts) + + return _impl + + +def _sparse_segment_mean_with_num_segments(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 4, "There should be 4 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + num_segments = int(inputs[3].data.numpy().item()) + real_counts = count_all_indices(inputs[2], attr["T"].name, num_segments=num_segments) + + # Calculate regular segment sum + segment_sum = _op.segment_sum(data, inputs[2], num_segments=num_segments) + + return row_wise_divide(segment_sum, real_counts) + + return _impl + + +def _sparse_tensor_dense_add(): + # Sparse utility from scipy + from scipy.sparse import csr_matrix + + def _impl(inputs, attr, params, mod): + assert ( + len(inputs) == 4 + ), "There should be 4 input tensors [sparse_indices, sparse_values, sparse_shape, dense]." + + indices_tensor = _infer_value(inputs[0], params, mod).numpy() + values_tensor = _infer_value(inputs[1], params, mod).numpy() + dense_shape_tensor = _infer_value(inputs[2], params, mod).numpy() + + data = inputs[3] + + rows = [x[0] for x in indices_tensor] + cols = [x[1] for x in indices_tensor] + + # Create scipy sparse Tensor(CSR) + weight_sp = csr_matrix( + (values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist()) + ) + + weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype) + weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype) + weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype) + + ret = _op.nn.sparse_add(data, [weight_data, weight_indices, weight_indptrs]) + + return ret + + return _impl + + +def _identity(): + def _impl(inputs, attr, params, mod): + return inputs[0] + + return _impl + + +def _identityn(): + def _impl(inputs, attr, params, mod): + return inputs + + return _impl + + +def _concatV2(): + def _impl(inputs, attr, params, mod): + pop_node = inputs.pop(len(inputs) - 1) + axis = int(_get_num_param(params, pop_node)) + return AttrCvt(op_name="concatenate", ignores=["T", "N", "Tidx"], extras={"axis": axis})( + [inputs], attr + ) + + return _impl + + +def _concat(): + def _impl(inputs, attr, params, mod): + pop_node = inputs.pop(0) + axis = int(_get_num_param(params, pop_node)) + return AttrCvt(op_name="concatenate", ignores=["N"], extras={"axis": axis})([inputs], attr) + + return _impl + + +def _pack(): + def _impl(inputs, attr, params, mod): + axis = int(attr["axis"]) + inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] + return _op.concatenate(inputs_reshaped, axis) + + return _impl + + +def _tensor_array(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get("dtype").name + assert not attr["dynamic_size"], "Dynamic size tensor array is " "not supported in TVM yet." + + if "shape" in attr: + shape = attr["shape"] + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, shape) + static_tensor_array_ops.register() + tensor_array_constructor = static_tensor_array_ops.get_global_var("tensor_array") + tensor_array = tensor_array_constructor(inputs[0]) + else: + tensor_array_constructor = prelude.get_global_var("tensor_array", dtype_str) + tensor_array = tensor_array_constructor(inputs[0]) + return tensor_array + + return _impl + + +def _tensor_array_scatter(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get("T").name + input_ta = inputs[0] + input_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) + values_shape = _infer_shape(inputs[2], prelude.mod) + input_t_shape = values_shape[1:] + indices_shape = _infer_shape(inputs[1], prelude.mod) + + if input_shape is None: + values_rank = len(values_shape) + unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) + unstack_function = prelude.get_global_var(unstack_name, dtype_str) + values = unstack_function(inputs[2]) + tensor_array_scatter_func = prelude.get_global_var("tensor_array_scatter", dtype_str) + else: + input_t_shape = _get_more_static_shape(input_t_shape, input_shape) + values_shape = (values_shape[0],) + input_t_shape + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_t_shape) + static_tensor_array_ops.register() + # Register static indices shape + if isinstance(indices_shape[0], int): + static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True) + tensor_array_scatter_func = prelude.get_global_var_static( + "tensor_array_scatter", dtype_str, input_t_shape + ) + + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, values_shape) + static_tensor_array_ops.register() + unstack_function = prelude.get_global_var_static( + "tensor_array_unstack", dtype_str, values_shape + ) + values = unstack_function(inputs[2]) + ret = tensor_array_scatter_func(input_ta, inputs[1], values) + return ret + + return _impl + + +def _tensor_array_gather(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get("dtype").name + input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude) + indices_shape = _infer_shape(inputs[1], prelude.mod) + + if input_shape is None: + gather_func = prelude.get_var("tensor_array_gather", dtype_str) + out = gather_func(inputs[2], inputs[1]) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape) + static_tensor_array_ops.register() + + if not isinstance(indices_shape[0], int): + gather_function = prelude.get_global_var_static( + "tensor_array_gather", dtype_str, input_shape + ) + out_tensor_t = gather_function(inputs[2], inputs[1]) + out_shape = (indices_shape[0],) + input_shape + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape) + static_tensor_array_ops.register() + + # Output shape is (indices_shape[0],) + input_shape + get_data_func = prelude.get_global_var_static( + "tensor_get_data", dtype_str, out_shape + ) + out = get_data_func(out_tensor_t) + else: + # For fixed length indices, directly generate static shape output + read_func = prelude.get_global_var_static( + "tensor_array_read", dtype_str, input_shape + ) + get_data_func = prelude.get_global_var_static( + "tensor_get_data", dtype_str, input_shape + ) + tensor_list = [] + for i in range(indices_shape[0]): + index = _op.take(inputs[1], tvm.relay.const(i)) + out_tensor = get_data_func(read_func(inputs[2], index)) + tensor_list.append(_op.expand_dims(out_tensor, axis=0)) + + if indices_shape[0] > 1: + out = _op.concatenate(tensor_list, axis=0) + else: + out = tensor_list[0] + + return out + + return _impl + + +def _tensor_array_size(): + def _impl(inputs, attr, params, prelude): + return prelude.length(inputs[0]) + + return _impl + + +def _tensor_array_write(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get("T").name + input_ta = inputs[3] + input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) + input_t_shape = _infer_shape(inputs[2], prelude.mod) + input_rank = len(input_t_shape) + + if input_ta_shape is None: + tensor_name = "tensor{}".format(input_rank) + tensor_func = prelude.get_tensor_ctor(tensor_name, dtype_str) + v = tensor_func(inputs[2]) + write_func = prelude.get_global_var("tensor_array_write", dtype_str) + else: + input_ta_rank = len(input_ta_shape) + assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}".format( + input_ta_rank, input_rank + ) + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) + static_tensor_array_ops.register() + tensor_func = static_tensor_array_ops.get_ctor("tensor_constructor") + v = tensor_func(inputs[2]) + # Write tensor with more static shape + actual_shape = _get_more_static_shape(input_t_shape, input_ta_shape) + if actual_shape != input_t_shape: + new_shape = [] + num_any_dim = 0 + for dim in actual_shape: + if not isinstance(dim, int): + num_any_dim += 1 + new_shape.append(dim if isinstance(dim, int) else -1) + if num_any_dim <= 1: + v = tensor_func(_op.reshape(inputs[2], new_shape)) + + write_func = prelude.get_global_var_static( + "tensor_array_write", dtype_str, input_ta_shape + ) + + return write_func(input_ta, _op.take(inputs[1], tvm.relay.const(0)), v) + + return _impl + + +def _tensor_array_read(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr["dtype"].name + input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude) + + if input_shape is None: + read_func = prelude.get_global_var("tensor_array_read", dtype_str) + out = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape) + static_tensor_array_ops.register() + read_func = static_tensor_array_ops.get_global_var("tensor_array_read") + out_tensor = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) + get_data_func = static_tensor_array_ops.get_global_var("tensor_get_data") + out = get_data_func(out_tensor) + + return out + + return _impl + + +def _tensor_array_split(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get("T").name + input_ta = inputs[0] + input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) + lengths = _op.cast(inputs[2], "int32") + lengths_shape = _infer_shape(lengths, prelude.mod) + value_shape = _infer_shape(inputs[1], prelude.mod) + input_rank = len(value_shape) + + if input_ta_shape is None: + tensor_name = "tensor{}".format(input_rank) + tensor_ctor = prelude.get_tensor_ctor(tensor_name, dtype_str) + v = tensor_ctor(inputs[1]) + split_func = prelude.get_global_var("tensor_array_split", dtype_str) + else: + input_ta_rank = len(input_ta_shape) + assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}".format( + input_ta_rank, input_rank + ) + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) + static_tensor_array_ops.register() + + # Check static value/indices shape + if isinstance(value_shape[0], int) or isinstance(lengths_shape[0], int): + static_tensor_array_ops.define_tensor_array_split(value_shape, lengths_shape, True) + + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, value_shape) + static_tensor_array_ops.register() + tensor_ctor = static_tensor_array_ops.get_ctor("tensor_constructor") + v = tensor_ctor(inputs[1]) + split_func = prelude.get_global_var_static( + "tensor_array_split", dtype_str, input_ta_shape + ) + + return split_func(input_ta, v, lengths) + + return _impl + + +def _tensor_array_concat(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr["dtype"].name + input_shape = get_tensor_array_shape(inputs[1], dtype_str, prelude) + + if input_shape is None: + concat_func = prelude.get_global_var("tensor_array_concat", dtype_str) + out = concat_func(inputs[1]) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape) + static_tensor_array_ops.register() + concat_func = prelude.get_global_var_static( + "tensor_array_concat", dtype_str, input_shape + ) + out_tensor = concat_func(inputs[1]) + out_shape = (Any(),) + input_shape[1:] + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape) + static_tensor_array_ops.register() + get_data_func = prelude.get_global_var_static("tensor_get_data", dtype_str, out_shape) + out = get_data_func(out_tensor) + + return out + + return _impl + + +def _tile(): + def _impl(inputs, attr, params, mod): + reps_input = inputs.pop() + if isinstance(reps_input, _expr.Call): + np_reps = _infer_value(reps_input, params, mod).numpy() + reps = [np_reps.flatten()[i] for i in range(np_reps.flatten().shape[0])] + else: + reps = _get_list_param(params, reps_input, mod) + new_input = [inputs.pop(0)] + + return AttrCvt(op_name="tile", extras={"reps": tuple(reps)}, ignores=["Tmultiples"])( + new_input, attr + ) + + return _impl + + +def _slice(): + def _impl(inputs, attr, params, mod): + try: + begin = _get_list_param(params, inputs[1], mod) + except Exception: + # Handle symbolic begin + begin = inputs[1] + try: + size = _get_list_param(params, inputs[2], mod) + except Exception: + # Handle symbolic size + size = inputs[2] + + # Align begin and strides for dynamic shape. + data_dim = len(_infer_shape(inputs[0], mod)) + strides = [1] * data_dim + if not isinstance(begin, (_expr.Call, _expr.Var)): + for _ in range(len(begin), data_dim): + begin.append(0) + elif not isinstance(size, (_expr.Call, _expr.Var)): + for _ in range(len(size), data_dim): + size.append(-1) + return _op.strided_slice( + inputs[0], begin=begin, end=size, strides=strides, slice_mode="size" + ) + + return _impl + + +def _reshape(): + def _impl(inputs, attr, params, mod): + pop_node = inputs.pop(1) + + try: + shape_arg = _get_tuple_param(params, pop_node) + except AttributeError: + # Shape operator is already pruned, hence + # try to infer shape by precompute prune if possible. + try: + params_new = _infer_value(pop_node, params, mod) + shape_arg = tuple(params_new.numpy().astype("int32").flatten()) + except Exception: + # Deal with symbolic shape case. + if isinstance(pop_node, _expr.Call) and "shape_of" in str(pop_node.op): + # shape_of is the direct ancestor. + return _op.reshape_like(inputs[0], pop_node.args[0]) + shape_arg = pop_node + + return AttrCvt(op_name="reshape", extras={"newshape": shape_arg}, ignores=["Tshape"])( + inputs, attr + ) + + return _impl + + +def _depth_to_space(): + def _impl(inputs, attr, params, mod): + block_size = int(attr["block_size"]) + layout = attr["data_format"].decode("utf-8") + return _op.nn.depth_to_space(inputs[0], block_size, layout) + + return _impl + + +def _space_to_depth(): + def _impl(inputs, attr, params, mod): + block_size = int(attr["block_size"]) + layout = attr["data_format"].decode("utf-8") + return _op.nn.space_to_depth(inputs[0], block_size, layout) + + return _impl + + +def _sparse_to_dense(): + def _impl(inputs, attr, params, mod): + sparse_indices = inputs[0] + output_shape = inputs[1] + sparse_values = inputs[2] + default_value = inputs[3] + + return _op.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) + + return _impl + + +def _bias_add(): + def _impl(inputs, attr, params, mod): + # Must expand for proper broadcasting in NCHW. + if "data_format" in attr and attr["data_format"].decode("utf-8") == "NCHW": + bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1)) + else: + bias = inputs[1] + return _op.add(inputs[0], bias) + + return _impl + + +def _broadcast_args(): + def _impl(inputs, attr, params, mod): + if isinstance(inputs[0], _expr.Var): + s0 = params[inputs[0].name_hint] + else: + s0 = _infer_value(inputs[0], params, mod) + if isinstance(inputs[1], _expr.Var): + s1 = params[inputs[1].name_hint] + else: + s1 = _infer_value(inputs[1], params, mod) + s0 = list(s0.numpy().reshape([-1])) + s1 = list(s1.numpy().reshape([-1])) + s0_size, s1_size = len(s0), len(s1) + + out = deque([]) + for i in range(1, min(s0_size, s1_size) + 1): + if s0[s0_size - i] == s1[s1_size - i]: + out.appendleft(s0[s0_size - i]) + elif s0[s0_size - i] == 1: + out.appendleft(s1[s1_size - i]) + else: + assert s1[s1_size - i] == 1, "Incompatible broadcast type %s and %s" % ( + s0[s0_size - i], + s1[s1_size - i], + ) + out.appendleft(s0[s0_size - i]) + if s0_size < s1_size: + for i in range(s0_size + 1, s1_size + 1): + out.appendleft(s1[s1_size - i]) + if s1_size < s0_size: + for i in range(s1_size + 1, s0_size + 1): + out.appendleft(s0[s0_size - i]) + return _expr.const(list(out), attr["T"].name) + + return _impl + + +def _broadcast_to(): + def _impl(inputs, attr, params, mod): + if isinstance(inputs[1], _expr.Var): + shape = params[inputs[1].name_hint] + else: + shape = _infer_value(inputs[1], params, mod) + shape = list(shape.numpy().reshape([-1])) + return _op.broadcast_to(inputs[0], shape) + + return _impl + + +def _squeeze(): + def _impl(inputs, attr, params, mod): + if len(attr["squeeze_dims"]) == 0: + attr["squeeze_dims"] = None + return AttrCvt( + op_name="squeeze", transforms={"squeeze_dims": "axis"}, ignores=["T", "_cloned"] + )(inputs, attr) + + return _impl + + +def _fused_batch_norm(): + def _impl(inputs, attr, params, mod): + # Tensorflow: (data, gamma, beta, moving_mean, moving_variance) + # Relay: (data, gamma, beta, moving_mean, moving_varience) + assert len(inputs) == 5 + axis = 3 + need_cast = False + + if "data_format" in attr: + attr["data_format"] = attr["data_format"].decode("utf-8") + if attr["data_format"] == "NCHW": + axis = 1 + if "U" in attr and attr["U"].name != attr["T"].name: + need_cast = True + inputs[0] = _op.cast(inputs[0], dtype=attr["U"].name) + # Check if mean and variance are empty + # If so, replace them with Mean and Variance Ops + # For run-time calculation + moving_mean_shape = [int(n) for n in inputs[3].type_annotation.shape] + moving_variance_shape = [int(n) for n in inputs[4].type_annotation.shape] + if moving_mean_shape[0] == 0 and moving_variance_shape[0] == 0: + inputs[3] = _op.mean(inputs[0], axis=axis, keepdims=False, exclude=True) + inputs[4] = _op.variance(inputs[0], axis=axis, keepdims=False, exclude=True) + out = AttrCvt( + op_name="batch_norm", + transforms={"scale_after_normalization": "scale", "variance_epsilon": "epsilon"}, + extras={"axis": axis}, + ignores=["data_format", "U", "exponential_avg_factor"], + disables=["momentum"], + )(inputs, attr) + + if need_cast: + out = _expr.TupleGetItem(out.astuple(), 0) + out = _op.cast(out, dtype=attr["T"].name) + return out + + return _impl + + +def _batch_norm(): + def _impl(inputs, attr, params, mod): + # Rearrange inputs from + # (data, moving_mean, moving_variance, beta, gamma) + # to + # (data, gamma, beta, moving_mean, moving_var) + new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]] + + axis = 3 + if "data_format" in attr: + attr["data_format"] = attr["data_format"].decode("utf-8") + if attr["data_format"] == "NCHW": + axis = 1 + + return AttrCvt( + op_name="batch_norm", + transforms={"scale_after_normalization": "scale", "variance_epsilon": "epsilon"}, + extras={"axis": axis}, + ignores=["data_format", "exponential_avg_factor"], + disables=["momentum"], + )(new_inputs, attr) + + return _impl + + +def _relu6(): + def _impl(inputs, attr, params, mod): + return _op.clip(inputs[0], a_min=0, a_max=6) + + return _impl + + +def _shape(): + def _impl(inputs, attr, params, mod): + is_symbolic_shape = False + input_shape = _infer_shape(inputs[0], mod) + for axis in input_shape: + if not isinstance(axis, (int, tvm.tir.IntImm)): + is_symbolic_shape = True + break + + if is_symbolic_shape: + ret = _op.shape_of(inputs[0], dtype=attr["out_type"].name) + else: + ret = np.array(input_shape, dtype=attr["out_type"].name) + return ret + + return _impl + + +def _fill(): + def _impl(inputs, attr, params, mod): + try: + output_shape = _infer_value(inputs[0], params, mod).numpy().tolist() + except Exception: + output_shape = inputs[0] + + return _op.full(inputs[1], output_shape, attr["T"].name) + + return _impl + + +def _lrn(): + def _impl(inputs, attr, params, mod): + attr_new = {} + depth_radius = attr.get("depth_radius", 5) + size = (depth_radius * 2) + 1 + attr_new["axis"] = 3 # Fix axis, NHWC format + attr_new["size"] = size + attr_new["bias"] = attr.get("bias", 1) + attr_new["alpha"] = attr.get("alpha", 1) * size + attr_new["beta"] = attr.get("beta", 0.5) + return AttrCvt(op_name="lrn")(inputs, attr_new) + + return _impl + + +def _sum(): + def _impl(inputs, attr, params, mod): + axis = _get_tuple_param(params, inputs[1]) + return AttrCvt( + op_name="sum", + extras={"axis": axis}, + transforms={"keep_dims": "keepdims"}, + ignores=["name", "Tidx"], + )([inputs[0]], attr) + + return _impl + + +def _reduce(op): + def _impl(inputs, attr, params, mod): + axis = _get_list_param(params, inputs[1], mod) + axis = tuple(axis) + if not axis: + axis = None + return AttrCvt( + op_name=op, + extras={"axis": axis}, + transforms={"keep_dims": "keepdims"}, + ignores=["name", "Tidx"], + )([inputs[0]], attr) + + return _impl + + +def _euclidean_norm(): + def _impl(inputs, attr, params, mod): + axis = tuple(_get_list_param(params, inputs[1], mod)) + keep_dims = bool(attr.get("keep_dims", False)) + return _op.sqrt( + _op.cast(_op.reduce.sum(_op.multiply(inputs[0], inputs[0]), axis, keep_dims), "float32") + ) + + return _impl + + +def _square(): + def _impl(inputs, attr, params, mod): + return _op.multiply(inputs[0], inputs[0]) + + return _impl + + +def _gather(): + "GatherV2, Gather" + + def _impl(inputs, attr, params, mod): + if len(inputs) > 2: + axis = _get_num_param(params, inputs.pop(2)) + else: + axis = 0 + batch_dims = 0 + if int(attr.get("batch_dims", 0)) != 0: + batch_dims = int(attr.get("batch_dims", 0)) + new_input = inputs[0:2] + op_ = AttrCvt( + op_name="take", + extras={ + "axis": tvm.tir.const(axis, "int32"), + "batch_dims": tvm.tir.const(batch_dims, "int32"), + }, + ignores=["Tindices", "Tparams", "validate_indices", "Taxis", "_class"], + )(new_input, attr) + return op_ + + return _impl + + +def _gather_nd(): + """GatherNd""" + + def _impl(inputs, attr, params, mod): + indices_dims = len(_infer_shape(inputs[1], mod)) + indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1))) + return AttrCvt(op_name="gather_nd", ignores=["Tindices", "Tparams", "Taxis", "_class"])( + [inputs[0], indices], attr + ) + + return _impl + + +def _stridedSlice(): + def _impl(inputs, attr, params, mod): + """Strided Slice. + Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice + Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/ + tensorflow/core/util/strided_slice_op.cc#L147-L368 + """ + begin = _get_list_param(params, inputs[1], mod) + end = _get_list_param(params, inputs[2], mod) + stride = _get_list_param(params, inputs[3], mod) + + begin_mask = int(attr.get("begin_mask", 0)) + end_mask = int(attr.get("end_mask", 0)) + ellipsis_mask = int(attr.get("ellipsis_mask", 0)) + new_axis_mask = int(attr.get("new_axis_mask", 0)) + shrink_axis_mask = int(attr.get("shrink_axis_mask", 0)) + in_type = _infer_type(inputs[0], mod) + data_shape = get_const_tuple(in_type.checked_type.shape) + data_dim = len(data_shape) + stride_dim = len(stride) + if data_dim == 0 and isinstance(inputs[0], _expr.Constant): + new_data = inputs[0].data.numpy().reshape(1) + return _expr.const(new_data, inputs[0].data.dtype) + + # This is a special routine to handle strided_slice after shape_of. + # We need this since in some cases we want to do strided_slice on + # a partial symbolic shape, such as (1, ?), and get a static shape + # (1,). Directly slice on shape_of will result in fully dynamic shape. + # TODO(kevinthesun): Can we generalize this process with partial eval? + if isinstance(inputs[0], _expr.Call) and inputs[0].op == _op.get("shape_of"): + bg = begin[0] + ed = end[0] + st = stride[0] + + if ed <= 0 < st: + ed += data_shape[0] + + in_shape = _infer_shape(inputs[0].args[0], mod) + dtype = in_type.checked_type.dtype + out_data = [] + idx = bg + while idx < ed: + if isinstance(in_shape[idx], int): + out_data.append(in_shape[idx]) + else: + break + idx += st + + # Only return when in_shape is fully static in the range from begin to end. + if idx >= ed: + ret = _expr.const(out_data, dtype) + if shrink_axis_mask: + ret = _op.squeeze(ret) + + return ret + + def _transform_mask(stride_dim, ellipsis_mask): + """Handle mask inputs to create new begin, end, stride and output shape""" + m_begin = [0] * data_dim + m_end = [0] * data_dim + m_stride = [0] * data_dim + fshape_indices = [] + # Count new axis after ellipsis_mask, consider while applying ellipsis_mask. + ellipsis_seen = False + new_axes_after_ellipsis = 0 + for i in range(stride_dim): + mask = 1 << i + if ellipsis_seen and (mask & new_axis_mask) != 0: + new_axes_after_ellipsis += 1 + if (mask & ellipsis_mask) != 0: + ellipsis_seen = True + if not ellipsis_seen: + # Used later for extending the stride attributes in the below loop. + ellipsis_mask |= 1 << stride_dim + stride_dim += 1 + final_index = 0 + for index in range(stride_dim): + mask = 1 << index + if mask & ellipsis_mask: + # Identify the end index for applying ellipsis_mask + to_index = min( + ((data_dim - (stride_dim - index)) + 1 + new_axes_after_ellipsis), data_dim + ) + for i in range(final_index, to_index): + m_begin[final_index] = 0 + m_end[final_index] = data_shape[final_index] + m_stride[final_index] = 1 + fshape_indices.append(final_index) + final_index += 1 + elif mask & new_axis_mask: + fshape_indices.append(-1) + elif not mask & new_axis_mask: + if final_index == len(m_begin): + break + if mask & begin_mask: + m_begin[final_index] = -1 if stride[index] < 0 else 0 + elif begin[index]: + m_begin[final_index] = begin[index] + if mask & end_mask: + m_end[final_index] = ( + -(data_shape[final_index] + 1) + if stride[index] < 0 + else data_shape[final_index] + ) + elif end[index]: + m_end[final_index] = end[index] + m_stride[final_index] = stride[index] + if mask & shrink_axis_mask: + # Tensorflow make axis with shrink_axis_mask as dimension 1 + m_begin[final_index] = ( + data_shape[final_index] + begin[index] + if begin[index] < 0 + else begin[index] + ) + m_end[final_index] = begin[index] + 1 + m_stride[final_index] = 1 + fshape_indices.append(-2) + else: + fshape_indices.append(final_index) + + final_index += 1 + return m_begin, m_end, m_stride, fshape_indices + + fshape_indices = None + if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: + begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) + out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) + out_shape = _infer_shape(out, mod=mod) + if not fshape_indices: + fshape_indices = range(len(out_shape)) + + # Create final output shape. + final_output = [] + for gather_index in fshape_indices: + if gather_index == -1: + final_output.append(1) + elif gather_index == -2: + pass + else: + final_output.append(out_shape[gather_index]) + + if not final_output: + if not shrink_axis_mask: + ret = out + else: + final_shape = [] + for dim in out_shape: + if dim != 1: + final_shape.append(dim) + if len(final_shape) == 0: + ret = _op.squeeze(out) + else: + # We need reshape to handle dynamic shape. + ret = _op.reshape(out, newshape=tuple(final_shape)) + else: + ret = _op.reshape(out, newshape=tuple(final_output)) + return ret + + return _impl + + +def _pad(name): + def _impl(inputs, attr, params, mod): + try: + padlist = _get_param(params, inputs[1]) + except (IndexError, KeyError, AttributeError): + try: + padlist = _infer_value(inputs[1], params, mod).numpy().tolist() + except Exception: + padlist = inputs[1] + + if isinstance(padlist, _expr.Expr): + paddings = padlist + else: + paddings = tuple(tuple(l) for l in padlist) + attr["pad_width"] = paddings + attr["pad_value"] = 0 + new_inputs = [inputs[0]] + if name == "PadV2": + try: + attr["pad_value"] = _get_num_param(params, inputs[2]) + except (IndexError, KeyError, AttributeError): + attr["pad_value"] = inputs[2] + return AttrCvt( + op_name="pad", + ignores=["Tpaddings"], + )(new_inputs, attr) + + return _impl + + +def _mirror_pad(): + def _impl(inputs, attr, params, mod): + padlist = _get_param(params, inputs[1]) + paddings = tuple(tuple(l) for l in padlist) + attr["pad_width"] = paddings + mode = attr["mode"].decode("utf-8") + attr["mode"] = mode + new_inputs = [inputs[0]] + return AttrCvt( + op_name="mirror_pad", + ignores=["Tpaddings"], + )(new_inputs, attr) + + return _impl + + +def _transpose(): + def _impl(inputs, attr, params, mod): + # If perm is not specified, axes is left empty, + # otherwise its value is get from params + axes = _get_list_param(params, inputs[1], mod) + return _op.transpose(inputs[0], axes=axes) + + return _impl + + +def _where(): + def _impl(inputs, attr, params, mod): + if len(inputs) == 1: + return AttrCvt(op_name="argwhere")(inputs, attr) + return AttrCvt(op_name="where")(inputs, attr) + + return _impl + + +def _clip_by_value(): + def _impl(inputs, attr, params, mod): + a_min = _get_num_param(params, inputs[1]) + a_max = _get_num_param(params, inputs[2]) + return _op.clip(inputs[0], a_min=a_min, a_max=a_max) + + return _impl + + +def _reverse_v2(): + def _impl(inputs, attr, params, mod): + axis = _get_num_param(params, inputs[1]) + return AttrCvt(op_name="reverse", ignores=["Tidx"], extras={"axis": int(axis)})( + [inputs[0]], attr + ) + + return _impl + + +def _rank(): + def _impl(inputs, attr, params, mod): + input_shape = _infer_shape(inputs[0], mod) + + name = attr["_node_name"] + params[name] = tvm.nd.array(np.array([len(input_shape)]).astype("int32")) + return [_expr.var(name, shape=params[name].shape, dtype="int32")] + + return _impl + + +def _range(): + def _impl(inputs, attr, params, mod): + try: + start = _get_param(params, inputs[0])[0] + except (IndexError, KeyError, AttributeError): + try: + start = _infer_value(inputs[1], params, mod).numpy().tolist() + start = start if not isinstance(start, list) else start[0] + except Exception: + # Symbolic start + start = inputs[0] + + try: + limit = ( + _get_param(params, inputs[1])[0] + if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) + else params.pop("Rank").numpy()[0] + ) + except (IndexError, KeyError, AttributeError): + try: + limit = _infer_value(inputs[1], params, mod).numpy().tolist() + limit = limit if not isinstance(limit, list) else limit[0] + except Exception: + limit = inputs[1] + + try: + delta = _get_param(params, inputs[2])[0] + except (IndexError, KeyError, AttributeError): + try: + delta = _infer_value(inputs[2], params, mod).numpy().tolist() + delta = delta if not isinstance(delta, list) else delta[0] + except Exception: + # Symbolic delta + delta = inputs[2] + + # if all attributes are constant, evalute the range function and return relay.const + if all( + [ + isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)), + isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)), + isinstance(delta, (np.int32, np.int64, int, np.float32, np.float64, float)), + ] + ): + return tvm.relay.const(list(range(int(start), int(limit), int(delta)))) + + dtype = attr["Tidx"].name if "Tidx" in attr else str(start.dtype) + if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)): + start = _expr.const(start, dtype=dtype) + if isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)): + limit = _expr.const(limit, dtype=dtype) + if isinstance(delta, (np.int32, np.int64, int, np.float32, np.float64, float)): + delta = _expr.const(delta, dtype=dtype) + + return AttrCvt( + op_name="arange", + ignores=["Tidx", "_class"], + extras={"start": start, "stop": limit, "step": delta, "dtype": dtype}, + )([], attr) + + return _impl + + +def _elu(): + def _impl(inputs, attr, params, mod): + dtype = attr["T"].name + alpha = tvm.relay.const(-1.0, dtype) + return alpha * _op.nn.relu(tvm.relay.const(1, dtype) - _op.exp(inputs[0])) + _op.nn.relu( + inputs[0] + ) + + return _impl + + +def _selu(): + def _impl(inputs, attr, params, mod): + dtype = attr["T"].name + alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype) + gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype) + return gamma * ( + alpha * _op.nn.relu(tvm.relay.const(1, dtype) - _op.exp(inputs[0])) + + _op.nn.relu(inputs[0]) + ) + + return _impl + + +def _mean(): + def _impl(inputs, attr, params, mod): + axis = _get_tuple_param(params, inputs[1]) + return AttrCvt( + op_name="mean", + ignores=["Tdim", "Tidx"], + transforms={"keep_dims": "keepdims"}, + extras={"axis": axis}, + )([inputs[0]], attr) + + return _impl + + +def _broadcast(name): + def _impl(inputs, attr, params, mod): + return AttrCvt(op_name=name, ignores=["name", "incompatible_shape_error", "Tidx"])( + inputs, attr + ) + + return _impl + + +def _split(has_size_vector): + # TF documentation https://www.tensorflow.org/api_docs/python/tf/split + def _impl(inputs, attr, params, mod): + try: + # order and number of inputs are different: + # if has_size_vector: + # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v + # else: + # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split + + # in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow, + # we can only support constants + if has_size_vector: + input_node_index = 0 + input_axis_index = 2 + size_splits = _get_param(params, inputs[1]) + section_beginnings = np.cumsum(size_splits)[:-1] + indices_or_sections = tuple(section_beginnings) + else: + input_node_index = 1 + input_axis_index = 0 + indices_or_sections = attr["num_split"] + input_node = inputs[input_node_index] + axis_input_value = _get_num_param(params, inputs[input_axis_index]) + except (IndexError, KeyError, AttributeError): + raise TypeError( + "Unsupported argument for split: `axis` and `num_or_size_splits` " + "should be constants" + ) + return _op.split( + input_node, indices_or_sections=indices_or_sections, axis=int(axis_input_value) + ) + + return _impl + + +def _unpack(): + def _impl(inputs, attr, params, mod): + input_node = inputs[0] + axis = attr["axis"] + input_shape = _infer_shape(input_node, mod) + axis_length = input_shape[axis] + if axis_length < 0: + raise TypeError("Unstack with unknown axis length") + splitted = _op.split(input_node, indices_or_sections=axis_length, axis=axis) + axis = [axis] + return _expr.TupleWrapper( + _expr.Tuple([_op.squeeze(split_item, axis=axis) for split_item in splitted]), + len(splitted), + ) + + return _impl + + +def _softmax(): + def _impl(inputs, attr, params, mod): + return AttrCvt(op_name="softmax", transforms={"axis": ("axis", 1)})([inputs[0]], attr) + + return _impl + + +def _softsign(): + # op description: https://www.tensorflow.org/api_docs/python/tf/math/softsign + def _impl(inputs, attr, params, mod): + abs_out = get_relay_op("abs")(inputs[0]) + add_out = abs_out + tvm.relay.const(1, attr["T"].name) + return inputs[0] / add_out + + return _impl + + +def _softplus(): + # op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus + def _impl(inputs, attr, params, mod): + exp_out = AttrCvt("exp")(inputs, attr) + inputs.append(tvm.relay.const(1, attr["T"].name)) + rh = tvm.relay.const(1, attr["T"].name) + add_out = get_relay_op("add")(exp_out, rh) + return get_relay_op("log")(add_out) + + return _impl + + +def _topk(): + def _impl(inputs, attr, params, mod): + k_input = inputs.pop(1) + try: + k = int(_get_num_param(params, k_input)) + except (IndexError, KeyError, AttributeError): + try: + k = int(_infer_value(k_input, params, mod).numpy().tolist()) + except Exception: + k = k_input + if isinstance(k, int): + if k < 1: + raise tvm.error.OpAttributeInvalid( + "Attribute k must be positive in operator TopKV2" + ) + k = _expr.const(k) + if attr["sorted"] is False: + raise tvm.error.OpAttributeUnImplemented( + "Attribute sorted=False is not supported in operator TopKV2" + ) + return AttrCvt( + op_name="topk", + ignores=["sorted"], + extras={"k": k, "is_ascend": False, "dtype": "int32"}, + )([inputs[0]], attr) + + return _impl + + +def _floordiv(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 2 + return AttrCvt("floor_divide")(inputs, attr) + + return _impl + + +def _floormod(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 2 + return AttrCvt("floor_mod")(inputs, attr) + + return _impl + + +def _logical(name): + def _impl(inputs, attr, params, mod): + return AttrCvt(op_name=name)(inputs, attr) + + return _impl + + +def _space_to_batch_nd(): + def _impl(inputs, attr, params, mod): + block_shape = _get_list_param(params, inputs[1], mod) + + paddings = _get_list_param(params, inputs[2], mod) + paddings = np.squeeze(paddings) + if len(paddings.shape) == 1: + paddings = np.expand_dims(paddings, axis=0) + paddings = paddings.tolist() + + attr["block_shape"] = block_shape + attr["paddings"] = paddings + out = AttrCvt("space_to_batch_nd", ignores=["Tblock_shape", "Tpaddings"])([inputs[0]], attr) + + return out + + return _impl + + +def _batch_to_space_nd(): + def _impl(inputs, attr, params, mod): + block_shape = _get_list_param(params, inputs[1], mod) + + crops = _get_list_param(params, inputs[2], mod) + crops = np.squeeze(crops) + if len(crops.shape) == 1: + crops = np.expand_dims(crops, axis=0) + crops = crops.tolist() + + attr["block_shape"] = block_shape + attr["crops"] = crops + out = AttrCvt("batch_to_space_nd", ignores=["Tblock_shape", "Tcrops"])([inputs[0]], attr) + + return out + + return _impl + + +def _atan2(): + def _impl(inputs, attr, params, mod): + divide = _elemwise("divide")(inputs, attr, params, mod) + return get_relay_op("atan")(divide) + + return _impl + + +def _prod(): + def _impl(inputs, attr, params, mod): + axis = _get_num_param(params, inputs[1]) + keepdims = attr["keep_dims"] + return _op.prod(inputs[0], int(axis), keepdims=keepdims) + + return _impl + + +def _log1p(): + # op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p + def _impl(inputs, attr, params, mod): + one = tvm.relay.const(1, attr["T"].name) + add_out = get_relay_op("add")(inputs[0], one) + return get_relay_op("log")(add_out) + + return _impl + + +def _one_hot(): + def _impl(inputs, attr, params, mod): + depth = int(_get_num_param(params, inputs[1])) + dtype = attr["T"].name + + on_value = _get_num_param(params, inputs[2]) + off_value = _get_num_param(params, inputs[3]) + new_inputs = [ + inputs[0], + tvm.relay.const(on_value, dtype), + tvm.relay.const(off_value, dtype), + ] + return AttrCvt("one_hot", ignores=["TI"], extras={"depth": depth, "dtype": dtype})( + new_inputs, attr + ) + + return _impl + + +def _squared_difference(): + def _impl(inputs, attr, params, mod): + difference = _op.subtract(inputs[0], inputs[1]) + return _op.multiply(difference, difference) + + return _impl + + +def _size(): + def _impl(inputs, attr, params, mod): + new_attr = attr + new_attr["out_type"] = attr["out_type"].name + return AttrCvt("ndarray_size", transforms={"out_type": "dtype"})(inputs, new_attr) + + return _impl + + +def _add_n(): + def _impl(inputs, attr, params, mod): + if not isinstance(inputs, tuple): + inputs = list(inputs) + assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given." + _res = inputs[0] + for each in inputs[1:]: + _res = _op.add(_res, each) + return _res + + return _impl + + +def _LSTMBlockCell(): + def _impl(inputs, attr, params, mod): + """LSTM Block cell. + Calculations and return values are described in: + https://github.com/tensorflow/tensorflow/blob/ + r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114 + + Parameters + ---------- + inputs : relay.Expr + Input data + in_state_c: list of relay.Expr + Cell state input values for all the layers + in_state_h: list of relay.Expr + Hidden state input values for all the layers + attrs : dict + Dict of operator attributes + params : dict + List of pretrained weights and bias + + Returns + ------- + relay.Expr.TupleWapper + [i, cs, f, o, ci, co, h] + """ + in_data = inputs[0] + in_state_c = inputs[1] + in_state_h = inputs[2] + in_weight = inputs[3] + in_bias = inputs[7] + forget_bias = attr.pop("forget_bias") + input_shape = _infer_shape(inputs[0], mod) + weight_shape = _infer_shape(inputs[3], mod) + batch_size, input_size = input_shape[0], input_shape[1] + num_hidden_layers = weight_shape[1] + + in_data = _op.reshape(in_data, newshape=(batch_size, input_size)) + ixh = _op.concatenate([in_data, in_state_h], axis=1) + in_weight = _op.transpose(in_weight, axes=None) + gates = _op.nn.dense(ixh, in_weight, units=num_hidden_layers) + gates_bias = _op.add(gates, in_bias) + gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1) + in_gate = _op.sigmoid(gate_list[0]) + in_transform = _op.tanh(gate_list[1]) + forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr["T"].name)) + forget_gate = _op.sigmoid(forget_gate) + out_gate = _op.sigmoid(gate_list[3]) + next_c = _op.add(_op.multiply(forget_gate, in_state_c), _op.multiply(in_gate, in_transform)) + co = _op.tanh(next_c) + next_h = out_gate * co + + return tvm.relay.TupleWrapper( + tvm.relay.Tuple([in_gate, next_c, forget_gate, out_gate, in_transform, co, next_h]), 7 + ) + + return _impl + + +def _unique(return_counts=True): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 1 + data = inputs[0] + if return_counts: + [unique, _, inverse_indices, num_uniq, counts] = _op.unique( + data, is_sorted=False, return_counts=True + ) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") + return _expr.TupleWrapper( + _expr.Tuple([unique_sliced, inverse_indices, counts_sliced]), + 3, + ) + [unique, _, inverse_indices, num_uniq] = _op.unique( + data, is_sorted=False, return_counts=False + ) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + return _expr.TupleWrapper( + _expr.Tuple([unique_sliced, inverse_indices]), + 2, + ) + + return _impl + + +# _convert_map defines maps of name to converter functor(callable) +# for 1 to 1 mapping, use Renamer if nothing but name is different +# use AttrCvt if attributes need to be converted +# for 1 to N mapping(composed), use custom callable functions +# for N to 1 mapping, currently not supported(?) +_convert_map = { + "Abs": AttrCvt("abs"), + "Acos": AttrCvt("acos"), + "Acosh": AttrCvt("acosh"), + "Add": _elemwise("add"), + "AddN": _add_n(), + "AddV2": _elemwise("add"), + "All": _reduce("all"), + "Any": _reduce("any"), + "ArgMax": _argx(_op.argmax, "argmax"), + "ArgMin": _argx(_op.argmin, "argmin"), + "Asin": AttrCvt("asin"), + "Asinh": AttrCvt("asinh"), + "Assert": _assert(), + "Atan": AttrCvt("atan"), + "Atanh": AttrCvt("atanh"), + "Atan2": _atan2(), + "AvgPool": _pooling("avg_pool"), + "AvgPool3D": _pool3d("avg_pool3d"), + "BatchMatMul": _batch_matmul(), + "BatchMatMulV2": _batch_matmul(), + "BatchNormWithGlobalNormalization": _batch_norm(), + "BatchToSpaceND": _batch_to_space_nd(), + "BiasAdd": _bias_add(), + "BroadcastTo": _broadcast_to(), + "BroadcastArgs": _broadcast_args(), + "Cast": _cast(), + "Ceil": AttrCvt("ceil"), + "CheckNumerics": _check_numerics(), + "ClipByValue": _clip_by_value(), + "Concat": _concat(), + "ConcatV2": _concatV2(), + "Conv2D": _conv("conv"), + "Conv2DBackpropInput": _conv("conv_transpose"), + "Conv3D": _conv3d("conv"), + "Conv3DBackpropInputV2": _conv3d("conv_transpose"), + "Cos": AttrCvt("cos"), + "Cosh": AttrCvt("cosh"), + "CropAndResize": _crop_and_resize(), + "DecodeJpeg": _decode_image(), + "DepthToSpace": _depth_to_space(), + "DepthwiseConv2dNative": _conv("depthwise"), + "Dilation2D": _dilation2d(), + "Elu": _elu(), + "Equal": _broadcast("equal"), + "Erf": AttrCvt("erf"), + "EuclideanNorm": _euclidean_norm(), + "Exp": AttrCvt("exp"), + "ExpandDims": _expand_dims(), + "Expm1": _expm1(), + "Fill": _fill(), + "Floor": AttrCvt("floor"), + "FloorDiv": _floordiv(), + "FloorMod": _floormod(), + "FusedBatchNorm": _fused_batch_norm(), + "FusedBatchNormV2": _fused_batch_norm(), + "FusedBatchNormV3": _fused_batch_norm(), + "Gather": _gather(), + "GatherNd": _gather_nd(), + "GatherV2": _gather(), + "Greater": _broadcast("greater"), + "GreaterEqual": _broadcast("greater_equal"), + "Identity": _identity(), + "IdentityN": _identityn(), + "IsFinite": AttrCvt("isfinite"), + "IsInf": AttrCvt("isinf"), + "IsNan": AttrCvt("isnan"), + "LeakyRelu": AttrCvt("leaky_relu"), + "LeftShift": AttrCvt("left_shift"), + "Less": _broadcast("less"), + "LessEqual": _broadcast("less_equal"), + "Log": AttrCvt("log"), + "Log1p": _log1p(), + "LogicalAnd": _logical("logical_and"), + "LogicalNot": _logical("logical_not"), + "LogicalOr": _logical("logical_or"), + "LogSoftmax": AttrCvt("log_softmax"), + "LRN": _lrn(), + "LSTMBlockCell": _LSTMBlockCell(), + "MatMul": _matmul(), + "Max": _reduce("max"), + "Maximum": _elemwise("maximum"), + "MaxPool": _pooling("max_pool"), + "MaxPool3D": _pool3d("max_pool3d"), + "Mean": _mean(), + "Min": _reduce("min"), + "Minimum": _elemwise("minimum"), + "MirrorPad": _mirror_pad(), + "Mod": _elemwise("mod"), + "Mul": _elemwise("multiply"), + "Neg": AttrCvt("negative"), + "NonMaxSuppressionV2": _nms(), + "NonMaxSuppressionV3": _nms(), + "NonMaxSuppressionV4": _nms(), + "NonMaxSuppressionV5": _nms(True), + "CombinedNonMaxSuppression": _combined_nms(), + "NoOp": _no_op(), + "NotEqual": _broadcast("not_equal"), + "OneHot": _one_hot(), + "Pack": _pack(), + "Pad": _pad("Pad"), + "PadV2": _pad("PadV2"), + "Pow": _elemwise("power"), + "Prod": _prod(), + "Range": _range(), + "Rank": _rank(), + "RealDiv": _elemwise("divide"), + "Relu": AttrCvt("relu"), + "Relu6": _relu6(), + "Reshape": _reshape(), + "ResizeBicubic": _resize("bilinear"), + "ResizeBilinear": _resize("bilinear"), + "ResizeNearestNeighbor": _resize("nearest_neighbor"), + "ReverseV2": _reverse_v2(), + "RightShift": AttrCvt("right_shift"), + "Rint": AttrCvt("round"), + "Round": AttrCvt("round"), + "Rsqrt": _rsqrt(), + "Select": _where(), + "SelectV2": _where(), + "Selu": _selu(), + "Shape": _shape(), + "Sigmoid": AttrCvt("sigmoid"), + "Sign": AttrCvt("sign"), + "Sin": AttrCvt("sin"), + "Sinh": AttrCvt("sinh"), + "Size": _size(), + "Slice": _slice(), + "Softmax": _softmax(), + "Softplus": _softplus(), + "Softsign": _softsign(), + "SpaceToBatchND": _space_to_batch_nd(), + "SpaceToDepth": _space_to_depth(), + "SparseToDense": _sparse_to_dense(), + "SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(), + "SparseFillEmptyRows": _sparse_fill_empty_rows(), + "SparseReshape": _sparse_reshape(), + "SegmentSum": _math_segment_sum(), + "SparseSegmentSum": _sparse_segment_sum(), + "SparseSegmentSumWithNumSegments": _sparse_segment_sum_with_num_segments(), + "SparseSegmentSqrtN": _sparse_segment_sum_sqrtn(), + "SparseSegmentSqrtNWithNumSegments": _sparse_segment_sum_sqrtn_with_num_segments(), + "SparseSegmentMean": _sparse_segment_mean(), + "SparseSegmentMeanWithNumSegments": _sparse_segment_mean_with_num_segments(), + "SparseTensorDenseAdd": _sparse_tensor_dense_add(), + "Split": _split(False), + "SplitV": _split(True), + "Sqrt": AttrCvt("sqrt"), + "Square": _square(), + "SquaredDifference": _squared_difference(), + "Squeeze": _squeeze(), + "StopGradient": _identity(), + "StridedSlice": _stridedSlice(), + "Sub": _elemwise("subtract"), + "Sum": _sum(), + "Tan": AttrCvt("tan"), + "Tanh": AttrCvt("tanh"), + "TensorArrayConcatV3": _tensor_array_concat(), + "TensorArrayGatherV3": _tensor_array_gather(), + "TensorArrayReadV3": _tensor_array_read(), + "TensorArrayScatterV3": _tensor_array_scatter(), + "TensorArraySizeV3": _tensor_array_size(), + "TensorArraySplitV3": _tensor_array_split(), + "TensorArrayV3": _tensor_array(), + "TensorArrayWriteV3": _tensor_array_write(), + "Tile": _tile(), + "TopKV2": _topk(), + "Transpose": _transpose(), + "TruncateMod": _elemwise("mod"), + "Unique": _unique(False), + "UniqueWithCounts": _unique(True), + "Unpack": _unpack(), + "UnravelIndex": _unravel_index(), + "Where": _where(), + "ZerosLike": AttrCvt("zeros_like"), +} diff --git a/python/tvm/relay/frontend/tflite_flexbuffer.py b/python/tvm/relay/frontend/tflite_flexbuffer.py index 734908214dce..4b5d2b9c605c 100644 --- a/python/tvm/relay/frontend/tflite_flexbuffer.py +++ b/python/tvm/relay/frontend/tflite_flexbuffer.py @@ -76,7 +76,7 @@ def __init__(self, buffer): self.buffer = buffer def indirect_jump(self, offset, byte_width): - """ Helper function to read the offset value and jump """ + """Helper function to read the offset value and jump""" unpack_str = "" if byte_width == 1: unpack_str = " i: + cstride = int64(strides[i]) + if len(begin) > i: + cbegin = int64(begin[i]) + if cbegin < 0: + cbegin += int64(data_shape[axes[i]]) + if len(end) <= i: + cend = int64(data_shape[axes[i]]) + elif slice_mode != 0: + cstride = int64(1) + if end[i] < 0: + cend = int64(data_shape[axes[i]]) + else: + cend = cbegin + int64(end[i]) + else: + if end[i] > data_shape[i]: + cend = int64(data_shape[axes[i]]) + elif end[i] < -data_shape[i]: + cend = int64(-1) + else: + cend = int64(end[i]) + if cend < 0: + cend += int64(data_shape[axes[i]]) + assert cstride != 0, "Strides can't be zero." + if cstride < 0: + slice_range = cbegin - cend + step = -cstride + else: + slice_range = cend - cbegin + step = cstride + + out[axes[i]] = int64(ceil_div(slice_range, step)) + return out + + @_reg.register_shape_func("strided_slice", False) def strided_slice_shape_func(attrs, inputs, _): """ Shape func for strided_slice """ slice_mode = convert(0 if attrs.slice_mode == "end" else 1) + if attrs.axes is None: + return [ + _strided_slice_shape_func_input_shape( + inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode + ) + ] return [ - _strided_slice_shape_func_input_shape( - inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode + _strided_slice_shape_func_with_axes( + inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode, attrs.axes ) ] @@ -1045,24 +1097,28 @@ def ensure_tensor(tensor): def _unique_shape(data_shape): unique_shape = output_tensor((1,), "int64") indices_shape = output_tensor((1,), "int64") + inverse_indices_shape = output_tensor((1,), "int64") num_unique_shape = output_tensor((1,), "int64") unique_shape[0] = data_shape[0] indices_shape[0] = data_shape[0] + inverse_indices_shape[0] = data_shape[0] num_unique_shape[0] = int64(1) - return (unique_shape, indices_shape, num_unique_shape) + return (unique_shape, indices_shape, inverse_indices_shape, num_unique_shape) @script def _unique_with_counts_shape(data_shape): unique_shape = output_tensor((1,), "int64") indices_shape = output_tensor((1,), "int64") + inverse_indices_shape = output_tensor((1,), "int64") num_unique_shape = output_tensor((1,), "int64") counts_shape = output_tensor((1,), "int64") unique_shape[0] = data_shape[0] indices_shape[0] = data_shape[0] + inverse_indices_shape[0] = data_shape[0] num_unique_shape[0] = int64(1) counts_shape[0] = data_shape[0] - return (unique_shape, indices_shape, num_unique_shape, counts_shape) + return (unique_shape, indices_shape, inverse_indices_shape, num_unique_shape, counts_shape) @_reg.register_shape_func("unique", False) @@ -1074,3 +1130,34 @@ def unique_shape_func(attrs, inputs, _): return _unique_with_counts_shape(inputs[0]) else: return _unique_shape(inputs[0]) + + +@script +def _gather_nd_shape(data_shape, indices_shape, batch_dims, index_rank): + ndim = data_shape.shape[0] + # using mdim = indices_shape[0] wouldn't work because a rank cannot + # depend on a runtime shape dimension of indices tensor, even if the + # dimension is always a known, fixed value. As a workaround, we assume that + # the fixed gather dimension (the size of an indexing tuple) is recorded + # in gather_nd op attributes. + mdim = index_rank + kdim = indices_shape.shape[0] - 1 + out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64") + for i in range(1, kdim + 1): + out_shape[i - 1] = indices_shape[i] + for i in range(mdim + batch_dims, ndim): + out_shape[kdim + i - (mdim + batch_dims)] = data_shape[i] + return out_shape + + +@_reg.register_shape_func("gather_nd", False) +def gather_nd_shape_func(attrs, inputs, _): + """ + Shape func for gather_nd operator. + """ + batch_dims = get_const_int(attrs.batch_dims) + index_rank = get_const_int(attrs.index_rank) + + assert index_rank > 0, "index_rank needs to be specified for dynamic gather_nd" + + return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))] diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 9152b50e7686..9f3c1cdec0f7 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -397,6 +397,14 @@ def qnn_dense(expr): return True +def check_dilation(attrs): + """Prevents offloading if dilation other than (1, 1)""" + if not isinstance(attrs, relay.op.op_attrs.GlobalPool2DAttrs): + if not (len(attrs.dilation) == 2 and attrs.dilation[0] == 1 and attrs.dilation[1] == 1): + return False + return True + + @tvm.ir.register_op_attr("nn.max_pool2d", "target.arm_compute_lib") def max_pool2d(expr): """Check if the external ACL codegen for maxpool2d should be used.""" @@ -406,7 +414,7 @@ def max_pool2d(expr): typ = args[0].checked_type if typ.dtype not in ["float32", "uint8"]: return False - return True + return check_dilation(attrs) @tvm.ir.register_op_attr("nn.avg_pool2d", "target.arm_compute_lib") @@ -424,7 +432,7 @@ def avg_pool2d(expr, from_quantized_composite=False): if attrs.layout != "NHWC": return False - return True + return check_dilation(attrs) @tvm.ir.register_op_attr("nn.global_max_pool2d", "target.arm_compute_lib") @@ -483,7 +491,7 @@ def qnn_add(expr): class OpAttrContext(object): - """ Temporarily changes the attr of an op. """ + """Temporarily changes the attr of an op.""" def __init__(self, op_name, attr_key, attr_value): """Saves the required info for RAII pattern usage. diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index e23d2026661d..5b7fd32add4c 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -31,7 +31,7 @@ # resize @reg.register_compute("image.resize") def compute_resize(attrs, inputs, out_type): - """ compute definition for resize op """ + """compute definition for resize op""" size = attrs.size layout = attrs.layout method = attrs.method diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 91c148b5df2e..caf1f187fad3 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2236,6 +2236,7 @@ def sparse_add(dense_mat, sparse_mat): Examples ------- .. code-block:: python + dense_data = [[ 3., 4., 4. ] [ 4., 2., 5. ]] sparse_data = [4., 8.] diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 5882027fb1d8..ccf011819a97 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -21,6 +21,7 @@ from tvm.driver import lower, build from tvm.target import get_native_generic_func, GenericFunc from tvm.runtime import Object +import tvm.ir._ffi_api from . import _make @@ -40,6 +41,40 @@ def get(op_name): return tvm.ir.Op.get(op_name) +def register(op_name, describe=""): + """Get the Op for a given name. + when the op_name is not registered, create a new empty op with the given name. + when the op_name has been registered, abort with an error message. + + Parameters + ---------- + op_name : str + The operator name + + describe : Optional[str] + The operator description + """ + + tvm.ir._ffi_api.RegisterOp(op_name, describe) + + +def register_stateful(op_name, stateful, level=10): + """Register operator pattern for an op. + + Parameters + ---------- + op_name : str + The name of the op. + + stateful : bool + The stateful flag. + + level : int + The priority level + """ + tvm.ir.register_op_attr(op_name, "TOpIsStateful", stateful, level) + + class OpPattern(object): """Operator generic patterns @@ -401,6 +436,27 @@ def register_external_compiler(op_name, fexternal=None, level=10): return tvm.ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level) +def register_fake_quantization_to_integer(op_name, func=None, level=10): + """Register quantize function for an op + + Given an op and Affine Types on it's inputs, this function should return the op + in affine space/integer operators and the new type of the output, where affine + denotes the transformation x_real = (x_affine - zero_point) * scale + + Parameters + ---------- + op_name : str + The name of the operator + + func: function (expr: Expr, map: Map) -> new_expr: Expr + The function for translating the op into affine space and integer operators + + level : int + The priority level + """ + return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level) + + @tvm._ffi.register_func("relay.op.compiler._lower") def _lower(name, schedule, inputs, outputs): return lower(schedule, list(inputs) + list(outputs), name=name) diff --git a/python/tvm/relay/op/random/kernel.py b/python/tvm/relay/op/random/kernel.py index fc1248e85678..6c82cc154eb6 100644 --- a/python/tvm/relay/op/random/kernel.py +++ b/python/tvm/relay/op/random/kernel.py @@ -77,8 +77,7 @@ def threefry_generate(key, shape): this function.** shape : Sequence[int] - Desired outputs shape of random numbers. **Currently the total - number of elements must be a multiple of 4.** + Desired outputs shape of random numbers. Returns ------- diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 6c5b1e0cdead..b4db412700a7 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -89,6 +89,18 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target): return strategy +@fast_softmax_strategy.register(["cuda", "gpu"]) +def fast_softmax_strategy_cuda(attrs, inputs, out_type, target): + """fast_softmax cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.fast_softmax), + wrap_topi_schedule(topi.cuda.schedule_softmax), + name="fast_softmax.cuda", + ) + return strategy + + @schedule_log_softmax.register(["cuda", "gpu"]) def schedule_log_softmax_cuda(attrs, outs, target): """scheudle log_softmax for cuda""" @@ -240,10 +252,23 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): tensorcore_dtypes = ["int4", "uint4", "int8", "uint8"] if ( - (N % 16 == 0 and in_channels % 16 == 0 and out_channels % 16 == 0) - or (N % 8 == 0 and in_channels % 16 == 0 and out_channels % 32 == 0) - or (N % 32 == 0 and in_channels % 16 == 0 and out_channels % 8 == 0) - and (data.dtype in tensorcore_dtypes and kernel.dtype in tensorcore_dtypes) + target.kind.name == "cuda" + and nvcc.have_tensorcore(target=target) + and kernel.dtype in tensorcore_dtypes + and ( + ( + data.dtype in ["int4", "uint4"] + and N % 8 == 0 + and in_channels % 32 == 0 + and out_channels % 8 == 0 + ) + or ( + data.dtype in ["int8", "uint8"] + and N % 8 == 0 + and in_channels % 16 == 0 + and out_channels % 32 == 0 + ) + ) ): strategy.add_implementation( wrap_compute_conv2d(topi.cuda.conv2d_hwnc_tensorcore), @@ -309,17 +334,37 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): cudnn_impl = True if layout == "NCHW": - # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8. assert kernel_layout == "OIHW" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True), - wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw), - name="group_conv2d_nchw.cuda", - ) + _, channels, _, _ = get_const_tuple(data.shape) + out_channels, in_channels, _, _ = get_const_tuple(kernel.shape) + oc_chunk = out_channels // 4 + ic_chunk = in_channels // 4 + + if ( + data.dtype in ["int8", "uint8"] + and kernel.dtype in ["int8", "uint8"] + and channels % groups == 0 + and out_channels % groups == 0 + and channels % 4 == 0 + and out_channels % 4 == 0 + and groups <= oc_chunk + and groups <= ic_chunk + ): + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.group_conv2d_nchw_int8, has_groups=True), + wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw_int8), + name="group_conv2d_nchw_int8.cuda", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True), + wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw), + name="group_conv2d_nchw.cuda", + ) elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: assert kernel_layout == "OIHW4o4i" strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True), + wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, has_groups=True), wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8), name="group_conv2d_NCHWc_int8.cuda", ) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index a6ad06e544a6..d56820e409aa 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -175,7 +175,7 @@ def fast_softmax_strategy(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_softmax(topi.nn.fast_softmax), - naive_schedule, + wrap_topi_schedule(topi.generic.schedule_fast_softmax), name="fast_softmax.generic", ) return strategy @@ -1095,7 +1095,15 @@ def _compute_nms(attrs, inputs, out_type): max_output_size = inputs[2] iou_threshold = inputs[3] score_threshold = inputs[4] - return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold) + output_format = attrs.output_format + return topi_compute( + inputs[0], + inputs[1], + max_output_size, + iou_threshold, + score_threshold, + output_format, + ) return _compute_nms diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 60bd92ef63d1..c21ec4d13906 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -79,6 +79,18 @@ def softmax_strategy_cpu(attrs, inputs, out_type, target): return strategy +@fast_softmax_strategy.register("cpu") +def fast_softmax_strategy_cpu(attrs, inputs, out_type, target): + """fast_softmax x86 strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.fast_softmax), + wrap_topi_schedule(topi.x86.schedule_softmax), + name="fast_softmax.x86", + ) + return strategy + + @schedule_log_softmax.register("cpu") def schedule_log_softmax_cpu(attrs, outs, target): """schedule log_softmax op for x86""" diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c87f545c138a..049ddc9622ba 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -867,7 +867,7 @@ def split(data, indices_or_sections, axis=0): return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) -def strided_slice(data, begin, end, strides=None, slice_mode="end"): +def strided_slice(data, begin, end, strides=None, axes=None, slice_mode="end"): """Strided slice of an array. Parameters @@ -885,6 +885,12 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): Specifies the stride values, it can be negative in that case, the input tensor will be reversed in that particular axis. + axes : Tuple[int] or List[int], optional + Axes along which slicing is applied. When it is specified, the length of begin, end, + strides, and axes must be equal. Moreover, begin, end, strides, and axes must be + static (cannot be relay.Expr). Axes argument for dynamic parameter slicing is + not supported yet. + slice_mode : str, optional The slice mode [end, size]. end: The ending indices for the slice [default]. @@ -916,8 +922,10 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): ishape_slice = slice_like(ishape, begin) begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin) begin = _make.where(begin >= ishape_slice, ishape_slice, begin) + # TODO(masahi): Support axes argument in dynamic strided slice + assert axes is None, "Axes argument for dynamic parameter slicing is not supported yet." return _dyn_make.strided_slice(data, begin, end, strides, slice_mode) - return _make.strided_slice(data, begin, end, strides, slice_mode) + return _make.strided_slice(data, begin, end, strides, slice_mode, axes) def strided_set(data, v, begin, end, strides=None): @@ -1072,7 +1080,7 @@ def gather(data, axis, indices): return _make.gather(data, axis, indices) -def gather_nd(data, indices, batch_dims=0): +def gather_nd(data, indices, batch_dims=0, index_rank=None): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1087,6 +1095,10 @@ def gather_nd(data, indices, batch_dims=0): batch_dims : int The number of batch dimensions. + index_rank : int, optional + The size of an indexing tuple, which is a fixed value and the same as indices.shape[0] + Only needed when other dimensions of indices are dynamic. + Returns ------- ret : relay.Expr @@ -1108,7 +1120,7 @@ def gather_nd(data, indices, batch_dims=0): indices = [[1, 0]] relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]] """ - return _make.gather_nd(data, indices, batch_dims) + return _make.gather_nd(data, indices, batch_dims, index_rank) def sequence_mask(data, valid_length, mask_value=0, axis=0): @@ -1392,6 +1404,7 @@ def sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_v Examples ------- .. code-block:: python + sparse_indices = [[0, 1], [0, 3], [2, 0], @@ -1413,7 +1426,6 @@ def sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_v [4, 0]] empty_row_indicator = [False, True, False, False, True] new_sparse_values = [1, 2, 10, 3, 4, 10] - """ new_sparse_indices, new_sparse_values, empty_row_indicator = TupleWrapper( _make.sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_value), 3 @@ -1445,6 +1457,7 @@ def sparse_reshape(sparse_indices, prev_shape, new_shape): Examples -------- .. code-block:: python + sparse_indices = [[0, 0, 0], [0, 0, 1], [0, 1, 0], @@ -1496,6 +1509,7 @@ def segment_sum(data, segment_ids, num_segments=None): Examples -------- .. code-block:: python + data = [[1, 2, 3, 4], [4, -3, 2, -1], [5, 6, 7, 8]] @@ -1566,6 +1580,7 @@ def cumsum(data, axis=None, dtype=None, exclusive=None): Examples -------- .. code-block:: python + a = [[1,2,3], [4,5,6]] cumsum(a) # if axis is not provided, cumsum is done over the flattened input. @@ -1621,6 +1636,7 @@ def cumprod(data, axis=None, dtype=None, exclusive=None): Examples -------- .. code-block:: python + a = [[1,2,3], [4,5,6]] cumprod(a) # if axis is not provided, cumprod is done over the flattened input. @@ -1654,7 +1670,7 @@ def unique(data, is_sorted=True, return_counts=False): data : relay.Expr A 1-D tensor of integers. - sorted : bool + is_sorted : bool Whether to sort the unique elements in ascending order before returning as output. return_counts : bool @@ -1662,12 +1678,16 @@ def unique(data, is_sorted=True, return_counts=False): Returns ------- - output : relay.Expr + unique : relay.Expr A 1-D tensor containing the unique elements of the input data tensor. indices : relay.Expr A 1-D tensor containing the index of each data element in the output tensor. + inverse_indices : relay.Expr + A 1-D tensor. For each entry in data, it contains the index of that data element in the + unique array. + num_unique : relay.Expr A 1-D tensor with size=1 containing the number of unique elements in the input data tensor. @@ -1677,6 +1697,7 @@ def unique(data, is_sorted=True, return_counts=False): Examples -------- .. code-block:: python + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False) output = [4, 5, 1, 2, 3, ?, ?, ?] indices = [0, 1, 2, 3, 4, 4, 0, 1] @@ -1694,5 +1715,5 @@ def unique(data, is_sorted=True, return_counts=False): num_unique = [5] """ if return_counts: - return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4) - return TupleWrapper(_make.unique(data, is_sorted, return_counts), 3) + return TupleWrapper(_make.unique(data, is_sorted, return_counts), 5) + return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4) diff --git a/python/tvm/relay/op/vision/_vision.py b/python/tvm/relay/op/vision/_vision.py index 8d6abf1a8c20..cab9f703e88a 100644 --- a/python/tvm/relay/op/vision/_vision.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -89,7 +89,7 @@ def nms_shape_func(attrs, inputs, _): @script -def _all_class_nms_shape_func(boxes_shape, scores_shape): +def _all_class_nms_shape_func_onnx(boxes_shape, scores_shape): out_shape = output_tensor((2,), "int64") count_shape = output_tensor((1,), "int64") @@ -99,9 +99,27 @@ def _all_class_nms_shape_func(boxes_shape, scores_shape): return out_shape, count_shape +@script +def _all_class_nms_shape_func_tf(boxes_shape, scores_shape): + out_indices_shape = output_tensor((3,), "int64") + out_scores_shape = output_tensor((2,), "int64") + count_shape = output_tensor((1,), "int64") + + out_indices_shape[0] = boxes_shape[0] + out_indices_shape[1] = scores_shape[1] * boxes_shape[1] + out_indices_shape[2] = int64(2) + out_scores_shape[0] = boxes_shape[0] + out_scores_shape[1] = scores_shape[1] * boxes_shape[1] + count_shape[0] = boxes_shape[0] + + return out_indices_shape, out_scores_shape, count_shape + + @reg.register_shape_func("vision.all_class_non_max_suppression", False) def all_class_nms_shape_func(attrs, inputs, _): - return _all_class_nms_shape_func(inputs[0], inputs[1]) + if attrs.output_format == "onnx": + return _all_class_nms_shape_func_onnx(inputs[0], inputs[1]) + return _all_class_nms_shape_func_tf(inputs[0], inputs[1]) @script diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 3f829e0b1cc7..8c54075d952c 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -152,7 +152,12 @@ def non_max_suppression( def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0 + boxes, + scores, + max_output_boxes_per_class=-1, + iou_threshold=-1.0, + score_threshold=-1.0, + output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -175,16 +180,31 @@ def all_class_non_max_suppression( score_threshold : float or relay.Expr, optional Score threshold to filter out low score boxes early + output_format : string, optional + "onnx" or "tensorflow". Specify by which frontends the outputs are + intented to be consumed. + Returns ------- out : relay.Tuple - The output is a relay.Tuple of two tensors, the first is `indices` of size - `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor - `num_total_detection` of shape `(1,)` representing the total number of selected boxes. + If `output_format` is "onnx", the output is a relay.Tuple of two tensors, the first is + `indices` of size `(batch_size * num_class* num_boxes , 3)` and the second is a scalar + tensor `num_total_detection` of shape `(1,)` representing the total number of selected + boxes. The three values in `indices` encode batch, class, and box indices. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. + + If `output_format` is "tensorflow", the output is a relay.Tuple of three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. """ if not isinstance(max_output_boxes_per_class, expr.Expr): max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32") @@ -194,6 +214,15 @@ def all_class_non_max_suppression( score_threshold = expr.const(score_threshold, "float32") out = _make.all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format, ) - return expr.TupleWrapper(out, 2) + + if output_format == "onnx": + return expr.TupleWrapper(out, 2) + + return expr.TupleWrapper(out, 3) diff --git a/python/tvm/relay/testing/temp_op_attr.py b/python/tvm/relay/testing/temp_op_attr.py index 12e3652de12f..e2d2e6bbcd42 100644 --- a/python/tvm/relay/testing/temp_op_attr.py +++ b/python/tvm/relay/testing/temp_op_attr.py @@ -22,7 +22,7 @@ class TempOpAttr(object): - """ Temporarily changes the attr of an op. """ + """Temporarily changes the attr of an op.""" def __init__(self, op_name, attr_key, attr_value): """Saves the required info for RAII pattern usage. diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index ca9996aeaaae..9ed40f85c3bc 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -19,3 +19,4 @@ # transformation passes from .transform import * from .recast import recast +from . import fake_quantization_to_integer diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py new file mode 100644 index 000000000000..5f4c53772eec --- /dev/null +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relay functions for rewriting fake quantized ops.""" +import tvm +from tvm import relay +from ..op import register_fake_quantization_to_integer + + +def fold_constant(expr): + mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.FoldConstant()(mod) + return mod["main"].body + + +@register_fake_quantization_to_integer("qnn.dequantize") +def dequantize(expr, type_map): + """Remove dequantize op""" + out = expr.args[0] + t = type_map[expr] + return [out, t.scale, t.zero_point, t.dtype] + + +@register_fake_quantization_to_integer("qnn.quantize") +def quantize(expr, type_map): + """Turn a quantize op into requantize or remove it""" + out = expr.args[0] + t = type_map[out] + in_scale = fold_constant(t.scale) + in_zero_point = fold_constant(t.zero_point) + if not ( + tvm.ir.structural_equal(in_scale, expr.args[1]) + and tvm.ir.structural_equal(in_zero_point, expr.args[2]) + and tvm.ir.structural_equal(t.dtype, expr.attrs.out_dtype) + ): + out = relay.qnn.op.requantize( + out, + in_scale, + in_zero_point, + expr.args[1], + expr.args[2], + out_dtype=expr.attrs.out_dtype, + ) + return [out, expr.args[1], expr.args[2], expr.attrs.out_dtype] + + +def register_unary_identity(op_name, op): + def identity(expr, type_map): + assert len(expr.args) == 1 + arg = expr.args[0] + t = type_map[arg] + out = op(arg, **expr.attrs) + return [out, t.scale, t.zero_point, t.dtype] + + return register_fake_quantization_to_integer(op_name, identity) + + +register_unary_identity("reshape", relay.op.reshape) +register_unary_identity("transpose", relay.op.transpose) +register_unary_identity("nn.max_pool2d", relay.op.nn.max_pool2d) + + +@register_fake_quantization_to_integer("nn.avg_pool2d") +def avgpool2d(expr, type_map): + """Rewrite a avgpool op""" + arg = expr.args[0] + t = type_map[arg] + arg = relay.op.cast(arg, "int32") + out = relay.op.nn.avg_pool2d(arg, **expr.attrs) + out = relay.op.cast(out, t.dtype) + return [out, t.scale, t.zero_point, t.dtype] + + +@register_fake_quantization_to_integer("nn.bias_add") +def bias_add(expr, type_map): + """Rewrite a bias_add op""" + x, b = expr.args + x_t = type_map[x] + b_t = type_map[b] + in_scale = fold_constant(x_t.scale) + in_zero_point = fold_constant(x_t.zero_point) + if not tvm.ir.structural_equal(x_t, b_t): + b = relay.qnn.op.requantize( + b, + b_t.scale, + b_t.zero_point, + in_scale, + in_zero_point, + out_dtype=xt.dtype, + ) + out = relay.op.nn.bias_add(x, b, **expr.attrs) + return [out, x_t.scale, x_t.zero_point, x_t.dtype] + + +@register_fake_quantization_to_integer("nn.conv2d") +def conv2d(expr, type_map): + """Rewrite a conv2d op""" + attrs = {**expr.attrs} + attrs.pop("out_dtype") + x, weight = expr.args + x_t = type_map[x] + w_t = type_map[weight] + conv_scale = fold_constant(x_t.scale * w_t.scale) + conv_zp = relay.const(0) + out = relay.qnn.op.conv2d( + x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs + ) + return [out, conv_scale, conv_zp, out.attrs.out_dtype] + + +@register_fake_quantization_to_integer("concatenate") +def concat(expr, type_map): + """Rewrite a concat op""" + scales = [] + zps = [] + for arg in expr.args[0].fields: + t = type_map[arg] + scales.append(t.scale) + zps.append(t.zero_point) + + out_type = type_map[expr] + + out = relay.qnn.op.concatenate( + expr.args[0], + relay.Tuple(scales), + relay.Tuple(zps), + out_type.scale, + out_type.zero_point, + **expr.attrs, + ) + return [out, out_type.scale, out_type.zero_point, out_type.dtype] + + +@register_fake_quantization_to_integer("clip") +def clip(expr, type_map): + """Rewrite a clip op""" + arg = expr.args[0] + t = type_map[arg] + amin = expr.attrs.a_min + amax = expr.attrs.a_max + scale = fold_constant(t.scale) + z_p = fold_constant(t.zero_point) + if isinstance(scale, relay.expr.Constant) and isinstance(z_p, relay.expr.Constant): + scale = scale.data.numpy().item() + z_p = z_p.data.numpy().item() + new_min = int(amin / scale + z_p) + new_max = int(amax / scale + z_p) + out = relay.op.clip(arg, new_min, new_max) + else: + amin = relay.op.round(relay.op.const(amin) / scale + z_p) + amax = relay.op.round(relay.op.const(amax) / scale + z_p) + out = relay.op.minimum(relay.op.maximum(arg, amin), amax) + return [out, t.scale, t.zero_point, t.dtype] diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 20e8bb94c501..20e045abab6c 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1171,3 +1171,31 @@ def AnnotateSpans(): The regsistered AnnotateSpans pass. """ return _ffi_api.AnnotateSpans() + + +def FakeQuantizationToInteger(): + # pylint: disable=anomalous-backslash-in-string + """ + Find regions of the graph of the form + + x w + | | + dq dq + \ / + op1 + | + op2 + | + q + + where q == qnn.quantize and dq = qnn.dequantize + and rewrite them into integer versions of op1 and op2 + + Rules for rewriting indivdual ops are in fake_quantization_to_integer.py + + Returns + ------- + ret : tvm.transform.Pass + The registered SimplifyExpr pass. + """ + return _ffi_api.FakeQuantizationToInteger() diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 265dedb63b57..71563b508290 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -27,7 +27,7 @@ # function exposures from .object_generic import convert_to_object, convert, const from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl -from .ndarray import vpi, rocm, ext_dev, micro_dev +from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib from .container import String from .params import save_param_dict, load_param_dict diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 63383e7710f5..7f83693292ba 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -137,3 +137,27 @@ def __from_tvm_object__(cls, obj): val = str.__new__(cls, content) val.__tvm_object__ = obj return val + + +@tvm._ffi.register_object("runtime.ShapeTuple") +class ShapeTuple(Object): + """TVM runtime ShapeTuple object. + Parameters + ---------- + shape : list[int] + The shape list used to construct the object. + """ + + def __init__(self, shape): + assert isinstance(shape, (list, tuple)), "Expect list of tuple, but received : {0}".format( + type(shape) + ) + for x in shape: + assert isinstance(x, int), "Expect int type, but received : {0}".format(type(x)) + self.__init_handle_by_constructor__(_ffi_api.ShapeTuple, *shape) + + def __len__(self): + return _ffi_api.GetShapeTupleSize(self) + + def __getitem__(self, idx): + return getitem_helper(self, _ffi_api.GetShapeTupleElem, len(self), idx) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index f0f33e162559..8107ab5b87d2 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -470,9 +470,6 @@ def load_module(path, fmt=""): files = [tar_temp.relpath(x) for x in tar_temp.listdir()] _cc.create_shared(path + ".so", files, cc=cc) path += ".so" - # TODO(weberlo): we should probably use a more distinctive suffix for microTVM object files - elif path.endswith(".obj"): - fmt = "micro_dev" # Redirect to the load API return _ffi_api.ModuleLoadFromFile(path, fmt) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index e19221c9f186..5a7acf0d6c30 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -268,13 +268,10 @@ def device(dev_type, dev_id=0): assert tvm.device("cuda", 0) == tvm.cuda(0) """ if isinstance(dev_type, string_types): - if "-device=micro_dev" in dev_type: - dev_type = Device.STR2MASK["micro_dev"] - else: - dev_type = dev_type.split()[0] - if dev_type not in Device.STR2MASK: - raise ValueError("Unknown device type %s" % dev_type) - dev_type = Device.STR2MASK[dev_type] + dev_type = dev_type.split()[0] + if dev_type not in Device.STR2MASK: + raise ValueError("Unknown device type %s" % dev_type) + dev_type = Device.STR2MASK[dev_type] return Device(dev_type, dev_id) @@ -510,22 +507,6 @@ def ext_dev(dev_id=0): return Device(12, dev_id) -def micro_dev(dev_id=0): - """Construct a micro device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return Device(13, dev_id) - - def hexagon(dev_id=0): """Construct a Hexagon device diff --git a/python/tvm/script/intrin.py b/python/tvm/script/intrin.py index 2ca11ff30b4b..76ddbb1de697 100644 --- a/python/tvm/script/intrin.py +++ b/python/tvm/script/intrin.py @@ -151,6 +151,11 @@ def max(a, b, span): # pylint: disable=redefined-builtin return tvm.tir.Max(a, b, span) +@register +def min(a, b, span): # pylint: disable=redefined-builtin + return tvm.tir.Min(a, b, span) + + def get_axis(begin, end, iter_type, span): ana = tvm.arith.Analyzer() extent = ana.simplify(end - begin) diff --git a/python/tvm/script/node.py b/python/tvm/script/node.py index 039eeb452ddb..c4593683da78 100644 --- a/python/tvm/script/node.py +++ b/python/tvm/script/node.py @@ -91,7 +91,7 @@ def __init__( span: Optional[Span] = None, ): def check_index(index: Union[int, PrimExpr]): - """ Check input index is non-negative integer or PrimExpr""" + """Check input index is non-negative integer or PrimExpr""" if isinstance(index, int): if index < 0: report_error("Negative index is not allowed during buffer access", span) diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index 11eecc9831a4..a23401d926e9 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -104,7 +104,7 @@ def get_optional_vars(node, context): @register class Allocate(WithScopeHandler): - """ With scope handler tir.allocate(extents, dtype, scope, condition) """ + """With scope handler tir.allocate(extents, dtype, scope, condition)""" def __init__(self): def allocate(extents, dtype, scope, condition=True, span=None): @@ -149,7 +149,7 @@ def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None): @register class LaunchThread(WithScopeHandler): - """ With scope handler tir.launch_thread(env_var, extent) """ + """With scope handler tir.launch_thread(env_var, extent)""" def __init__(self): def launch_thread(env_var, extent, span): @@ -175,7 +175,7 @@ def launch_thread(env_var, extent, span): @register class Realize(WithScopeHandler): - """ With scope handler tir.realize(buffer_bounds, scope, condition) """ + """With scope handler tir.realize(buffer_bounds, scope, condition)""" def __init__(self): def realize( @@ -205,7 +205,7 @@ def realize( @register class Attr(WithScopeHandler): - """ With scope handler tir.attr(attr_node, attr_key, value) """ + """With scope handler tir.attr(attr_node, attr_key, value)""" def __init__(self): def attr(attr_node, attr_key, value, span): @@ -218,7 +218,7 @@ def attr(attr_node, attr_key, value, span): @register class AssertHandler(WithScopeHandler): - """ With scope handler tir.Assert(condition, message) """ + """With scope handler tir.Assert(condition, message)""" def __init__(self): def Assert(condition, message, span): @@ -229,7 +229,7 @@ def Assert(condition, message, span): @register class Let(WithScopeHandler): - """ With scope handler tir.let(var, value) """ + """With scope handler tir.let(var, value)""" def __init__(self): def let(var, value, span): @@ -240,7 +240,7 @@ def let(var, value, span): @register class Block(WithScopeHandler): - """ With scope handler tir.block(extents, name) as iter_vars""" + """With scope handler tir.block(extents, name) as iter_vars""" def __init__(self): def block(axes=None, name_hint: str = "", span: Optional[Span] = None): @@ -359,7 +359,7 @@ def enter_scope( @register class InitBlock(WithScopeHandler): - """ With scope handler tir.init()""" + """With scope handler tir.init()""" def __init__(self): def init(span: Span = None): @@ -490,7 +490,7 @@ def create_loop( @register class Serial(ForScopeHandler): - """ For scope handler tir.serial(begin, end, annotations)""" + """For scope handler tir.serial(begin, end, annotations)""" def __init__(self): def serial( @@ -506,7 +506,7 @@ def serial( @register class Parallel(ForScopeHandler): - """ For scope handler tir.parallel(begin, end, annotations)""" + """For scope handler tir.parallel(begin, end, annotations)""" def __init__(self): def parallel( @@ -524,7 +524,7 @@ def parallel( @register class Vectorized(ForScopeHandler): - """ For scope handler tir.vectorized(begin, end, annotations)""" + """For scope handler tir.vectorized(begin, end, annotations)""" def __init__(self): def vectorized( @@ -542,7 +542,7 @@ def vectorized( @register class Unroll(ForScopeHandler): - """ For scope handler tir.unroll(begin, end, annotations)""" + """For scope handler tir.unroll(begin, end, annotations)""" def __init__(self): def unroll( @@ -560,7 +560,7 @@ def unroll( @register class ThreadBinding(ForScopeHandler): - """ For scope handler tir.thread_binding(begin, end, thread, annotations)""" + """For scope handler tir.thread_binding(begin, end, thread, annotations)""" def __init__(self): def thread_binding( @@ -606,7 +606,7 @@ def signature(self): @register class Grid(ForScopeHandler): - """ For scope handler tir.grid(extents)""" + """For scope handler tir.grid(extents)""" def __init__(self): def grid(*extents: List[PrimExpr], span: Span): diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index 6aa1239e9d79..7eb938c58f96 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -478,7 +478,7 @@ def match_buffer_region( @register class VarDef(SpecialStmt): - """ Special function for defining a Var""" + """Special function for defining a Var""" def __init__(self): def var(dtype, span): @@ -493,7 +493,7 @@ def var(dtype, span): @register class EnvThread(SpecialStmt): - """ Bind a var to thread env """ + """Bind a var to thread env""" def __init__(self): def env_thread(env_name, span): diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 748b4e8910c1..be39a6f6bd25 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -296,10 +296,13 @@ def micro(model="unknown", options=None): if model not in MICRO_SUPPORTED_MODELS: raise ValueError(f"Model {model} not supported by tvm.target.micro.") opts = _merge_opts( - MICRO_SUPPORTED_MODELS[model] + ["-runtime=c", "--system-lib", f"-model={model}"], + MICRO_SUPPORTED_MODELS[model] + ["-runtime=c", f"-model={model}"], options, ) + if (not options) or (options and "--executor=aot" not in options): + opts = _merge_opts(opts, "--system-lib") + # NOTE: in the future, the default micro target will be LLVM except when # external dependencies are present. return Target(" ".join(["c"] + opts)) diff --git a/python/tvm/te/hybrid/module.py b/python/tvm/te/hybrid/module.py index beea8844f78c..af6270045b6b 100644 --- a/python/tvm/te/hybrid/module.py +++ b/python/tvm/te/hybrid/module.py @@ -85,7 +85,7 @@ def load(self, path): src = self.src_ class FindFunc(ast.NodeVisitor): - """ Find the function in module to be loaded module. """ + """Find the function in module to be loaded module.""" # pylint: disable=invalid-name def __init__(self): diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index afe521a74361..eb200df0c599 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -48,7 +48,7 @@ from .op import comm_reducer, min, max, sum from .op import q_multiply_shift -from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule +from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError from . import schedule from . import ir_builder diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 5550a9e3c74f..ef1cab1fb663 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -19,4 +19,4 @@ from .block_scope import BlockScope, Dependency, DepKind, StmtSRef from .state import ScheduleDebugMask, ScheduleState -from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule +from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule, ScheduleError diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index f207fa274212..9452f5ab72ee 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -19,6 +19,7 @@ from typing import List, Optional, Union from tvm._ffi import register_object as _register_object +from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object from tvm.tir import Block, For, IntImm, PrimFunc, Var @@ -27,6 +28,11 @@ from .state import ScheduleState, StmtSRef +@register_error +class ScheduleError(TVMError): + """Error that happens during TensorIR scheduling.""" + + @_register_object("tir.LoopRV") class LoopRV(Object): """A random variable that refers to a loop""" @@ -57,10 +63,14 @@ class Schedule(Object): Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html """ + ERROR_RENDER_LEVEL = {"detail": 0, "fast": 1, "none": 2} + def __init__( self, func_or_mod: Union[PrimFunc, IRModule], + *, debug_mode: Union[bool, int] = False, + error_render_level: str = "detail", ): """Construct a concrete TensorIR schedule from an IRModule or a PrimFunc @@ -71,6 +81,11 @@ def __init__( debug_mode : Union[bool, int] Do extra correctness checking after the class creation and each time scheduling primitive + error_render_level : str = "detail" + The level of error rendering. Choices: "detail", "fast", "none". + "detail": Render a detailed error message, with the TIR and error locations printed + "fast: Show a simple error message without rendering or string manipulation + "none": Do not show any error message. Note ---------- @@ -85,10 +100,17 @@ def __init__( debug_mode = 0 if not isinstance(debug_mode, int): raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}") + if error_render_level not in Schedule.ERROR_RENDER_LEVEL: + raise ValueError( + 'error_render_level can be "detail", "fast", or "none", but got: ' + + f"{error_render_level}" + ) + error_render_level = Schedule.ERROR_RENDER_LEVEL.get(error_render_level) self.__init_handle_by_constructor__( _ffi_api_schedule.ConcreteSchedule, # pylint: disable=no-member func_or_mod, debug_mode, + error_render_level, ) ########## Utilities ########## @@ -234,6 +256,121 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]: """ return _ffi_api_schedule.ScheduleGetLoops(self, block) # pylint: disable=no-member + ########## Schedule: loops manipulation ########## + ########## Schedule: compute location ########## + def compute_inline(self, block: BlockRV) -> None: + """Inline a block into its consumer(s). It requires: + 1) The block is a complete non-root block, which only produces one buffer + 2) The block must not be the only leaf in the scope. + 3) The body of the block must be a BufferStore statement in the form of, + A[i, j, k, ...] = ... + where the indices of the LHS are all distinct atomic variables, + and no variables other than those indexing variables are allowed in the statement. + + Parameters + ---------- + block : BlockRV + The block to be inlined to its consumer(s) + + Examples + -------- + + Before compute-inline, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_inline(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do compute-inline: + + .. code-block:: python + + sch = tir.Schedule(before_inline, debug_mode=True) + sch.compute_inline(sch.get_block("B")) + print(tvm.script.asscript(sch.mod["main"])) + + After applying compute-inline, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_inline(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + """ + _ffi_api_schedule.ScheduleComputeInline(self, block) # pylint: disable=no-member + + def reverse_compute_inline(self, block: BlockRV) -> None: + """Inline a block into its only producer. It requires: + 1) The block is a complete non-root block, which only produces and consumes one buffer + 2) The block must not be the only leaf in the scope. + 3) The only producer of the block is a read-after-write producer + and a complete non-root block + 4) The body of the block must be a BufferStore statement in the form of, + B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...) + where the indices of each `BufferLoad` on the RHS are all distinct atomic variables, + and no variables other than those indexing variables are allowed in the statement. + + Parameters + ---------- + block : BlockRV + The block to be inlined to its producer + + Examples + -------- + + Before reverse-compute-inline, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_inline(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do reverse-compute-inline: + + .. code-block:: python + + sch = tir.Schedule(before_inline, debug_mode=True) + sch.reverse_compute_inline(sch.get_block("C")) + print(tvm.script.asscript(sch.mod["main"])) + + After applying reverse-compute-inline, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_inline(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + """ + _ffi_api_schedule.ScheduleReverseComputeInline(self, block) # pylint: disable=no-member + + ########## Schedule: loop binding/annotation ########## + ########## Schedule: cache read/write ########## + ########## Schedule: reduction ########## + ########## Schedule: blockize & tensorize ########## + @_register_object("tir.ConcreteSchedule") class ConcreteSchedule(Schedule): diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index be55b48da71e..26b22f99c215 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -347,6 +347,17 @@ def MakePackedAPI(num_unpacked_params=0): return _ffi_api.MakePackedAPI(num_unpacked_params) +def MakeUnpackedAPI(): + """Transform the PrimFuncs in the module to a C API compatible with internal calls. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.MakeUnpackedAPI() + + def SplitHostDevice(): """Split the function into a host function and device functions. diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index 15d84c20ed23..52ed7a18df81 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -43,14 +43,14 @@ def get_arch_version(target_mattr): def is_dotprod_available(): - """ Checks whether the hardware has support for udot/sdot instructions. """ + """Checks whether the hardware has support for udot/sdot instructions.""" target = tvm.target.Target.current(allow_none=False) arch_version = get_arch_version(target.mattr) return arch_version >= 8.4 or ((arch_version in (8.2, 8.3)) and "+dotprod" in target.mattr) def is_mmla_available(): - """ Checks whether the hardware has support for ummla/smmla instructions. """ + """Checks whether the hardware has support for ummla/smmla instructions.""" target = tvm.target.Target.current(allow_none=False) arch_version = get_arch_version(target.mattr) return arch_version >= 8.6 or ( @@ -59,7 +59,7 @@ def is_mmla_available(): def is_aarch64_arm(): - """ Checks whether we are compiling for an AArch64 target. """ + """Checks whether we are compiling for an AArch64 target.""" target = tvm.target.Target.current(allow_none=False) return "aarch64" in target.attrs.get("mtriple", "") diff --git a/python/tvm/topi/arm_cpu/bitserial_conv2d.py b/python/tvm/topi/arm_cpu/bitserial_conv2d.py index 6406861885c3..def9b8345cd8 100644 --- a/python/tvm/topi/arm_cpu/bitserial_conv2d.py +++ b/python/tvm/topi/arm_cpu/bitserial_conv2d.py @@ -55,7 +55,7 @@ def bitserial_conv2d_nhwc( out_dtype, unipolar, ): - """ Compute convolution with pack on spatial axes. """ + """Compute convolution with pack on spatial axes.""" assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" assert pack_dtype == "uint8", "only support packing into uint8 bits" assert out_dtype == "int16", "only support output type of int16" diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index 7dbbf9d3d447..b3af36740551 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -381,7 +381,7 @@ def _callback(op): def _conv2d_arm_cpu_winograd_nnpack( cfg, data, kernel, strides, padding, dilation, out_dtype, convolution_algorithm ): - """ TOPI compute callback. Use winograd NNPACK template """ + """TOPI compute callback. Use winograd NNPACK template""" N, CI, IH, IW = get_const_tuple(data.shape) if isinstance(dilation, int): diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 85c03997a98d..8e416be8daa2 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -33,7 +33,7 @@ def configure_knobs(cfg, M, K): - """ Configure auto-tuning knobs for the interleaved strategy """ + """Configure auto-tuning knobs for the interleaved strategy""" x, y = cfg.axis(M // 4), cfg.axis(K // 16) cfg.define_reorder("reorder_gemm", [x, y], policy="candidate", candidate=[[x, y], [y, x]]) @@ -280,7 +280,7 @@ def compute_conv2d_gemm_without_weight_transform( def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): - """ Schedule the conv2d_gemm interleaved strategy """ + """Schedule the conv2d_gemm interleaved strategy""" C = out.op.input_tensors[0] C_interleaved = C.op.input_tensors[0] A_interleaved = C_interleaved.op.input_tensors[0] @@ -372,7 +372,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): def schedule_conv2d_gemm_native(cfg, s, out, final_out): - """ Schedule the conv2d_gemm hybrid strategy """ + """Schedule the conv2d_gemm hybrid strategy""" C = out.op.input_tensors[0] A = C.op.input_tensors[0] in_type = A.dtype diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index fc7e4036341a..bf4c03a6e5ed 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -196,7 +196,7 @@ def _callback(op): def compute_conv2d_NHWC_quantized_interleaved( cfg, data, kernel, strides, padding, dilation, out_dtype ): - """ Interface for interleaved compute_conv2d_NHWC_quantized_interleaved""" + """Interface for interleaved compute_conv2d_NHWC_quantized_interleaved""" return _compute_conv2d_NHWC_quantized( cfg, data, kernel, strides, padding, dilation, out_dtype, True ) @@ -206,7 +206,7 @@ def compute_conv2d_NHWC_quantized_interleaved( def compute_conv2d_NHWC_quantized_interleaved_without_transform( cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels ): - """ Interface for interleaved compute_conv2d_NHWC_quantized_interleaved_without_transform""" + """Interface for interleaved compute_conv2d_NHWC_quantized_interleaved_without_transform""" return _compute_conv2d_NHWC_quantized_without_transform( cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels, True ) @@ -214,7 +214,7 @@ def compute_conv2d_NHWC_quantized_interleaved_without_transform( @autotvm.register_topi_schedule("conv2d_NHWC_quantized_interleaved.arm_cpu") def schedule_conv2d_NHWC_quantized_interleaved(cfg, outs): - """ Interface for interleaved schedule_conv2d_NHWC_quantized_interleaved""" + """Interface for interleaved schedule_conv2d_NHWC_quantized_interleaved""" return _schedule_conv2d_NHWC_quantized(cfg, outs, True) @@ -222,7 +222,7 @@ def schedule_conv2d_NHWC_quantized_interleaved(cfg, outs): # The weights are interleaved and transposed @autotvm.register_topi_compute("conv2d_NHWC_quantized_native.arm_cpu") def compute_conv2d_NHWC_quantized_native(cfg, data, kernel, strides, padding, dilation, out_dtype): - """ Interface for native compute_conv2d_NHWC_quantized""" + """Interface for native compute_conv2d_NHWC_quantized""" return _compute_conv2d_NHWC_quantized( cfg, data, kernel, strides, padding, dilation, out_dtype, False ) @@ -232,7 +232,7 @@ def compute_conv2d_NHWC_quantized_native(cfg, data, kernel, strides, padding, di def compute_conv2d_NHWC_quantized_native_without_transform( cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels ): - """ Interface for compute_conv2d_NHWC_quantized_native_without_transform""" + """Interface for compute_conv2d_NHWC_quantized_native_without_transform""" return _compute_conv2d_NHWC_quantized_without_transform( cfg, data, @@ -249,5 +249,5 @@ def compute_conv2d_NHWC_quantized_native_without_transform( @autotvm.register_topi_schedule("conv2d_NHWC_quantized_native.arm_cpu") def schedule_conv2d_NHWC_quantized_native(cfg, outs): - """ Interface for native schedule_conv2d_NHWC_quantized""" + """Interface for native schedule_conv2d_NHWC_quantized""" return _schedule_conv2d_NHWC_quantized(cfg, outs, False) diff --git a/python/tvm/topi/bifrost/dense.py b/python/tvm/topi/bifrost/dense.py index 9ab8b4ebea62..7e827813ed66 100644 --- a/python/tvm/topi/bifrost/dense.py +++ b/python/tvm/topi/bifrost/dense.py @@ -103,7 +103,7 @@ def _callback(op): def fuse_and_bind(s, tensor, axis=None, num_thread=None): - """ fuse all the axis and bind to GPU threads """ + """fuse all the axis and bind to GPU threads""" axis = axis or s[tensor].op.axis fused = s[tensor].fuse(*axis) bx, tx = s[tensor].split(fused, num_thread) diff --git a/python/tvm/topi/bifrost/depthwise_conv2d.py b/python/tvm/topi/bifrost/depthwise_conv2d.py index 625c274213ad..801acd676aa6 100644 --- a/python/tvm/topi/bifrost/depthwise_conv2d.py +++ b/python/tvm/topi/bifrost/depthwise_conv2d.py @@ -52,7 +52,7 @@ def _schedule(pad_data, kernel, conv): output = conv def tile_and_bind3d(tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None): - """ tile and bind 3d """ + """tile and bind 3d""" y_factor = y_factor or z_factor x_factor = x_factor or y_factor zo, zi = s[tensor].split(z, z_factor) diff --git a/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py index b3d8397791fe..be9218431c85 100644 --- a/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py +++ b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py @@ -65,7 +65,7 @@ def unpack_HWNCnc_to_hwnc(packed_out, out_dtype): def conv2d_hwnc_tensorcore(data, kernel, strides, padding, dilation, in_dtype, out_dtype="int32"): - """"Compute conv2d with tensorcore for HWNC layout with int8/int4""" + """ "Compute conv2d with tensorcore for HWNC layout with int8/int4""" assert data.dtype in ("int4", "uint4", "int8", "uint8") assert kernel.dtype in ("int4", "uint4", "int8", "uint8") packed_out = hwnc_tensorcore_cuda(data, kernel, strides, padding, dilation, out_dtype) diff --git a/python/tvm/topi/cuda/group_conv2d_nchw.py b/python/tvm/topi/cuda/group_conv2d_nchw.py index 2af011700235..d75cfffc1af8 100644 --- a/python/tvm/topi/cuda/group_conv2d_nchw.py +++ b/python/tvm/topi/cuda/group_conv2d_nchw.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name +# pylint: disable=no-value-for-parameter """The template for cuda group_conv2d_nchw""" import tvm from tvm import te @@ -23,11 +24,28 @@ from .injective import schedule_injective_from_existing from .tensor_intrin import dp4a from ..nn.pad import pad +from ..nn.conv2d import unpack_NCHWc_to_nchw from ..nn.utils import get_pad_tuple from ..utils import traverse_inline, get_const_tuple, get_const_int from .. import nn +def group_conv2d_nchw_int8(data, kernel, strides, padding, dilation, groups, out_dtype="float32"): + """Compute group_conv2d internally using group_conv2d_nchwc layout for int8 dtype""" + assert data.dtype in ("int8", "uint8") + assert kernel.dtype in ("int8", "uint8") + assert data.dtype == kernel.dtype + packed_out = group_conv2d_NCHWc_int8( + data, kernel, strides, padding, dilation, groups, out_dtype + ) + return unpack_NCHWc_to_nchw(packed_out, out_dtype) + + +def schedule_group_conv2d_nchw_int8(outs): + """Create schedule for tensors""" + return schedule_group_conv2d_NCHWc_int8(outs) + + @autotvm.register_topi_compute("group_conv2d_nchw.cuda") def group_conv2d_nchw(_, data, kernel, stride, padding, dilation, groups, out_dtype="float32"): return nn.group_conv2d_nchw(data, kernel, stride, padding, dilation, groups, out_dtype) @@ -422,7 +440,13 @@ def _schedule_group_conv2d_NCHWc_int8(cfg, s, output): oc_chunk = get_const_int(output.shape[1]) # tile and bind spatial axes - n, f, y, x, c = s[output].op.axis + if len(s[output].op.axis) == 5: + n, f, y, x, c = s[output].op.axis + else: + # For task extraction of auto-tuning, the expected output is 4D. Since auto-tuning tasks + # are created from scratch, therefore the real auto-tuning will still happen on 5D output. + n, f, y, x = s[output].op.axis + cfg.define_split("tile_n", n, num_outputs=4) cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2) cfg.define_split("tile_f", cfg.axis(oc_chunk // groups), num_outputs=4) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 9a3b86d72b18..e402c5888978 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -32,6 +32,7 @@ calculate_overlap, binary_search, collect_selected_indices, + collect_selected_indices_and_scores, run_all_class_nms, ) @@ -988,8 +989,74 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro return ib.get() +def _collect_selected_indices_and_scores_ir( + selected_indices, + selected_scores, + num_detections, + row_offsets, + num_total_detections, + collected_indices, + collected_scores, +): + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + selected_scores = ib.buffer_ptr(selected_scores) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + num_total_detections = ib.buffer_ptr(num_total_detections) + collected_indices = ib.buffer_ptr(collected_indices) + collected_scores = ib.buffer_ptr(collected_scores) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(num_boxes, nthread_tx) + nthread_by = batch_size * num_class + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + zero = cast(0, "int64") + + with ib.new_scope(): + idx = bx * nthread_tx + tx + idy = cast(by, "int64") + batch_id = idy // num_class + class_id = idy % num_class + + with ib.if_scope(idx < num_detections[batch_id, class_id]): + offset = row_offsets[batch_id, class_id] + idx + collected_indices[batch_id, offset, 0] = class_id + collected_indices[batch_id, offset, 1] = cast(selected_indices[idy, idx], "int64") + collected_scores[batch_id, offset] = selected_scores[idy, idx] + with ib.else_scope(): + with ib.if_scope(idx < num_boxes): + offset = ( + num_total_detections[batch_id] + + class_id * num_boxes + - row_offsets[batch_id, class_id] + + idx + - num_detections[batch_id, class_id] + ) + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = 0.0 + + return ib.get() + + def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -1012,16 +1079,30 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early + output_format : str, optional + "onnx" or "tensorflow", see below + Returns ------- - out : [tvm.te.Tensor, tvm.te.Tensor] - The output is two tensors, the first is `indices` of size + out : list of tvm.te.Tensor + If `output_format` is "onnx", the output is two tensors. The first is `indices` of size `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor `num_total_detection` of shape `(1,)` representing the total number of selected - boxes. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + boxes. The three values in `indices` encode batch, class, and box indices. + Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. + + If `output_format` is "tensorflow", the output is three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. """ batch, num_class, num_boxes = scores.shape @@ -1029,7 +1110,7 @@ def all_class_non_max_suppression( sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both") valid_count = _get_valid_box_count(sorted_scores, score_threshold) - selected_indices, num_detections = run_all_class_nms( + selected_indices, selected_scores, num_detections = run_all_class_nms( boxes, sorted_scores, sorted_indices, @@ -1037,14 +1118,30 @@ def all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, _nms_loop, + return_scores=(output_format == "tensorflow"), ) + if output_format == "onnx": + row_offsets, num_total_detections = exclusive_scan( + num_detections, return_reduction=True, output_dtype="int64" + ) + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + return [selected_indices, num_total_detections] + + num_detections_per_batch = reshape(num_detections, (batch, num_class)) row_offsets, num_total_detections = exclusive_scan( - num_detections, return_reduction=True, output_dtype="int64" + num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1 ) - selected_indices = collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + selected_indices, selected_scores = collect_selected_indices_and_scores( + selected_indices, + selected_scores, + num_detections_per_batch, + row_offsets, + num_total_detections, + _collect_selected_indices_and_scores_ir, ) - return [selected_indices, num_total_detections] + return [selected_indices, selected_scores, num_total_detections] diff --git a/python/tvm/topi/cuda/reduction.py b/python/tvm/topi/cuda/reduction.py index ceab71640533..b9d02d9c81d8 100644 --- a/python/tvm/topi/cuda/reduction.py +++ b/python/tvm/topi/cuda/reduction.py @@ -37,7 +37,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): all_reduce = False num_thread = 32 target = tvm.target.Target.current() - if target and target.kind.name == "opencl": + if target and (target.kind.name == "opencl" or target.kind.name == "metal"): # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py # don't know why num_thread = 16 diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 6dbaf02191c8..0d19a92f2058 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -231,7 +231,7 @@ def ir(data, data_ex_scan, reduction): data[tid * scan_axis_size + scan_axis_size - 1], ) with ib.else_scope(): - reduction[tid] = 0 + reduction[tid] = cast(0, reduction.dtype) return ib.get() diff --git a/python/tvm/topi/cuda/softmax.py b/python/tvm/topi/cuda/softmax.py index 99fbdd0367db..b743aefc50d5 100644 --- a/python/tvm/topi/cuda/softmax.py +++ b/python/tvm/topi/cuda/softmax.py @@ -47,8 +47,15 @@ def schedule_softmax(outs): expsum = softmax.op.input_tensors[1] exp = softmax.op.input_tensors[0] max_elem = s[exp].op.input_tensors[1] + delta = None + elif op_tag == "fast_softmax_output": + expsum = softmax.op.input_tensors[1] + exp = softmax.op.input_tensors[0] + delta = s[exp].op.input_tensors[0] + max_elem = s[delta].op.input_tensors[1] elif op_tag == "log_softmax_output": exp = None + delta = None max_elem = softmax.op.input_tensors[1] expsum = softmax.op.input_tensors[2] else: @@ -73,6 +80,8 @@ def sched_warp_softmax(): if len(softmax.shape) > 2: ops = [max_elem.op, expsum.op, softmax.op] + if delta is not None: + ops.append(delta.op) if exp is not None: ops.append(exp.op) @@ -99,7 +108,10 @@ def sched_warp_softmax(): s[expsum].compute_at(s[softmax], xo) # (2) exp - if exp is not None: + if delta is not None: + s[exp].compute_inline() + s[delta].compute_inline() + elif exp is not None: xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread) _, xii = s[exp].split(xi, factor=4) s[exp].vectorize(xii) @@ -112,7 +124,7 @@ def sched_warp_softmax(): k = max_elem.op.reduce_axis[0] ko, _ = s[max_elem].split(k, nparts=num_thread) s[max_elem].bind(ko, thread_x) - if exp is not None: + if exp is not None and delta is None: s[max_elem].compute_at(s[exp], xo) else: s[max_elem].bind(ko, thread_x) @@ -123,7 +135,10 @@ def sched_warp_softmax(): block_x = te.thread_axis("blockIdx.x") thread_x = te.thread_axis((0, num_thread), "threadIdx.x") - if exp is not None: + if delta is not None: + s[exp].compute_inline() + s[delta].compute_inline() + elif exp is not None: s[exp].bind(exp.op.axis[0], block_x) s[max_elem].bind(max_elem.op.axis[0], block_x) diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py index 4476648e0aa4..7a796fa42696 100644 --- a/python/tvm/topi/cuda/sparse_reshape.py +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -48,6 +48,7 @@ def sparse_reshape( Examples -------- .. code-block:: python + sparse_indices = [[0, 0, 0], [0, 0, 1], [0, 1, 0], diff --git a/python/tvm/topi/cuda/unique.py b/python/tvm/topi/cuda/unique.py index 2bca3c447c4c..8f78cc5fb924 100644 --- a/python/tvm/topi/cuda/unique.py +++ b/python/tvm/topi/cuda/unique.py @@ -119,7 +119,7 @@ def _calc_num_unique(inc_scan): def _calc_unique_ir( - data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts + data, argsorted_indices, inc_scan, index_converter, unique_elements, inverse_indices, counts ): """Low level IR to calculate unique elements, inverse indices, and counts (optional) of unique elements of 1-D array. @@ -143,7 +143,7 @@ def _calc_unique_ir( unique_elements : Buffer A buffer that stores the unique elements. - indices : Buffer + inverse_indices : Buffer A buffer that stores the the index of each input data element in the unique element array. counts (optional) : Buffer @@ -154,7 +154,7 @@ def _calc_unique_ir( argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) inc_scan_ptr = ib.buffer_ptr(inc_scan) unique_elements_ptr = ib.buffer_ptr(unique_elements) - indices_ptr = ib.buffer_ptr(indices) + inverse_indices_ptr = ib.buffer_ptr(inverse_indices) index_converter_ptr = None if isinstance(index_converter, tir.Buffer): @@ -163,7 +163,7 @@ def _calc_unique_ir( if isinstance(counts, tir.Buffer): counts_ptr = ib.buffer_ptr(counts) # use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1] - unique_seq_indices_ptr = ib.buffer_ptr(indices) + unique_seq_indices_ptr = ib.buffer_ptr(inverse_indices) batch_size = data.shape[0] max_threads = _get_max_threads(batch_size) @@ -218,7 +218,7 @@ def _calc_unique_ir( if not index_converter_ptr else index_converter_ptr[inc_scan_ptr[tid]] ) - indices_ptr[data_idx] = unique_idx + inverse_indices_ptr[data_idx] = unique_idx with ib.if_scope(tid == 0): unique_elements_ptr[unique_idx] = data_ptr[data_idx] with ib.else_scope(): @@ -293,11 +293,20 @@ def unique(data, is_sorted=True, return_counts=False): Returns ------- - output : tvm.te.Tensor - A 1-D tensor containing the unique elements of the input data tensor. + unique : tvm.te.Tensor + A 1-D tensor containing the unique elements of the input data tensor. The same size as + the input data. If there are less unique elements than input data, the end of the tensor + is padded with zeros. indices : tvm.te.Tensor - A 1-D tensor containing the index of each data element in the output tensor. + A 1-D tensor. The same size as output. For each entry in output, it contains + the index of its first occurence in the input data. The end of the tensor is padded + with the length of the input data. + + inverse_indices : tvm.te.Tensor + A 1-D tensor. For each entry in data, it contains the index of that data element in the + unique array. (Note that inverse_indices is very similar to indices if output is not + sorted) num_unique : tvm.te.Tensor A 1-D tensor with size=1 containing the number of unique elements in the input data tensor. @@ -308,21 +317,25 @@ def unique(data, is_sorted=True, return_counts=False): Examples -------- .. code-block:: python + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False) - output = [4, 5, 1, 2, 3, ?, ?, ?] - indices = [0, 1, 2, 3, 4, 4, 0, 1] - num_unique = [5] + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, ?, ?, ?] + inverse_indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True) - output = [4, 5, 1, 2, 3, ?, ?, ?] - indices = [0, 1, 2, 3, 4, 4, 0, 1] - num_unique = [5] - counts = [2, 2, 1, 1, 2, ?, ?, ?] + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, ?, ?, ?] + inverse_indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + counts = [2, 2, 1, 1, 2, ?, ?, ?] [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True) - output = [1, 2, 3, 4, 5, ?, ?, ?] - indices = [3, 4, 0, 1, 2, 2, 3, 4] - num_unique = [5] + output = [1, 2, 3, 4, 5, ?, ?, ?] + indices = [2, 3, 4, 0, 1, ?, ?, ?] + inverse_indices = [3, 4, 0, 1, 2, 2, 3, 4] + num_unique = [5] """ sorted_data = sort(data) argsorted_indices = argsort(data, dtype="int32") @@ -355,6 +368,20 @@ def unique(data, is_sorted=True, return_counts=False): out_buffers = [unique_elements_buf, inverse_indices_buf] out_dtypes = [data.dtype, "int32"] # prepare inputs and fcompute + # calculate first occurence + first_occurence_buf = tir.decl_buffer( + data.shape, "int32", "first_occurence_buf", data_alignment=8 + ) + first_occurence = te.extern( + [data.shape], + [argsorted_indices, inc_scan], + lambda ins, outs: _calc_first_occurence_ir(ins[0], ins[1], outs[0]), + dtype=["int32"], + in_buffers=[argsorted_indices_buf, inc_scan_buf], + out_buffers=[first_occurence_buf], + name="_calc_first_occurence", + tag="_calc_first_occurence_gpu", + ) if is_sorted: in_data = [data, argsorted_indices, inc_scan] in_buffers = [data_buf, argsorted_indices_buf, inc_scan_buf] @@ -362,22 +389,8 @@ def unique(data, is_sorted=True, return_counts=False): fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs) else: fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs, None) + indices = first_occurence else: - # calculate the index converter if the unique elements should not be sorted - # calculate first occurence - first_occurence_buf = tir.decl_buffer( - data.shape, "int32", "first_occurence_buf", data_alignment=8 - ) - first_occurence = te.extern( - [data.shape], - [argsorted_indices, inc_scan], - lambda ins, outs: _calc_first_occurence_ir(ins[0], ins[1], outs[0]), - dtype=["int32"], - in_buffers=[argsorted_indices_buf, inc_scan_buf], - out_buffers=[first_occurence_buf], - name="_calc_first_occurence", - tag="_calc_first_occurence_gpu", - ) # calculate index converter by sorting unique elements by their first occurence argsorted_first_occurence = argsort(first_occurence, dtype="int32") index_converter = argsort(argsorted_first_occurence, dtype="int32") @@ -390,6 +403,7 @@ def unique(data, is_sorted=True, return_counts=False): fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs) else: fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs, None) + indices = sort(first_occurence) outs = te.extern( out_data_shape, in_data, @@ -401,5 +415,5 @@ def unique(data, is_sorted=True, return_counts=False): tag="_calc_unique_gpu", ) if return_counts: - return [outs[0], outs[1], num_unique_elements, outs[2]] - return [*outs, num_unique_elements] + return [outs[0], indices, outs[1], num_unique_elements, outs[2]] + return [outs[0], indices, outs[1], num_unique_elements] diff --git a/python/tvm/topi/cuda/vision.py b/python/tvm/topi/cuda/vision.py index 88983ab89f76..5208aeccd413 100644 --- a/python/tvm/topi/cuda/vision.py +++ b/python/tvm/topi/cuda/vision.py @@ -39,7 +39,9 @@ def traverse(op): traverse(tensor.op) scheduled_ops.append(op) - traverse(outs[0].op) + for o in outs: + traverse(o.op) + return s diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 866887706862..04d649037fef 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -563,6 +563,23 @@ def schedule_softmax(outs): return _default_schedule(outs, False) +def schedule_fast_softmax(outs): + """Schedule for fast_softmax + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of fast_softmax + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def schedule_dense(outs): """Schedule for dense diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index f0d564581d95..42d0455665a1 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -24,7 +24,7 @@ def get_2d_indices(indices, layout="NCHW"): - """ Get 2d indices """ + """Get 2d indices""" (cc, inum, ic) = (0, 0, 0) if layout == "NHWC": n, y, x, c = indices @@ -43,7 +43,7 @@ def get_2d_indices(indices, layout="NCHW"): def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, ib, ic): - """ Get 2d pixel """ + """Get 2d pixel""" if boxes is None: y = tvm.te.max(tvm.te.min(y, image_height - 1), 0) x = tvm.te.max(tvm.te.min(x, image_width - 1), 0) @@ -62,7 +62,7 @@ def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, def get_iny_inx( y, x, image_height, image_width, target_height, target_width, coordinate_transformation_mode ): - """ Infer input x,y from output x,y with various coordinate transformation methods """ + """Infer input x,y from output x,y with various coordinate transformation methods""" scale_y = te.div(image_height.astype("float"), target_height.astype("float")) scale_x = te.div(image_width.astype("float"), target_width.astype("float")) if coordinate_transformation_mode == "half_pixel": diff --git a/python/tvm/topi/intel_graphics/conv2d.py b/python/tvm/topi/intel_graphics/conv2d.py index bdbde91918dd..b276bcae92b1 100644 --- a/python/tvm/topi/intel_graphics/conv2d.py +++ b/python/tvm/topi/intel_graphics/conv2d.py @@ -125,7 +125,7 @@ def _create_schedule_template(cfg, dshape, kshape, strides, padding, dilation): ##### SCHEDULE UTILITIES ##### def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None): - """ tile and bind 3d """ + """tile and bind 3d""" y_factor = y_factor or z_factor x_factor = x_factor or y_factor zo, zi = s[tensor].split(z, z_factor) @@ -580,7 +580,7 @@ def _schedule_cl_spatialpack(s, op): temp = s[conv].op.input_tensors[0] kernel_vec = s[conv].op.input_tensors[1] kernel = s[kernel_vec].op.input_tensors[0] - temp_W = s.cache_read(temp, "warp", [conv]) + temp_W = s.cache_read(temp, "shared", [conv]) conv_L = s.cache_write(conv, "local") kernel_L = s.cache_read(kernel_vec, "local", [conv_L]) diff --git a/python/tvm/topi/mali/conv2d.py b/python/tvm/topi/mali/conv2d.py index 62be92cf1927..52fe011a70e9 100644 --- a/python/tvm/topi/mali/conv2d.py +++ b/python/tvm/topi/mali/conv2d.py @@ -607,7 +607,7 @@ def conv2d_winograd_nhwc_mali( ##### SCHECULE UTILITIES ##### def tile_and_bind(s, tensor, y, x, y_factor, x_factor=None): - """ tile and bind to GPU threads """ + """tile and bind to GPU threads""" x_factor = x_factor or y_factor yo, xo, yi, xi = s[tensor].tile(y, x, y_factor, x_factor) s[tensor].bind(xo, te.thread_axis("blockIdx.x")) @@ -618,7 +618,7 @@ def tile_and_bind(s, tensor, y, x, y_factor, x_factor=None): def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None): - """ tile and bind 3d """ + """tile and bind 3d""" y_factor = y_factor or z_factor x_factor = x_factor or y_factor zo, zi = s[tensor].split(z, z_factor) diff --git a/python/tvm/topi/mali/dense.py b/python/tvm/topi/mali/dense.py index 53f76219bacd..a8ca66b09cd5 100644 --- a/python/tvm/topi/mali/dense.py +++ b/python/tvm/topi/mali/dense.py @@ -103,7 +103,7 @@ def _callback(op): def fuse_and_bind(s, tensor, axis=None, num_thread=None): - """ fuse all the axis and bind to GPU threads """ + """fuse all the axis and bind to GPU threads""" # TODO(@comaniac): figure out where this function is used. axis = axis or s[tensor].op.axis fused = s[tensor].fuse(*axis) diff --git a/python/tvm/topi/mali/depthwise_conv2d.py b/python/tvm/topi/mali/depthwise_conv2d.py index 55fcb1de9c4a..b292f694b995 100644 --- a/python/tvm/topi/mali/depthwise_conv2d.py +++ b/python/tvm/topi/mali/depthwise_conv2d.py @@ -132,7 +132,7 @@ def _callback(op): def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None): - """ tile and bind 3d """ + """tile and bind 3d""" y_factor = y_factor or z_factor x_factor = x_factor or y_factor zo, zi = s[tensor].split(z, z_factor) diff --git a/python/tvm/topi/nn/bitserial_conv2d.py b/python/tvm/topi/nn/bitserial_conv2d.py index 78d05d027659..87acd4e3602c 100644 --- a/python/tvm/topi/nn/bitserial_conv2d.py +++ b/python/tvm/topi/nn/bitserial_conv2d.py @@ -40,10 +40,10 @@ def bitserial_conv2d_nchw( Parameters ---------- - input : tvm.te.Tensor + data : tvm.te.Tensor 4-D with shape [batch, in_channel, in_height, in_width] - filter : tvm.te.Tensor + kernel : tvm.te.Tensor 4-D with shape [num_filter, in_channel, filter_height, filter_width] stride : int or a list/tuple of two ints @@ -74,10 +74,10 @@ def bitserial_conv2d_nchw( """ assert isinstance(stride, int) or len(stride) == 2 Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype) - if len(filter.shape) == 4: - Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype) + if len(kernel.shape) == 4: + Filter_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype) else: - Filter_q = filter + Filter_q = kernel batch, in_channel, activation_bits, in_height, in_width = Input_q.shape num_filter, _, kernel_h, kernel_w, weight_bits = Filter_q.shape @@ -163,10 +163,10 @@ def bitserial_conv2d_nhwc( Parameters ---------- - input : tvm.te.Tensor + data : tvm.te.Tensor 4-D with shape [batch, in_height, in_width, in_channel] - filter : tvm.te.Tensor + kernel : tvm.te.Tensor 4-D with shape [filter_height, filter_width, in_channel, num_filter] stride : int or a list/tuple of two ints diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 80f87f86736c..130eb4b69844 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -159,7 +159,7 @@ def conv2d_infer_layout(workload, cfg): def _get_workload(data, kernel, stride, padding, dilation, out_dtype, data_layout="NCHW"): - """ Get the workload structure. """ + """Get the workload structure.""" if data_layout == "NCHW": _, CI, IH, IW = get_const_tuple(data.shape) elif data_layout == "NHWC": diff --git a/python/tvm/topi/nn/depthwise_conv2d.py b/python/tvm/topi/nn/depthwise_conv2d.py index 052ab8b88d1c..a3639b57e7e0 100644 --- a/python/tvm/topi/nn/depthwise_conv2d.py +++ b/python/tvm/topi/nn/depthwise_conv2d.py @@ -51,7 +51,7 @@ def _get_workload(data, kernel, stride, padding, dilation, out_dtype): - """ Get the workload structure. """ + """Get the workload structure.""" _, in_channel, height, width = [x.value for x in data.shape] channel, channel_multiplier, kh, kw = [x.value for x in kernel.shape] out_channel = channel * channel_multiplier diff --git a/python/tvm/topi/random/kernel.py b/python/tvm/topi/random/kernel.py index 1be4c86e63c0..8b6bb114b181 100644 --- a/python/tvm/topi/random/kernel.py +++ b/python/tvm/topi/random/kernel.py @@ -216,7 +216,7 @@ def threefry_generate(gen, out_shape): not be reused in another function, otherwise random numbers will be repeated. out_shape : Sequence[int] - Output shape of the random numbers. Product of all dimensions must be a multiple of 4. + Output shape of the random numbers. Returns ------- @@ -229,9 +229,6 @@ def threefry_generate(gen, out_shape): out_len = tir.const(1) for s in out_shape: out_len *= s - assert ( - out_len.value % 4 == 0 - ), f"Threefry can only generate arrays who's size is a multiple of 4 ({out_len} was provided)." assert ( out_len.value <= 2 ** 64 - 1 ), f"Can only generate up to 2^64 random numbers, but {out_len} were requested." @@ -296,7 +293,14 @@ def gen_ir(gen_ptr, out_gen_ptr, out_array_ptr): _shift_right(irb, gen[8], gen[9], tmp, 8, tmp, 9) # Compute random values - _threefry(irb, tmp, 0, tmp, 4, out_array, 0, out_len // 4) + if out_len.value >= 4: + _threefry(irb, tmp, 0, tmp, 4, out_array, 0, out_len // 4) + if out_len.value % 4 != 0: + remaining = irb.allocate(gen.dtype, 4, name="remaining", scope="global") + tmp[7] = tmp[7] + tir.Cast(gen.dtype, out_len // 4 * 4) # increment counter + _threefry(irb, tmp, 0, tmp, 4, remaining, 0, 1) + with irb.for_range(0, out_len % 4, dtype="uint64", name="i") as i: + out_array[out_len // 4 * 4 + i] = remaining[i] # Update generator state out_gen[0] = tmp[0] # key stays the same @@ -306,7 +310,13 @@ def gen_ir(gen_ptr, out_gen_ptr, out_array_ptr): out_gen[4] = tmp[4] # path stays the same out_gen[5] = tmp[5] out_gen[6] = tir.const(0, dtype=gen.dtype) # unused, leave it as 0 - out_gen[7] = tmp[7] + tir.Cast(gen.dtype, out_len) # increment counter + if out_len.value % 4 != 0: + # increment counter for the remaining + # as we will generate 4 random numbers for the remaining, increase 4 here. + # the main increment was done before the second _threefry. + out_gen[7] = tmp[7] + tir.Cast(gen.dtype, 4) + else: + out_gen[7] = tmp[7] + tir.Cast(gen.dtype, out_len) # increment counter out_gen[8] = tmp[8] # path unchanged, so no update here out_gen[9] = tmp[9] @@ -490,7 +500,7 @@ def uniform(gen, low, high, out_shape, out_dtype): less than high. out_shape : Sequence[int] - Output shape of the random numbers. Product of all dimensions must be a multiple of 4. + Output shape of the random numbers. out_dtype : str The output dtype. diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index d7b008c4c33f..0fe29f315b43 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Scatter operator""" -from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate +from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate, expr from ..te import extern, hybrid @@ -206,12 +206,16 @@ def _verify_scatter_nd_inputs(data, indices, updates): f"the length of the shape of the output ({len(shape)})." ) for i in range(len(indices.shape) - 1): + if isinstance(indices.shape[i + 1], expr.Var) or isinstance(updates.shape[i], expr.Var): + continue assert indices.shape[i + 1] == updates.shape[i], ( f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " f"updates[{i}] ({updates.shape[i]})." ) for i in range(mdim, len(data.shape)): data_ind = i - mdim + len(indices.shape) - 1 + if isinstance(updates.shape[data_ind], expr.Var) or isinstance(data.shape[i], expr.Var): + continue assert updates.shape[data_ind] == data.shape[i], ( f"Dimension of updates[{data_ind}] ({updates.shape[data_ind]}) must equal dimension " f"of out_shape[{i}] ({data.shape[i]})." diff --git a/python/tvm/topi/sparse_reshape.py b/python/tvm/topi/sparse_reshape.py index 5535477e17c8..f2c0a2928b93 100644 --- a/python/tvm/topi/sparse_reshape.py +++ b/python/tvm/topi/sparse_reshape.py @@ -45,6 +45,7 @@ def sparse_reshape( Examples -------- .. code-block:: python + sparse_indices = [[0, 0, 0], [0, 0, 1], [0, 1, 0], diff --git a/python/tvm/topi/testing/adaptive_pool_python.py b/python/tvm/topi/testing/adaptive_pool_python.py index dd8fadd71f14..9a61e52a2826 100644 --- a/python/tvm/topi/testing/adaptive_pool_python.py +++ b/python/tvm/topi/testing/adaptive_pool_python.py @@ -73,7 +73,7 @@ def _pool3d(in_size, out_size, np_data, np_op): def adaptive_pool_channel_first(np_data, out_size, pool_op, np_op): - """ The reference function for adaptive pool, channel first layout """ + """The reference function for adaptive pool, channel first layout""" ishape = np_data.shape n, c = ishape[:2] oshape = (n, c) + out_size @@ -87,7 +87,7 @@ def adaptive_pool_channel_first(np_data, out_size, pool_op, np_op): def adaptive_pool_channel_last(np_data, out_size, pool_op, np_op): - """ The reference function for adaptive pool, channel last layout """ + """The reference function for adaptive pool, channel last layout""" ishape = np_data.shape n, c = ishape[0], ishape[-1] oshape = (n,) + out_size + (c,) @@ -108,7 +108,7 @@ def adaptive_pool_channel_last(np_data, out_size, pool_op, np_op): def adaptive_pool(np_data, out_size, pool_type, layout): - """ The reference function for adaptive pool, for 2d and 3d """ + """The reference function for adaptive pool, for 2d and 3d""" if isinstance(out_size, int): out_size = (out_size,) if len(out_size) == 1: diff --git a/python/tvm/topi/testing/bilinear_resize_python.py b/python/tvm/topi/testing/bilinear_resize_python.py index 844546e0643f..b1fb8b0b4845 100644 --- a/python/tvm/topi/testing/bilinear_resize_python.py +++ b/python/tvm/topi/testing/bilinear_resize_python.py @@ -22,7 +22,7 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mode="align_corners"): - """ Bilinear scaling using python""" + """Bilinear scaling using python""" (new_h, new_w) = out_size (ib, ic) = (1, 1) diff --git a/python/tvm/topi/testing/strided_slice_python.py b/python/tvm/topi/testing/strided_slice_python.py index 30466c785778..3843d0996777 100644 --- a/python/tvm/topi/testing/strided_slice_python.py +++ b/python/tvm/topi/testing/strided_slice_python.py @@ -17,7 +17,7 @@ """strided_slice/set in python""" -def strided_slice_python(data, begin, end, strides, slice_mode="end"): +def strided_slice_python(data, begin, end, strides, slice_mode="end", axes=None): """Python version of strided slice operator. Parameters @@ -41,6 +41,8 @@ def strided_slice_python(data, begin, end, strides, slice_mode="end"): the sizeof a slice starting at the location specified by begin. If end[i] is -1, all remaining elements in that dimension are included in the slice. + axes : list, optional + Axes along which slicing is applied Returns ------- @@ -48,6 +50,22 @@ def strided_slice_python(data, begin, end, strides, slice_mode="end"): The sliced result. """ strides = [] if strides is None else strides + if axes is not None: + rank = len(data.shape) + new_begin = [0] * rank + new_end = [data.shape[i] for i in range(rank)] + new_strides = [1] * rank + + for i, axis in enumerate(axes): + new_begin[axis] = begin[i] + new_end[axis] = end[i] + if len(strides) > i: + new_strides[axis] = strides[i] + + begin = new_begin + end = new_end + strides = new_strides + slices = [] for i in range(len(data.shape)): new_stride = None @@ -66,6 +84,7 @@ def strided_slice_python(data, begin, end, strides, slice_mode="end"): new_end = end[i] slices.append(slice(new_begin, new_end, new_stride)) + return data[tuple(slices)] diff --git a/python/tvm/topi/testing/trilinear_resize3d_python.py b/python/tvm/topi/testing/trilinear_resize3d_python.py index de1e2307737f..d603e987d5ef 100644 --- a/python/tvm/topi/testing/trilinear_resize3d_python.py +++ b/python/tvm/topi/testing/trilinear_resize3d_python.py @@ -23,7 +23,7 @@ def trilinear_resize3d_python( data_in, out_size, layout, coordinate_transformation_mode="align_corners" ): - """ Trilinear 3d scaling using python""" + """Trilinear 3d scaling using python""" (new_d, new_h, new_w) = out_size if layout == "NDHWC": diff --git a/python/tvm/topi/testing/upsampling_python.py b/python/tvm/topi/testing/upsampling_python.py index 7f48aa47b8d1..dd187c4d8cff 100644 --- a/python/tvm/topi/testing/upsampling_python.py +++ b/python/tvm/topi/testing/upsampling_python.py @@ -22,7 +22,7 @@ def upsample_nearest(arr, scale): - """ Populate the array by scale factor""" + """Populate the array by scale factor""" h, w = arr.shape out_h = int(round(h * scale[0])) out_w = int(round(w * scale[1])) @@ -36,7 +36,7 @@ def upsample_nearest(arr, scale): def upsampling_python(data, scale, layout="NCHW"): - """ Python version of scaling using nearest neighbour """ + """Python version of scaling using nearest neighbour""" ishape = data.shape if layout == "NCHW": @@ -87,7 +87,7 @@ def upsampling_python(data, scale, layout="NCHW"): def upsample3d_nearest(arr, scale): - """ Populate the array by scale factor""" + """Populate the array by scale factor""" d, h, w = arr.shape out_d = int(round(d * scale[0])) out_h = int(round(h * scale[1])) @@ -104,7 +104,7 @@ def upsample3d_nearest(arr, scale): def upsampling3d_python(data, scale, layout="NCDHW"): - """ Python version of 3D scaling using nearest neighbour """ + """Python version of 3D scaling using nearest neighbour""" ishape = data.shape if layout == "NCDHW": diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index df30ff775f60..b4d0167be2b1 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -170,7 +170,7 @@ def reverse_sequence(a, seq_lengths, seq_axis=1, batch_axis=0): return cpp.reverse_sequence(a, seq_lengths, seq_axis, batch_axis) -def strided_slice(a, begin, end, strides=None, slice_mode="end"): +def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"): """Slice of an array. Parameters @@ -189,6 +189,10 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"): in that case, the input tensor will be reversed in that particular axis. + axes : list of int, optional + Axes along which slicing is applied. When it is specified, begin, end + strides, and axes need to a list of integers of the same length. + slice_mode : str, optional The slice mode [end, size]. end - The ending indices for the slice [default]. @@ -205,6 +209,7 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"): or isinstance(end, tvm.te.Tensor) or isinstance(strides, tvm.te.Tensor) ): + assert axes is None, "axes argument is not supported by dynamic strided slice yet." if not isinstance(begin, tvm.te.Tensor): begin = const_vector(begin) if not isinstance(end, tvm.te.Tensor): @@ -216,7 +221,9 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"): return cpp.dynamic_strided_slice(a, begin, end, strides) if strides is None: strides = [] - return cpp.strided_slice(a, begin, end, strides, slice_mode) + if axes is None: + axes = [] + return cpp.strided_slice(a, begin, end, strides, axes, slice_mode) @tvm.te.tag_scope(tag=tag.INJECTIVE + ",strided_set") diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py index e7256551d7b6..5aeadc541e29 100644 --- a/python/tvm/topi/unique.py +++ b/python/tvm/topi/unique.py @@ -93,7 +93,7 @@ def _calc_num_unique(inc_scan): def _calc_unique_ir( - data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts + data, argsorted_indices, inc_scan, index_converter, unique_elements, inverse_indices, counts ): """Low level IR to calculate unique elements, inverse indices, and counts (optional) of unique elements of 1-D array. @@ -117,7 +117,7 @@ def _calc_unique_ir( unique_elements : Buffer A buffer that stores the unique elements. - indices : Buffer + inverse_indices : Buffer A buffer that stores the the index of each input data element in the unique element array. counts (optional) : Buffer @@ -128,7 +128,7 @@ def _calc_unique_ir( argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) inc_scan_ptr = ib.buffer_ptr(inc_scan) unique_elements_ptr = ib.buffer_ptr(unique_elements) - indices_ptr = ib.buffer_ptr(indices) + inverse_indices_ptr = ib.buffer_ptr(inverse_indices) index_converter_ptr = None if isinstance(index_converter, tir.Buffer): @@ -137,7 +137,7 @@ def _calc_unique_ir( if isinstance(counts, tir.Buffer): counts_ptr = ib.buffer_ptr(counts) # use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1] - unique_seq_indices_ptr = ib.buffer_ptr(indices) + unique_seq_indices_ptr = ib.buffer_ptr(inverse_indices) data_length = data.shape[0] @@ -167,7 +167,7 @@ def _calc_unique_ir( unique_idx = ( inc_scan_ptr[i] if not index_converter_ptr else index_converter_ptr[inc_scan_ptr[i]] ) - indices_ptr[data_idx] = unique_idx + inverse_indices_ptr[data_idx] = unique_idx with ib.if_scope(i == 0): unique_elements_ptr[unique_idx] = data_ptr[data_idx] with ib.else_scope(): @@ -219,11 +219,20 @@ def unique(data, is_sorted=True, return_counts=False): Returns ------- - output : tvm.te.Tensor - A 1-D tensor containing the unique elements of the input data tensor. + unique : tvm.te.Tensor + A 1-D tensor containing the unique elements of the input data tensor. The same size as + the input data. If there are less unique elements than input data, the end of the tensor + is padded with zeros. indices : tvm.te.Tensor - A 1-D tensor containing the index of each data element in the output tensor. + A 1-D tensor. The same size as output. For each entry in output, it contains + the index of its first occurence in the input data. The end of the tensor is padded + with the length of the input data. + + inverse_indices : tvm.te.Tensor + A 1-D tensor. For each entry in data, it contains the index of that data element in + the unique array. (Note that inverse_indices is very similar to indices if output is not + sorted.) num_unique : tvm.te.Tensor A 1-D tensor with size=1 containing the number of unique elements in the input data tensor. @@ -234,21 +243,25 @@ def unique(data, is_sorted=True, return_counts=False): Examples -------- .. code-block:: python + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False) - output = [4, 5, 1, 2, 3, ?, ?, ?] - indices = [0, 1, 2, 3, 4, 4, 0, 1] - num_unique = [5] + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, ?, ?, ?] + inverse_indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True) - output = [4, 5, 1, 2, 3, ?, ?, ?] - indices = [0, 1, 2, 3, 4, 4, 0, 1] - num_unique = [5] - counts = [2, 2, 1, 1, 2, ?, ?, ?] + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, ?, ?, ?] + inverse_indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + counts = [2, 2, 1, 1, 2, ?, ?, ?] [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True) - output = [1, 2, 3, 4, 5, ?, ?, ?] - indices = [3, 4, 0, 1, 2, 2, 3, 4] - num_unique = [5] + output = [1, 2, 3, 4, 5, ?, ?, ?] + indices = [2, 3, 4, 0, 1, ?, ?, ?] + inverse_indices = [3, 4, 0, 1, 2, 2, 3, 4] + num_unique = [5] """ sorted_data = sort(data) argsorted_indices = argsort(data, dtype="int32") @@ -266,16 +279,17 @@ def unique(data, is_sorted=True, return_counts=False): out_data_shape = [data.shape] * 2 out_dtypes = [data.dtype, "int32"] # prepare inputs and fcompute + + first_occurence = _calc_first_occurence(argsorted_indices, inc_scan) if is_sorted: in_data = [data, argsorted_indices, inc_scan] if return_counts: fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs) else: fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs, None) + + indices = first_occurence else: - # calculate the index converter if the unique elements should not be sorted - # calculate first occurence - first_occurence = _calc_first_occurence(argsorted_indices, inc_scan) # calculate index converter by sorting unique elements by their first occurence argsorted_first_occurence = argsort(first_occurence, dtype="int32") index_converter = argsort(argsorted_first_occurence, dtype="int32") @@ -284,6 +298,10 @@ def unique(data, is_sorted=True, return_counts=False): fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs) else: fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs, None) + # First occurence is in order of sorted unique output, if we sort the first_occurence array + # we get the correct result + indices = sort(first_occurence) + outs = te.extern( out_data_shape, in_data, @@ -293,5 +311,5 @@ def unique(data, is_sorted=True, return_counts=False): tag="_calc_unique_cpu", ) if return_counts: - return [outs[0], outs[1], num_unique_elements, outs[2]] - return [*outs, num_unique_elements] + return [outs[0], indices, outs[1], num_unique_elements, outs[2]] + return [outs[0], indices, outs[1], num_unique_elements] diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 2e8528c5e76c..3a056cfb4326 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -495,5 +495,5 @@ def ceil_div(a, b): def swap(arr, axis): - """ swap arr[axis] and arr[-1] """ + """swap arr[axis] and arr[-1]""" return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]] diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 744c5ef7feda..7a51946a279a 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -22,14 +22,15 @@ from tvm.te import hybrid from tvm.tir import if_then_else -from ..sort import sort, argsort +from ..sort import argsort from ..math import cast -from ..transform import reshape +from ..transform import reshape, gather from .. import reduction from ..scan import cumsum from .nms_util import ( binary_search, collect_selected_indices, + collect_selected_indices_and_scores, run_all_class_nms, ) @@ -727,8 +728,62 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro return ib.get() +def _collect_selected_indices_and_scores_ir( + selected_indices, + selected_scores, + num_detections, + row_offsets, + num_total_detections, + collected_indices, + collected_scores, +): + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + selected_scores = ib.buffer_ptr(selected_scores) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + num_total_detections = ib.buffer_ptr(num_total_detections) + collected_indices = ib.buffer_ptr(collected_indices) + collected_scores = ib.buffer_ptr(collected_scores) + zero = cast(0, "int64") + + with ib.for_range(0, batch_size * num_class, name="i", kind="parallel") as i: + i = cast(i, "int64") + batch_id = i // num_class + class_id = i % num_class + + with ib.for_range(0, num_boxes, name="j") as j: + with ib.if_scope(j < num_detections[batch_id, class_id]): + offset = row_offsets[batch_id, class_id] + j + collected_indices[batch_id, offset, 0] = class_id + collected_indices[batch_id, offset, 1] = cast(selected_indices[i, j], "int64") + collected_scores[batch_id, offset] = selected_scores[i, j] + with ib.else_scope(): + offset = ( + num_total_detections[batch_id] + + class_id * num_boxes + - row_offsets[batch_id, class_id] + + j + - num_detections[batch_id, class_id] + ) + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = 0.0 + + return ib.get() + + def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -751,25 +806,40 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early + output_format : str, optional + "onnx" or "tensorflow", see below. + Returns ------- - out : [tvm.te.Tensor, tvm.te.Tensor] - The output is two tensors, the first is `indices` of size + out : list of tvm.te.Tensor + If `output_format` is "onnx", the output is two tensors. The first is `indices` of size `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor `num_total_detection` of shape `(1,)` representing the total number of selected - boxes. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + boxes. The three values in `indices` encode batch, class, and box indices. + Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. + + If `output_format` is "tensorflow", the output is three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. """ batch, num_class, num_boxes = scores.shape scores = reshape(scores, (batch * num_class, num_boxes)) - sorted_scores = sort(scores, axis=1, is_ascend=False) sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32") + sorted_scores = gather(scores, 1, sorted_indices) + valid_count = _get_valid_box_count(sorted_scores, score_threshold) - selected_indices, num_detections = run_all_class_nms( + selected_indices, selected_scores, num_detections = run_all_class_nms( boxes, sorted_scores, sorted_indices, @@ -777,14 +847,29 @@ def all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, _nms_loop, + return_scores=(output_format == "tensorflow"), ) - row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") + if output_format == "onnx": + row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") + num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1) - num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1) - - selected_indices = collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + return [selected_indices, num_total_detections] + + num_detections_per_batch = reshape(num_detections, (batch, num_class)) + row_offsets = cumsum(num_detections_per_batch, exclusive=True, dtype="int64", axis=1) + num_total_detections = reduction.sum(cast(num_detections_per_batch, "int64"), axis=1) + + selected_indices, selected_scores = collect_selected_indices_and_scores( + selected_indices, + selected_scores, + num_detections_per_batch, + row_offsets, + num_total_detections, + _collect_selected_indices_and_scores_ir, ) - return [selected_indices, num_total_detections] + return [selected_indices, selected_scores, num_total_detections] diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index 1147b1687783..d12592fd111a 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -106,28 +106,63 @@ def collect_selected_indices(num_class, selected_indices, num_detections, row_of first, in descending of scores, followed by boxes from batch 0, class 1 etc. """ batch_class, num_boxes = selected_indices.shape - - selected_indices_buf = tvm.tir.decl_buffer( - selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8 - ) - num_detections_buf = tvm.tir.decl_buffer( - num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8 - ) - row_offsets_buf = tvm.tir.decl_buffer( - row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8 - ) - return te.extern( [(batch_class * num_boxes, 3)], [selected_indices, num_detections, row_offsets], lambda ins, outs: ir(num_class, ins[0], ins[1], ins[2], outs[0]), dtype=["int64"], - in_buffers=[selected_indices_buf, num_detections_buf, row_offsets_buf], name="collect_indices", tag="collect_indices", ) +def collect_selected_indices_and_scores( + selected_indices, selected_scores, num_detections, row_offsets, num_total_detections, ir +): + """Collect selected indices and scores from the core NMS loop into one linear output + + Parameters + ---------- + num_class : int + + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the indices + of selected boxes by the core NMS loop. + + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the scores + of selected boxes by the core NMS loop. + + num_detections tvm.te.Tensor + 2-D tensor with shape (batch_size, num_classes), representing + the number of boxes selected by the core NMS loop, per batch and class + + row_offsets tvm.te.Tensor + 2-D tensor with shape (batch_size, num_classes), this should be the exclusive scan + of num_detections along axis 1 + + ir : function + A function to generate IR for CPU or GPU, see its usage in vision/nms.py and cuda/nms.py + + Returns + ------- + out : [tvm.te.Tensor, tvm.te.Tensor] + The output is two tensors. The first is indices of size + (batch_size, num_class* num_boxes, 2), and the second is scores of size + (batch_size, num_class* num_boxes). + """ + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + return te.extern( + [(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)], + [selected_indices, selected_scores, num_detections, row_offsets, num_total_detections], + lambda ins, outs: ir(ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], outs[1]), + dtype=["int64", "float32"], + name="collect_indices_and_scores", + tag="collect_indices_and_scores", + ) + + def _all_class_nms_ir( boxes, sorted_scores, @@ -139,6 +174,7 @@ def _all_class_nms_ir( iou_threshold, max_output_size_per_class, box_indices, + selected_scores, num_valid_boxes, nms_loop, ): @@ -150,6 +186,9 @@ def _all_class_nms_ir( box_indices = ib.buffer_ptr(box_indices) num_valid_boxes = ib.buffer_ptr(num_valid_boxes) + if selected_scores is not None: + selected_scores = ib.buffer_ptr(selected_scores) + if isinstance(iou_threshold, float): iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) @@ -171,6 +210,9 @@ def on_new_valid_box(ib, tid, num_current_valid_box, i, j): with ib.if_scope(tid + 0 == 0): box_indices[i, num_current_valid_box] = sorted_indices[i, j] + if selected_scores is not None: + selected_scores[i, num_current_valid_box] = sorted_scores[i, j] + def on_new_invalidated_box(*_): pass @@ -201,6 +243,7 @@ def run_all_class_nms( max_output_size_per_class, iou_threshold, nms_loop, + return_scores=False, ): """The core all class NMS routine @@ -230,31 +273,49 @@ def run_all_class_nms( nms_loop : function A core NMS loop, see its usage in vision/nms.py and cuda/nms.py + return_scores : bool, optional + Whether or not to return selected scores, needed by the tensorflow output format. + Returns ------- - out : [tvm.te.Tensor, tvm.te.Tensor] - The output is two tensors, the first is indices of size - (batch_size * num_class, num_boxes) and the second is a tensor + out : a list of tvm.te.Tensor + The output is three tensors, the first and second are indices and scores of size + (batch_size * num_class, num_boxes), and the third is a tensor num_selected_boxes of shape (batch_size * num_class,) representing the total number of - selected boxes per batch and class. + selected boxes per batch and class. If return_scores is False, the second output is + None. """ batch, num_boxes, _ = boxes.shape batch_class = sorted_scores.shape[0] num_class = batch_class // batch - boxes_buf = tvm.tir.decl_buffer(boxes.shape, boxes.dtype, "boxes_buf", data_alignment=8) - sorted_scores_buf = tvm.tir.decl_buffer( - sorted_scores.shape, sorted_scores.dtype, "sorted_scores_buf", data_alignment=8 - ) - sorted_indices_buf = tvm.tir.decl_buffer( - sorted_indices.shape, sorted_indices.dtype, "sorted_indices_buf", data_alignment=8 - ) - valid_count_buf = tvm.tir.decl_buffer( - valid_count.shape, "int32", "valid_count_buf", data_alignment=4 - ) + if return_scores is False: + selected_indices, num_detections = te.extern( + [(batch_class, num_boxes), (1, batch_class)], + [boxes, sorted_scores, sorted_indices, valid_count], + lambda ins, outs: _all_class_nms_ir( + ins[0], # boxes + ins[1], # sorted_scores + ins[2], # sorted_indices + ins[3], # valid_count + batch_class, + num_class, + num_boxes, + iou_threshold, + max_output_size_per_class, + outs[0], # box_indices + None, # scores + outs[1], # num_selected_boxes + nms_loop, + ), + dtype=["int32", "int32"], + name="all_class_nms", + tag="all_class_nms", + ) + return selected_indices, None, num_detections return te.extern( - [(batch_class, num_boxes), (1, batch_class)], + [(batch_class, num_boxes), (batch_class, num_boxes), (1, batch_class)], [boxes, sorted_scores, sorted_indices, valid_count], lambda ins, outs: _all_class_nms_ir( ins[0], # boxes @@ -267,16 +328,11 @@ def run_all_class_nms( iou_threshold, max_output_size_per_class, outs[0], # box_indices - outs[1], # num_selected_boxes + outs[1], # selected scores + outs[2], # num_selected_boxes nms_loop, ), - dtype=["int32", "int32"], - in_buffers=[ - boxes_buf, - sorted_scores_buf, - sorted_indices_buf, - valid_count_buf, - ], + dtype=["int32", "float32", "int32"], name="all_class_nms", tag="all_class_nms", ) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index df480123375d..37bdd09d6ca6 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -139,7 +139,7 @@ def _default_batch_matmul_config(cfg, M, N, K): def batch_matmul_blas_common(cfg, x, y, out_shape, lib): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch, using one of BLAS libraries. + data in batch, using one of BLAS libraries. Supports broadcasting in batch dimension. Parameters ---------- @@ -162,10 +162,10 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" XB, M, XK = get_const_tuple(x.shape) YB, N, YK = get_const_tuple(y.shape) - assert XB == YB, "batch dimension doesn't match" + assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match" assert XK == YK, "shapes of x and y is inconsistent" if out_shape is not None: - assert out_shape[0] == XB, "got invalid output shape" + assert out_shape[0] in (XB, YB), "got invalid output shape" assert out_shape[1] == M, "got invalid output shape" assert out_shape[2] == N, "got invalid output shape" cfg.add_flop(XB * M * N * XK * 2) diff --git a/python/tvm/topi/x86/bitserial_conv2d.py b/python/tvm/topi/x86/bitserial_conv2d.py index 18f305094754..73c9dd56517f 100644 --- a/python/tvm/topi/x86/bitserial_conv2d.py +++ b/python/tvm/topi/x86/bitserial_conv2d.py @@ -39,7 +39,7 @@ def bitserial_conv2d_nchw( out_dtype="int16", unipolar=True, ): - """ Compute convolution with pack on spatial axes. """ + """Compute convolution with pack on spatial axes.""" assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" data_q = bitpack(data, in_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype) # Check if kernel is already bitpacked @@ -181,7 +181,7 @@ def bitserial_conv2d_nhwc( out_dtype="int16", unipolar=True, ): - """ Compute convolution with pack on spatial axes. """ + """Compute convolution with pack on spatial axes.""" assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" data_q = bitpack(data, in_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype) pack_kernel = len(kernel.shape) == 4 diff --git a/python/tvm/topi/x86/conv3d.py b/python/tvm/topi/x86/conv3d.py index cb202f5257af..d5b09e640e16 100644 --- a/python/tvm/topi/x86/conv3d.py +++ b/python/tvm/topi/x86/conv3d.py @@ -471,7 +471,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout): def _get_conv3d_workload(data, kernel, stride, padding, out_dtype, data_layout="NCHW"): - """ Get the workload structure. """ + """Get the workload structure.""" if data_layout == "NCDHW": _, CI, ID, IH, IW = get_const_tuple(data.shape) CO, CIG, KD, KH, KW = get_const_tuple(kernel.shape) diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 6011f01c2cb0..4fed4c16464e 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -155,6 +155,7 @@ def _default_dense_nopack_config(cfg, M, N, K): cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn]) cfg["tile_x"] = SplitEntity([N, 1]) cfg["tile_y"] = SplitEntity([1, M]) + return M, N, K @autotvm.register_topi_compute("dense_nopack.x86") @@ -175,7 +176,7 @@ def dense_nopack(cfg, data, weight, bias=None, out_dtype=None): "tile_k", 32 if isinstance(K, (tvm.tir.Var, tvm.tir.Any)) else K, num_outputs=2 ) if cfg.is_fallback: - _default_dense_nopack_config(cfg, M, N, K) + M, N, K = _default_dense_nopack_config(cfg, M, N, K) vec = cfg["tile_k"].size[-1] k = te.reduce_axis((0, K // vec), "k") diff --git a/python/tvm/topi/x86/nn.py b/python/tvm/topi/x86/nn.py index 0994700fe98c..4c39f2ad7382 100644 --- a/python/tvm/topi/x86/nn.py +++ b/python/tvm/topi/x86/nn.py @@ -42,9 +42,17 @@ def schedule_softmax(outs): exp = softmax.op.input_tensors[0] expsum = softmax.op.input_tensors[1] max_elem = s[exp].op.input_tensors[1] + delta = None + axis = int(softmax.op.attrs["axis"]) + elif op_tag == "fast_softmax_output": + exp = softmax.op.input_tensors[0] + expsum = softmax.op.input_tensors[1] + delta = s[exp].op.input_tensors[0] + max_elem = s[delta].op.input_tensors[1] axis = int(softmax.op.attrs["axis"]) elif op_tag == "log_softmax_output": exp = None + delta = None max_elem = softmax.op.input_tensors[1] expsum = softmax.op.input_tensors[2] axis = 1 @@ -65,6 +73,9 @@ def schedule_softmax(outs): s[max_elem].compute_at(s[softmax], fused_outer_axes) s[expsum].compute_at(s[softmax], fused_outer_axes) + if delta is not None: + s[exp].compute_inline() + s[delta].compute_inline() if exp is not None: s[exp].compute_at(s[softmax], fused_outer_axes) diff --git a/src/README.md b/src/README.md index 2653efa56c83..bb9aeb2a8578 100644 --- a/src/README.md +++ b/src/README.md @@ -21,14 +21,17 @@ Header files in include are public APIs that share across modules. There can be internal header files within each module that sit in src. ## Modules -- support: Internal support utilities. -- runtime: Minimum runtime related codes. -- node: base infra for IR/AST nodes that is dialect independent. -- ir: Common IR infrastructure. -- tir: Tensor-level IR. -- te: tensor expression DSL - arith: Arithmetic expression and set simplification. -- relay: Relay IR, high-level optimization. -- autotvm: The auto-tuning module. +- auto\_scheduler: The template-free auto-tuning module. +- autotvm: The template-based auto-tuning module. - contrib: Contrib extension libraries. - driver: Compilation driver APIs. +- ir: Common IR infrastructure. +- node: The base infra for IR/AST nodes that is dialect independent. +- relay: Relay IR, high-level optimizations. +- runtime: Minimum runtime related codes. +- support: Internal support utilities. +- target: Hardwaer target. +- tir: Tensor IR, low-level optimizations. +- te: Tensor expression DSL. +- topi: Tensor Operator Inventory. diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index edcb6f8a2c92..c1daae967b47 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -242,9 +242,9 @@ class IterMapRewriter : public ExprMutator { * either 1) follow inclusion relation or 2) have no intersection * * For Example, x = i0*30 + i1*15 + i2*3 + i3, - * 1) [i0*2 + i1 < 3, i2*3 + i3 < 5] is valid, since {i0, i1} \intersect {i2, i3} = empty set. + * 1) [i0*2 + i1 < 3, i2*3 + i3 < 5] is valid, since {i0, i1} \\intersect {i2, i3} = empty set. * 2) [i0*2 + i1 < 3, i1*5 + i2 < 5] is not valid, - * since {i0, i1} \intersect {i1, i2} = {i1}, i0 \in {i0, i1}, i0 \notin {i1, i2} + * since {i0, i1} \\intersect {i1, i2} = {i1}, i0 \\in {i0, i1}, i0 \\notin {i1, i2} * \return whether the predicates are valid; */ bool CheckConstraints() const { diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f30cecbf7f05..cd8173717d5f 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -93,22 +92,62 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std offset_factor, buffer_type); } -void GetBinds(const Array& args, bool compact, +void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, Map* out_binds, Array* out_arg_list) { *out_binds = binds; - for (const auto& x : args) { - if (out_binds->find(x) == out_binds->end()) { - auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, -1, 0, compact); - out_binds->Set(x, buf); - out_arg_list->push_back(buf); + for (const ObjectRef& x : args) { + if (const te::TensorNode* tensor_node = x.as()) { + te::Tensor x_ref = GetRef(tensor_node); + if (out_binds->find(x_ref) == out_binds->end()) { + tir::Buffer buf = + BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact); + out_binds->Set(x_ref, buf); + out_arg_list->push_back(buf); + } else { + out_arg_list->push_back((*out_binds)[x_ref]); + } + } else if (x.as() || x.as()) { + out_arg_list->push_back(x); } else { - out_arg_list->push_back((*out_binds)[x]); + LOG(FATAL) + << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var, " + << "but got a " << x->GetTypeKey(); } } } +void GetBinds(const Array& args, bool compact, + const std::unordered_map& binds, + Map* out_binds, Array* out_arg_list) { + Array ref_args; + for (ObjectRef x : args) { + ref_args.push_back(x); + } + GetBinds(ref_args, compact, binds, out_binds, out_arg_list); +} + +TVM_REGISTER_GLOBAL("driver.get_binds") + .set_body_typed([](const Array& args, bool compact, + const Map& binds) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != nullptr) { + for (auto kv : binds) { + c_binds.insert({kv.first, kv.second}); + } + } + Map out_binds; + Array out_arg_list; + GetBinds(args, compact, c_binds, &out_binds, &out_arg_list); + + // TVM object system doesn't have a pair object, so we'll put both ret values in an array + // and return that. + Array out_arr = {out_binds, out_arg_list}; + return out_arr; + }); + transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { return WithAttr(std::move(f), tvm::attr::kTarget, target); @@ -128,63 +167,208 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } -IRModule lower(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds) { - Array out_arg_list; - auto pass_ctx = transform::PassContext::Current(); - - sch = sch.normalize(); - - // Before TIR transformation. - auto bounds = te::InferBound(sch); - auto stmt = te::ScheduleOps(sch, bounds, false); - bool compact = te::VerifyCompactBuffer(stmt); - - Map out_binds; - GetBinds(args, compact, binds, &out_binds, &out_arg_list); - - // build the function - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); +Array CreatePassList(bool disable_loop_partition, bool for_te_schedule) { + transform::PassContext pass_ctx = transform::PassContext::Current(); - bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + // Get any user-added passes + Array> add_lower_pass = + pass_ctx->GetConfig>>("tir.add_lower_pass", Array>()) + .value(); + + Array user_lower_phase0 = Array(); + Array user_lower_phase1 = Array(); + Array user_lower_phase2 = Array(); + Array user_lower_phase3 = Array(); + + // phase pasees is of the form + // [[phase_number, pass], [phase_number, pass]... ] + for (Array phase_pass : add_lower_pass) { + const IntImmNode* phase_num = phase_pass[0].as(); + ICHECK(phase_num) + << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer"; + int phase_num_val = phase_num->value; + + CHECK_GE(phase_num_val, 0); + + const tvm::transform::PassNode* pass_node = phase_pass[1].as(); + tvm::transform::Pass pass = GetRef(pass_node); + // Copy the pass into the correct phase + if (phase_num_val == 0) { + user_lower_phase0.push_back(pass); + } else if (phase_num_val == 1) { + user_lower_phase1.push_back(pass); + } else if (phase_num_val == 2) { + user_lower_phase2.push_back(pass); + } else if (phase_num_val >= 3) { + user_lower_phase3.push_back(pass); + } } - auto mod = IRModule(Map({{GlobalVar(name), f}})); - auto pass_list = Array(); + // Construct the pass list, inserting the user provided passes at the end of the phase - // Phase 0 - pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); - // Phase 1 + // PHASE 0 + Array pass_list = user_lower_phase0; + + // PHASE 1 + if (for_te_schedule) { + pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + } else { + pass_list.push_back(tir::transform::LowerInitBlock()); + pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::FlattenBuffer()); + } pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::LoopPartition()); + + // Add user-defined phase-1 passes + pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end()); + + // PHASE 2 + if (!disable_loop_partition) { + pass_list.push_back(tir::transform::LoopPartition()); + } + pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); pass_list.push_back(tir::transform::InjectVirtualThread()); pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back(tir::transform::UnrollLoop()); - // Phase 2 + + // Add user-defined phase-2 passes + pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end()); + + // PHASE 3 pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); pass_list.push_back(tir::transform::RewriteUnsafeSelect()); + pass_list.push_back(tir::transform::HoistIfThenElse()); + + // Add user-defined phase-3 passes + pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end()); + if (instrument_bound_checkers) { pass_list.push_back(tir::transform::InstrumentBoundCheckers()); } - // run - auto optimize = transform::Sequential(pass_list); + return pass_list; +} + +IRModule LowerWithPassList(IRModule mod, Array pass_list) { + auto optimize = tvm::transform::Sequential(pass_list); mod = optimize(std::move(mod)); return mod; } +IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds) { + // Convert te schedule to IRModule + Array out_arg_list; + transform::PassContext pass_ctx = transform::PassContext::Current(); + + sch = sch.normalize(); + + // Before TIR transformation. + Map bounds = te::InferBound(sch); + tir::Stmt stmt = te::ScheduleOps(sch, std::move(bounds), false); + bool compact = te::VerifyCompactBuffer(stmt); + + Map out_binds; + GetBinds(args, compact, binds, &out_binds, &out_arg_list); + + // Build the function + // At this point binds is only te::Tensors + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + return IRModule(Map({{GlobalVar(name), f}})); +} + +TVM_REGISTER_GLOBAL("driver.schedule_to_module") + .set_body_typed([](te::Schedule sch, const Array& args, const String& name, + const Map& binds) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != nullptr) { + for (auto kv : binds) { + c_binds.insert({kv.first, kv.second}); + } + } + IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds); + return mod; + }); + +IRModule LowerModule(IRModule mod, bool simple_mode) { + Array pass_list = CreatePassList(simple_mode, false); + return LowerWithPassList(std::move(mod), pass_list); +} + +TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) { + return LowerModule(std::move(mod), simple_mode); +}); + +IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_mode) { + transform::PassContext pass_ctx = transform::PassContext::Current(); + tir::PrimFunc f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); + + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + IRModule mod = IRModule(Map({{GlobalVar(name), f}})); + + // Get the pass list + Array pass_list = CreatePassList(simple_mode, false); + return LowerWithPassList(std::move(mod), pass_list); +} + +TVM_REGISTER_GLOBAL("driver.lower_primfunc") + .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) { + return LowerPrimFunc(std::move(func), name, simple_mode); + }); + +IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { + Array ref_args; + for (ObjectRef x : args) { + ref_args.push_back(x); + } + return LowerSchedule(std::move(sch), ref_args, name, binds); +} + +IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { + IRModule mod = ScheduleToModule(std::move(sch), args, name, binds); + // Get the legacy TE pass list + Array pass_list = CreatePassList(simple_mode, true); + return LowerWithPassList(mod, pass_list); +} + +TVM_REGISTER_GLOBAL("driver.lower_schedule") + .set_body_typed([](te::Schedule sch, const Array& args, const String& name, + const Map& binds, bool simple_mode) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != nullptr) { + for (auto kv : binds) { + c_binds.insert({kv.first, kv.second}); + } + } + return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode); + }); + std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, const Target& target_host_arg, const transform::PassContext& pass_ctx) { @@ -200,8 +384,15 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); - mixed_pass_list.push_back(tir::transform::MakePackedAPI(0)); + + if (target->GetAttr("unpacked-api").value_or(Bool(false))) { + mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); + } else { + mixed_pass_list.push_back(tir::transform::MakePackedAPI(0)); + } + mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + auto opt_mixed = transform::Sequential(mixed_pass_list); mod_mixed = opt_mixed(std::move(mod_mixed)); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 203520802091..caddf0efcc77 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -157,39 +157,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "GlobalVar(" << node->name_hint << ")"; }); -// Container printer -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '['; - for (size_t i = 0; i < op->size(); ++i) { - if (i != 0) { - p->stream << ", "; - } - p->Print(op->at(i)); - } - p->stream << ']'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '{'; - for (auto it = op->begin(); it != op->end(); ++it) { - if (it != op->begin()) { - p->stream << ", "; - } - if (it->first->IsInstance()) { - p->stream << '\"' << Downcast(it->first) << "\": "; - } else { - p->Print(it->first); - p->stream << ": "; - } - p->Print(it->second); - } - p->stream << '}'; - }); - TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { std::stringstream ss; ss << ref; diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc new file mode 100644 index 000000000000..795e5b8cb542 --- /dev/null +++ b/src/ir/instrument.cc @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/instrument.cc + * \brief Infrastructure for instrumentation. + */ +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace instrument { + +/*! + * \brief Base PassInstrument implementation + * \sa BasePassInstrument + */ +class BasePassInstrumentNode : public PassInstrumentNode { + public: + /*! \brief Callback to run when entering PassContext. */ + runtime::TypedPackedFunc enter_pass_ctx_callback; + /*! \brief Callback to run when exiting PassContext. */ + runtime::TypedPackedFunc exit_pass_ctx_callback; + + /*! \brief Callback determines whether to run a pass or not. */ + runtime::TypedPackedFunc should_run_callback; + + /*! \brief Callback to run before a pass. */ + runtime::TypedPackedFunc + run_before_pass_callback; + /*! \brief Callback to run after a pass. */ + runtime::TypedPackedFunc + run_after_pass_callback; + + /*! \brief Instrument when entering PassContext. */ + void EnterPassContext() const final; + + /*! \brief Instrument when exiting PassContext. */ + void ExitPassContext() const final; + + /*! + * \brief Determine whether to run the pass or not. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + * + * \return true to run the pass; false to skip the pass. + */ + bool ShouldRun(const IRModule&, const transform::PassInfo& info) const final; + + /*! + * \brief Instrument before pass run. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const final; + + /*! + * \brief Instrument after pass run. + * + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const final; + + static constexpr const char* _type_key = "instrument.PassInstrument"; + TVM_DECLARE_FINAL_OBJECT_INFO(BasePassInstrumentNode, PassInstrumentNode); +}; + +/*! + * \brief Managed reference class for BasePassInstrumentNode + * \sa BasePassInstrumentNode + */ +class BasePassInstrument : public PassInstrument { + public: + /*! + * \brief Constructor + * + * \param name Name for this instrumentation. + * + * + * \param enter_pass_ctx_callback Callback to call when entering pass context. + * \param exit_pass_ctx_callback Callback to call when exiting pass context. + * + * \param should_run_callback Callback to determine whether pass should run. (return true: enable; + * return false: disable) + * + * \param run_before_pass_callback Callback to call before a pass run. + * \param run_after_pass_callback Callback to call after a pass run. + */ + TVM_DLL BasePassInstrument( + String name, runtime::TypedPackedFunc enter_pass_ctx_callback, + runtime::TypedPackedFunc exit_pass_ctx_callback, + runtime::TypedPackedFunc + should_run_callback, + runtime::TypedPackedFunc + run_before_pass_callback, + runtime::TypedPackedFunc + run_after_pass_callback); + + TVM_DEFINE_OBJECT_REF_METHODS(BasePassInstrument, PassInstrument, BasePassInstrumentNode); +}; + +BasePassInstrument::BasePassInstrument( + String name, runtime::TypedPackedFunc enter_pass_ctx_callback, + runtime::TypedPackedFunc exit_pass_ctx_callback, + runtime::TypedPackedFunc should_run_callback, + runtime::TypedPackedFunc + run_before_pass_callback, + runtime::TypedPackedFunc + run_after_pass_callback) { + auto pi = make_object(); + pi->name = std::move(name); + + pi->enter_pass_ctx_callback = std::move(enter_pass_ctx_callback); + pi->exit_pass_ctx_callback = std::move(exit_pass_ctx_callback); + + pi->should_run_callback = std::move(should_run_callback); + + pi->run_before_pass_callback = std::move(run_before_pass_callback); + pi->run_after_pass_callback = std::move(run_after_pass_callback); + + data_ = std::move(pi); +} + +void BasePassInstrumentNode::EnterPassContext() const { + if (enter_pass_ctx_callback != nullptr) { + enter_pass_ctx_callback(); + } +} + +void BasePassInstrumentNode::ExitPassContext() const { + if (exit_pass_ctx_callback != nullptr) { + exit_pass_ctx_callback(); + } +} + +bool BasePassInstrumentNode::ShouldRun(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { + if (should_run_callback == nullptr) { + return true; + } + + return should_run_callback(ir_module, pass_info); +} + +void BasePassInstrumentNode::RunBeforePass(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { + if (run_before_pass_callback != nullptr) { + run_before_pass_callback(ir_module, pass_info); + } +} + +void BasePassInstrumentNode::RunAfterPass(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { + if (run_after_pass_callback != nullptr) { + run_after_pass_callback(ir_module, pass_info); + } +} + +TVM_REGISTER_NODE_TYPE(BasePassInstrumentNode); + +TVM_REGISTER_GLOBAL("instrument.PassInstrument") + .set_body_typed( + [](String name, runtime::TypedPackedFunc enter_pass_ctx, + runtime::TypedPackedFunc exit_pass_ctx, + runtime::TypedPackedFunc should_run, + runtime::TypedPackedFunc + run_before_pass, + runtime::TypedPackedFunc + run_after_pass) { + return BasePassInstrument(name, enter_pass_ctx, exit_pass_ctx, should_run, + run_before_pass, run_after_pass); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << node->name; + }); + +/*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */ +struct PassProfile { + // TODO(@altanh): expose PassProfile through TVM Object API + using Clock = std::chrono::steady_clock; + using Duration = std::chrono::duration; + using Time = std::chrono::time_point; + + /*! \brief The name of the pass being profiled. */ + String name; + /*! \brief The time when the pass was entered. */ + Time start; + /*! \brief The time when the pass completed. */ + Time end; + /*! \brief The total duration of the pass, i.e. end - start. */ + Duration duration; + /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */ + std::vector children; + + explicit PassProfile(String name) + : name(name), start(Clock::now()), end(Clock::now()), children() {} + + /*! \brief Gets the PassProfile of the currently executing pass. */ + static PassProfile* Current(); + /*! \brief Pushes a new PassProfile with the given pass name. */ + static void EnterPass(String name); + /*! \brief Pops the current PassProfile. */ + static void ExitPass(); +}; + +struct PassProfileThreadLocalEntry { + /*! \brief The placeholder top-level PassProfile. */ + PassProfile root; + /*! \brief The stack of PassProfiles for nested passes currently running. */ + std::stack profile_stack; + + PassProfileThreadLocalEntry() : root("root") {} +}; + +/*! \brief Thread local store to hold the pass profiling data. */ +typedef dmlc::ThreadLocalStore PassProfileThreadLocalStore; + +void PassProfile::EnterPass(String name) { + PassProfile* cur = PassProfile::Current(); + cur->children.emplace_back(name); + PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back()); +} + +void PassProfile::ExitPass() { + PassProfile* cur = PassProfile::Current(); + ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling"; + cur->end = PassProfile::Clock::now(); + cur->duration = std::chrono::duration_cast(cur->end - cur->start); + PassProfileThreadLocalStore::Get()->profile_stack.pop(); +} + +PassProfile* PassProfile::Current() { + PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); + if (!entry->profile_stack.empty()) { + return entry->profile_stack.top(); + } else { + return &entry->root; + } +} + +String RenderPassProfiles() { + PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); + CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!"; + + if (entry->root.children.empty()) { + LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?"; + return String(); + } + + // (depth, parent_duration, pass) + std::stack> profiles; + + // push top level passes + PassProfile::Duration top_dur(0); + for (auto it = entry->root.children.begin(); it != entry->root.children.end(); ++it) { + top_dur += it->duration; + } + for (auto it = entry->root.children.rbegin(); it != entry->root.children.rend(); ++it) { + profiles.push(std::make_tuple(0, top_dur, &*it)); + } + + std::ostringstream os; + os << std::fixed; + + while (profiles.size() > 0) { + size_t depth; + PassProfile::Duration parent_duration; + PassProfile* profile; + std::tie(depth, parent_duration, profile) = profiles.top(); + profiles.pop(); + + // indent depth + for (size_t i = 0; i < depth; ++i) { + os << "\t"; + } + + // calculate time spent in pass itself (excluding sub-passes), and push children + PassProfile::Duration self_duration = profile->duration; + for (auto it = profile->children.rbegin(); it != profile->children.rend(); ++it) { + self_duration -= it->duration; + profiles.push(std::make_tuple(depth + 1, profile->duration, &*it)); + } + + double parent_pct = profile->duration.count() / parent_duration.count() * 100.0; + double total_pct = profile->duration.count() / top_dur.count() * 100.0; + + os << profile->name << ": "; + os << std::setprecision(0); + os << profile->duration.count() << "us [" << self_duration.count() << "us] "; + os << std::setprecision(2) << "(" << total_pct << "%; " << parent_pct << "%)\n"; + } + + return os.str(); +} + +TVM_REGISTER_GLOBAL("instrument.RenderTimePassProfiles").set_body_typed(RenderPassProfiles); + +TVM_REGISTER_GLOBAL("instrument.MakePassTimingInstrument").set_body_typed([]() { + auto run_before_pass = [](const IRModule&, const transform::PassInfo& pass_info) { + PassProfile::EnterPass(pass_info->name); + return true; + }; + + auto run_after_pass = [](const IRModule&, const transform::PassInfo& pass_info) { + PassProfile::ExitPass(); + }; + + auto exit_pass_ctx = []() { PassProfileThreadLocalStore::Get()->root.children.clear(); }; + + return BasePassInstrument("PassTimingInstrument", + /* enter_pass_ctx */ nullptr, exit_pass_ctx, /* should_run */ nullptr, + run_before_pass, run_after_pass); +}); + +} // namespace instrument +} // namespace tvm diff --git a/src/ir/op.cc b/src/ir/op.cc index 8fd34d30ffa7..fac15a7daad4 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include #include #include @@ -102,10 +101,71 @@ TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name) reg.reset_attr(attr_name); }); -TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) { +TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String descr) { const OpRegEntry* reg = OpRegistry::Global()->Get(op_name); ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before"; - OpRegistry::Global()->RegisterOrGet(op_name).set_name(); + auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); + op.describe(descr); +}); + +// This is exposed FFI api for prototyping using in python. +// Note: it is not full of the C++ type relation, +// since in python side we don't have access to the type reporter, +// and cannot propagate constraints to the inputs, only to the output. +TVM_REGISTER_GLOBAL("ir.OpAddTypeRel") + .set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + if (value.type_code() == kTVMPackedFuncHandle) { + // do an eager copy of the PackedFunc to avoid deleting function from frontend. + PackedFunc* fcopy = new PackedFunc(value.operator tvm::runtime::PackedFunc()); + auto f = [=](const Array& args, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) -> bool { + Array input_types(args.begin(), args.end() - 1); + // call customized relation functions + // *fcopy's signature: function (args: List[Type], attrs: Attrs) -> Type + Type ret_type = (*fcopy)(input_types, attrs); + // when defined ret_type, inference of output type is ok, do type assign + // otherwise, inference failure happens + if (ret_type.defined()) { + // the last argument is output + // TODO(xqdan): support multiple outputs + reporter->Assign(args.back(), ret_type); + return true; + } + return false; + }; + // adjust function call to call conventions of relay type system with TypeReporter + auto type_rel = runtime::TypedPackedFunc&, int, const Attrs&, + const TypeReporter&)>(f); + reg.add_type_rel(rel_name, type_rel); + } else if (value.type_code() == kTVMNullptr) { + // Call relation functions of relay + auto func_name = std::string("tvm.relay.type_relation.") + rel_name; + auto* f = runtime::Registry::Get(func_name); + ICHECK(f != nullptr) << "AddTypeRel error: no type_relation registered."; + reg.add_type_rel(rel_name, *f); + } + }); + +TVM_REGISTER_GLOBAL("ir.OpAddArgument") + .set_body_typed([](Op op, String name, String type, String description) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.add_argument(name, type, description); + }); + +TVM_REGISTER_GLOBAL("ir.OpSetSupportLevel").set_body_typed([](Op op, int level) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.set_support_level(level); +}); + +TVM_REGISTER_GLOBAL("ir.OpSetNumInputs").set_body_typed([](Op op, int n) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.set_num_inputs(n); +}); + +TVM_REGISTER_GLOBAL("ir.OpSetAttrsTypeKey").set_body_typed([](Op op, String key) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.set_attrs_type_key(key); }); TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 48f13bc81df4..8120ca798ab2 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include @@ -56,6 +55,8 @@ struct PassContextThreadLocalEntry { typedef dmlc::ThreadLocalStore RelayPassContextThreadLocalStore; void PassContext::EnterWithScope() { + InstrumentEnterPassContext(); + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); entry->context_stack.push(*this); } @@ -65,6 +66,8 @@ void PassContext::ExitWithScope() { ICHECK(!entry->context_stack.empty()); ICHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); + + InstrumentExitPassContext(); } PassContext PassContext::Current() { @@ -142,6 +145,16 @@ class PassConfigManager { } } + Map> ListConfigs() { + Map> configs; + for (const auto& kv : key2vtype_) { + Map metadata; + metadata.Set("type", kv.second.type_key); + configs.Set(kv.first, metadata); + } + return configs; + } + static PassConfigManager* Global() { static auto* inst = new PassConfigManager(); return inst; @@ -160,172 +173,102 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde PassConfigManager::Global()->Register(key, value_type_index); } +Map> PassContext::ListConfigs() { + return PassConfigManager::Global()->ListConfigs(); +} + PassContext PassContext::Create() { return PassContext(make_object()); } -void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { +void PassContext::InstrumentEnterPassContext() { auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->trace_func != nullptr) { - pass_ctx_node->trace_func(module, info, is_before); + if (pass_ctx_node->instruments.defined()) { + Array enter_successes; + try { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->EnterPassContext(); + enter_successes.push_back(pi); + } + } catch (const Error& e) { + LOG(INFO) << "Pass instrumentation entering pass context failed."; + LOG(INFO) << "Disable pass instrumentation."; + pass_ctx_node->instruments.clear(); + + for (instrument::PassInstrument pi : enter_successes) { + LOG(INFO) << pi->name << " exiting PassContext ..."; + pi->ExitPassContext(); + LOG(INFO) << pi->name << " exited PassContext."; + } + enter_successes.clear(); + + throw e; + } } } -class ModulePass; - -/*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */ -struct PassProfile { - // TODO(@altanh): expose PassProfile through TVM Object API - using Clock = std::chrono::steady_clock; - using Duration = std::chrono::duration; - using Time = std::chrono::time_point; - - /*! \brief The name of the pass being profiled. */ - String name; - /*! \brief The time when the pass was entered. */ - Time start; - /*! \brief The time when the pass completed. */ - Time end; - /*! \brief The total duration of the pass, i.e. end - start. */ - Duration duration; - /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */ - std::vector children; - - explicit PassProfile(String name) - : name(name), start(Clock::now()), end(Clock::now()), children() {} - - /*! \brief Gets the PassProfile of the currently executing pass. */ - static PassProfile* Current(); - /*! \brief Pushes a new PassProfile with the given pass name. */ - static void EnterPass(String name); - /*! \brief Pops the current PassProfile. */ - static void ExitPass(); -}; - -struct PassProfileThreadLocalEntry { - /*! \brief The placeholder top-level PassProfile. */ - PassProfile root; - /*! \brief The stack of PassProfiles for nested passes currently running. */ - std::stack profile_stack; - /*! \brief Whether or not pass profiling is active. */ - bool active; - - PassProfileThreadLocalEntry() : root("root"), active(false) {} -}; +void PassContext::InstrumentExitPassContext() { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->instruments.defined()) { + try { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->ExitPassContext(); + } + } catch (const Error& e) { + LOG(INFO) << "Pass instrumentation exiting pass context failed."; + pass_ctx_node->instruments.clear(); + throw e; + } + } +} -/*! \brief Thread local store to hold the pass profiling data. */ -typedef dmlc::ThreadLocalStore PassProfileThreadLocalStore; +bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo& pass_info) const { + auto pass_ctx_node = this->operator->(); + if (!pass_ctx_node->instruments.defined()) { + return true; + } -void PassProfile::EnterPass(String name) { - if (!PassProfileThreadLocalStore::Get()->active) return; - PassProfile* cur = PassProfile::Current(); - cur->children.emplace_back(name); - PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back()); -} + const bool pass_required = PassArrayContains(pass_ctx_node->required_pass, pass_info->name); + bool should_run = true; + if (!pass_required) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + should_run &= pi->ShouldRun(ir_module, pass_info); + } + } -void PassProfile::ExitPass() { - if (!PassProfileThreadLocalStore::Get()->active) return; - PassProfile* cur = PassProfile::Current(); - ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling"; - cur->end = std::move(PassProfile::Clock::now()); - cur->duration = std::chrono::duration_cast(cur->end - cur->start); - PassProfileThreadLocalStore::Get()->profile_stack.pop(); + if (should_run) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->RunBeforePass(ir_module, pass_info); + } + } + return should_run; } -PassProfile* PassProfile::Current() { - PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); - if (!entry->profile_stack.empty()) { - return entry->profile_stack.top(); - } else { - return &entry->root; +void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& pass_info) const { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->RunAfterPass(ir_module, pass_info); + } } } IRModule Pass::operator()(IRModule mod) const { - const PassNode* node = operator->(); - ICHECK(node != nullptr); - PassProfile::EnterPass(node->Info()->name); - auto ret = node->operator()(std::move(mod)); - PassProfile::ExitPass(); - return std::move(ret); + return this->operator()(std::move(mod), PassContext::Current()); } IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassNode* node = operator->(); ICHECK(node != nullptr); - PassProfile::EnterPass(node->Info()->name); + const PassInfo& pass_info = node->Info(); + if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { + DLOG(INFO) << "Skipping pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level; + return mod; + } auto ret = node->operator()(std::move(mod), pass_ctx); - PassProfile::ExitPass(); + pass_ctx.InstrumentAfterPass(ret, pass_info); return std::move(ret); } -String RenderPassProfiles() { - PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); - CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!"; - - if (entry->root.children.empty()) { - LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?"; - return String(); - } - - // (depth, parent_duration, pass) - std::stack> profiles; - - // push top level passes - PassProfile::Duration top_dur(0); - for (auto it = entry->root.children.begin(); it != entry->root.children.end(); ++it) { - top_dur += it->duration; - } - for (auto it = entry->root.children.rbegin(); it != entry->root.children.rend(); ++it) { - profiles.push(std::make_tuple(0, top_dur, &*it)); - } - - std::ostringstream os; - os << std::fixed; - - while (profiles.size() > 0) { - size_t depth; - PassProfile::Duration parent_duration; - PassProfile* profile; - std::tie(depth, parent_duration, profile) = profiles.top(); - profiles.pop(); - - // indent depth - for (size_t i = 0; i < depth; ++i) { - os << "\t"; - } - - // calculate time spent in pass itself (excluding sub-passes), and push children - PassProfile::Duration self_duration = profile->duration; - for (auto it = profile->children.rbegin(); it != profile->children.rend(); ++it) { - self_duration -= it->duration; - profiles.push(std::make_tuple(depth + 1, profile->duration, &*it)); - } - - double parent_pct = profile->duration.count() / parent_duration.count() * 100.0; - double total_pct = profile->duration.count() / top_dur.count() * 100.0; - - os << profile->name << ": "; - os << std::setprecision(0); - os << profile->duration.count() << "us [" << self_duration.count() << "us] "; - os << std::setprecision(2) << "(" << total_pct << "%; " << parent_pct << "%)\n"; - } - - return os.str(); -} - -TVM_REGISTER_GLOBAL("transform.render_pass_profiles").set_body_typed(RenderPassProfiles); - -TVM_REGISTER_GLOBAL("transform.clear_pass_profiles").set_body_typed([]() { - PassProfileThreadLocalStore::Get()->root.children.clear(); -}); - -TVM_REGISTER_GLOBAL("transform.enable_pass_profiling").set_body_typed([]() { - PassProfileThreadLocalStore::Get()->active = true; -}); - -TVM_REGISTER_GLOBAL("transform.disable_pass_profiling").set_body_typed([]() { - PassProfileThreadLocalStore::Get()->active = false; -}); - /*! * \brief Module-level passes are designed to implement global * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes @@ -464,12 +407,11 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c << "The diagnostic context was set at the top of this block this is a bug."; const PassInfo& pass_info = Info(); + ICHECK(mod.defined()) << "The input module must be set."; + DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; - ICHECK(mod.defined()) << "The input module must be set."; - - pass_ctx.Trace(mod, pass_info, true); mod = pass_func(std::move(mod), pass_ctx); ICHECK(mod.defined()) << "The return value of a module pass must be set."; @@ -480,7 +422,6 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; - pass_ctx.Trace(mod, pass_info, false); return mod; } @@ -621,13 +562,14 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, - TraceFunc trace_func, Optional> config) { + Array instruments, + Optional> config) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; pctx->required_pass = std::move(required); pctx->disabled_pass = std::move(disabled); - pctx->trace_func = std::move(trace_func); + pctx->instruments = std::move(instruments); if (config.defined()) { pctx->config = config.value(); } @@ -642,17 +584,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << "\n"; p->stream << "\topt_level: " << node->opt_level << "\n"; - p->stream << "\trequired passes: ["; - for (const auto& it : node->required_pass) { - p->stream << it << " "; - } - p->stream << "]\n"; + p->stream << "\trequired passes: " << node->required_pass << "\n"; + p->stream << "\tdisabled passes: " << node->disabled_pass << "\n"; + p->stream << "\tinstruments: " << node->instruments << "\n"; - p->stream << "\tdisabled passes: ["; - for (const auto& it : node->disabled_pass) { - p->stream << it << " "; - } - p->stream << "]\n"; p->stream << "\tconfig: " << node->config; }); @@ -669,6 +604,13 @@ TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::In TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope); +TVM_REGISTER_GLOBAL("transform.OverrideInstruments") + .set_body_typed([](PassContext pass_ctx, Array instruments) { + pass_ctx.InstrumentExitPassContext(); + pass_ctx->instruments = instruments; + pass_ctx.InstrumentEnterPassContext(); + }); + Pass PrintIR(String header, bool show_meta_data) { auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data); @@ -679,5 +621,7 @@ Pass PrintIR(String header, bool show_meta_data) { TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); +TVM_REGISTER_GLOBAL("transform.ListConfigs").set_body_typed(PassContext::ListConfigs); + } // namespace transform } // namespace tvm diff --git a/src/node/attr_registry.h b/src/node/attr_registry.h index f84be1467453..050f9e5b2845 100644 --- a/src/node/attr_registry.h +++ b/src/node/attr_registry.h @@ -25,7 +25,6 @@ #define TVM_NODE_ATTR_REGISTRY_H_ #include -#include #include #include diff --git a/src/node/container_printing.cc b/src/node/container_printing.cc new file mode 100644 index 000000000000..1565630cc6ac --- /dev/null +++ b/src/node/container_printing.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Printer implementation for containers + * \file node/container_printint.cc + */ +#include +#include +#include + +namespace tvm { + +// Container printer +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '['; + for (size_t i = 0; i < op->size(); ++i) { + if (i != 0) { + p->stream << ", "; + } + p->Print(op->at(i)); + } + p->stream << ']'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '{'; + for (auto it = op->begin(); it != op->end(); ++it) { + if (it != op->begin()) { + p->stream << ", "; + } + if (it->first->IsInstance()) { + p->stream << '\"' << Downcast(it->first) << "\": "; + } else { + p->Print(it->first); + p->stream << ": "; + } + p->Print(it->second); + } + p->stream << '}'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '['; + for (size_t i = 0; i < op->size; ++i) { + if (i != 0) { + p->stream << ", "; + } + p->stream << op->data[i]; + } + p->stream << ']'; + }); +} // namespace tvm diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 79a53aa26440..a7c3493e7feb 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include namespace tvm { diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 75f03fbc7954..94dfda556cc9 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -26,7 +26,6 @@ #include #include #include -#include #include #include #include diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 05327b1ca303..f5344ab9126e 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include diff --git a/src/parser/op_table.h b/src/parser/op_table.h index 050904f23280..28c9cd7fc05f 100644 --- a/src/parser/op_table.h +++ b/src/parser/op_table.h @@ -28,7 +28,6 @@ #define TVM_PARSER_OP_TABLE_H_ #include -#include #include #include diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 793d6bb9a43d..c6407e8909d9 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -524,13 +524,22 @@ class Parser { NDArray NumberToNDArray(const Token& token) { if (token->token_type == TokenType::kInteger) { DLDevice dev = {DLDeviceType::kDLCPU, 0}; - auto dtype = String2DLDataType("int32"); - auto data = NDArray::Empty({}, dtype, dev); - auto array = reinterpret_cast(data->data); - // revisit this, literal node issue. - int64_t value = Downcast(token->data); - array[0] = (int32_t)value; - return data; + int64_t i = Downcast(token->data); + if (i > std::numeric_limits::max()) { + auto dtype = String2DLDataType("int64"); + auto data = NDArray::Empty({}, dtype, dev); + auto array = reinterpret_cast(data->data); + // revisit this, literal node issue. + array[0] = i; + return data; + } else { + auto dtype = String2DLDataType("int32"); + auto data = NDArray::Empty({}, dtype, dev); + auto array = reinterpret_cast(data->data); + // revisit this, literal node issue. + array[0] = i; + return data; + } } else if (token->token_type == TokenType::kFloat) { DLDevice dev = {DLDeviceType::kDLCPU, 0}; auto float_imm = Downcast(token->data); @@ -1516,7 +1525,7 @@ class Parser { } case TokenType::kBoolean: { Consume(TokenType::kBoolean); - int value = Downcast(next->data); + int64_t value = Downcast(next->data); auto boolean = BooleanToNDarray(value); Expr e = Constant(boolean, next->span); ICHECK(e->span.defined()) << "constant spans must be defined"; diff --git a/src/parser/span_check.h b/src/parser/span_check.h index ab71d30a54f5..0074c66d61f4 100644 --- a/src/parser/span_check.h +++ b/src/parser/span_check.h @@ -29,7 +29,6 @@ #include #include #include -#include #include #include diff --git a/src/parser/token.h b/src/parser/token.h index 1133483fa8f8..31e974355e4b 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -26,7 +26,6 @@ #define TVM_PARSER_TOKEN_H_ #include -#include #include #include diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 5e71794cc7fb..e26c97429d6b 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -25,10 +25,10 @@ #define TVM_PARSER_TOKENIZER_H_ #include -#include #include #include +#include #include #include #include @@ -172,44 +172,53 @@ struct Tokenizer { Token ParseNumber(bool is_pos, bool is_float, std::string number) { ICHECK(number.size() > 0) << "an empty string is an invalid number"; - try { - if (is_float) { - throw std::invalid_argument("is_float"); - } + if (!is_float) { auto token = NewToken(TokenType::kInteger); size_t index = 0; - int value = std::stoi(number, &index); - if (number.size() > index) { - throw std::invalid_argument("floating point"); + int64_t value = 0; + try { + value = std::stoll(number, &index); + } catch (const std::invalid_argument& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`"); + } catch (const std::out_of_range& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`"); } - value = is_pos ? value : -value; - token->data = tvm::Integer(value); - return token; - } catch (const std::invalid_argument& ia) { - auto token = NewToken(TokenType::kFloat); + if (number.size() <= index) { + value = is_pos ? value : -value; + if (value > std::numeric_limits::max()) { + token->data = tvm::IntImm(DataType::Int(64), value); + } else { + token->data = tvm::IntImm(DataType::Int(32), value); + } + return token; + } + } + auto token = NewToken(TokenType::kFloat); - auto suffix_pos = number.rfind("f"); + auto suffix_pos = number.rfind("f"); - auto literal_text = number.substr(0, suffix_pos); + auto literal_text = number.substr(0, suffix_pos); - auto suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos); + auto suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos); - int width = 32; + int width = 32; - if (suffix.size()) { - try { - width = std::stoi(suffix); - } catch (const std::invalid_argument& err) { - this->diag_ctx.Emit(Diagnostic::Error(token->span) - << "invalid numeric suffix `" << suffix << "`"); - } + if (suffix.size()) { + try { + width = std::stoi(suffix); + } catch (const std::invalid_argument& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid numeric suffix `" << suffix << "`"); + } catch (const std::out_of_range& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid numeric suffix `" << suffix << "`"); } - - double value = stod(literal_text); - value = is_pos ? value : -value; - token->data = tvm::FloatImm(DataType::Float(width), value); - return token; } + + double value = stod(literal_text); + value = is_pos ? value : -value; + token->data = tvm::FloatImm(DataType::Float(width), value); + return token; } Token ParseNumber(bool is_pos) { diff --git a/src/printer/meta_data.h b/src/printer/meta_data.h index f76c32d353cf..b2e245bd5b45 100644 --- a/src/printer/meta_data.h +++ b/src/printer/meta_data.h @@ -25,7 +25,6 @@ #define TVM_PRINTER_META_DATA_H_ #include -#include #include #include diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 85a9c51a2fa8..840878390018 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -21,7 +21,6 @@ #include #include -#include #include #include diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index d225cb8ae82a..2e4eec23f733 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -33,7 +33,6 @@ #include #include #include -#include #include #include diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc index 8dd6819e0e8c..35813f67d094 100644 --- a/src/relay/analysis/context_analysis.cc +++ b/src/relay/analysis/context_analysis.cc @@ -59,7 +59,6 @@ #include #include #include -#include #include namespace tvm { diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index a48a8a847c7c..66294d1dd076 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -137,10 +137,17 @@ class AOTExecutorCodegen : public ExprVisitor { // Pack the sid inside the TVMValue auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle()); auto sid_value = sids_table_[sid]; - tvm::PrimExpr set_tensor = - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrData, sid_value}); - stmts_.push_back(tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor))); + + if (!use_unpacked_api_) { + tvm::PrimExpr set_tensor = + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {sid_array, 0, tir::builtin::kArrData, sid_value}); + stmts_.push_back( + tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor))); + } else { + stmts_.push_back(tir::LetStmt(sid_array, sid_value, tir::Evaluate(0))); + } + sid_vars.push_back(sid_array); } return sid_vars; @@ -148,7 +155,7 @@ class AOTExecutorCodegen : public ExprVisitor { /*! * \brief Utility function to return a parameter associated with an expression - * \param expr Relay Expression assicated with the parameter + * \param expr Relay Expression associated with the parameter * \return Variable that represents the DLTensor associated with the parameters */ tir::Var PackParam(Expr expr) { @@ -161,16 +168,16 @@ class AOTExecutorCodegen : public ExprVisitor { auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), {tir::StringImm(params_by_expr_[expr])}); - tvm::PrimExpr set_param_array = - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {param_array, 0, tir::builtin::kArrData, param_handle}); - lookup_call.push_back(tir::Evaluate(set_param_array)); - - tir::Stmt lookup_body = tir::SeqStmt(lookup_call); + if (!use_unpacked_api_) { + tvm::PrimExpr set_param_array = + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {param_array, 0, tir::builtin::kArrData, param_handle}); + stmts_.push_back( + tir::LetStmt(param_array, StackAlloca("arg_value", 1), tir::Evaluate(set_param_array))); + } else { + stmts_.push_back(tir::LetStmt(param_array, param_handle, tir::Evaluate(0))); + } - // Allocate the DLTensors on the stack - lookup_body = tir::LetStmt(param_array, StackAlloca("arg_value", 1), lookup_body); - stmts_.push_back(lookup_body); return param_array; } @@ -206,15 +213,20 @@ class AOTExecutorCodegen : public ExprVisitor { } auto ret_expr = Downcast(call); - // Pack the return(s) value. A call node can produce multiple outputs for (const auto& var : PackSid(ret_expr)) { args.push_back(var); } - // Use tvm_call_packed to execute the function - create_func_call_stmts.push_back(tir::Evaluate( - tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args))); + // Use tvm_call_packed to execute the function unless we're calling directly + auto calling_pattern = tvm::tir::builtin::tvm_call_cpacked(); + if (use_unpacked_api_) { + calling_pattern = tvm::tir::builtin::call_extern(); + } + + create_func_call_stmts.push_back( + tir::Evaluate(tvm::tir::Call(DataType::Int(32), calling_pattern, args))); + tir::Stmt body = tir::SeqStmt(create_func_call_stmts); stmts_.push_back(body); } @@ -226,16 +238,20 @@ class AOTExecutorCodegen : public ExprVisitor { * copy-on-write fashion. */ void CopyToOutput(te::Var out, te::Var in, size_t size) { - auto retval_get = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), - {in, 0, tir::builtin::kArrData}); - // Define intermediate DLTensor to load/store the data auto tmp0 = te::Var("tmp0", DataType::Handle()); auto tmp1 = te::Var("tmp1", DataType::Handle()); te::Var loop_idx("i", DataType::Int(32)); auto retval_i = tir::Load(DataType::UInt(8), tmp0, loop_idx, tir::const_true()); - auto tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), - {out, 0, tir::builtin::kArrData}); + + PrimExpr retval_get = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), + {in, 0, tir::builtin::kArrData}); + PrimExpr tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), + {out, 0, tir::builtin::kArrData}); + if (use_unpacked_api_) { + retval_get = in; + tostore = out; + } // Copy the variable from the input to the output tir::Stmt copy = tir::For( @@ -535,6 +551,15 @@ class AOTExecutorCodegen : public ExprVisitor { TargetsMap targets_; /*! \brief target host */ Target target_host_; + /*! + * \brief unpacked api toggle + * When set to true the code generated will use unpacked calls to functions: + * func(void* arg0, void* arg1) + * Rather than packed calls: + * func(void* args) + * Defaults to using the packed calling convention + */ + Bool use_unpacked_api_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). @@ -564,21 +589,20 @@ class AOTExecutorCodegen : public ExprVisitor { public: AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host) - : mod_(mod), return_sid_() { - compile_engine_ = CompileEngine::Global(); - targets_ = targets; - target_host_ = target_host; - } + : mod_(mod), + targets_(targets), + target_host_(target_host), + use_unpacked_api_(target_host->GetAttr("unpacked-api").value_or(Bool(false))), + compile_engine_(CompileEngine::Global()) {} LoweredOutput Codegen(relay::Function func) { // Get the module, storage map and token sizes auto pf = GetPackedFunc("relay.backend.GraphPlanMemory"); storage_device_map_ = (*pf)(func); - int input_index = 0; for (auto input : func->params) { input_vars_.push_back(input); - main_signature_.push_back(tir::Var(MakeString("input_", input_index), DataType::Handle())); + main_signature_.push_back(tir::Var("input", DataType::Handle())); } // Define the storage allocator ids @@ -592,7 +616,7 @@ class AOTExecutorCodegen : public ExprVisitor { // Find the return sid return_sid_ = AotReturnSidVisitor(storage_device_map_).FindReturnSid(func); for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) { - main_signature_.push_back(tir::Var(MakeString("output_", output_index), DataType::Handle())); + main_signature_.push_back(tir::Var("output", DataType::Handle())); } VisitExpr(func->body); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 5e3b66b3ae15..29f7d30833a0 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -31,7 +31,6 @@ #include #include #include -#include #include #include #include @@ -163,7 +162,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> } } - // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. + // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. if (!schedule.defined()) { ICHECK(anchor_implementation_.defined()); schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); @@ -763,15 +762,9 @@ class CompileEngineImpl : public CompileEngineNode { all_args.push_back(arg); } // lower the function - if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { - cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); - } else { - using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); + std::unordered_map binds; + cache_node->funcs = tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); - std::unordered_map binds; - cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds); - } value->cached_func = CachedFunc(cache_node); return value; } @@ -807,7 +800,7 @@ class CompileEngineImpl : public CompileEngineNode { With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; - cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds); + cache_node->funcs = tvm::LowerSchedule(spair.first, all_args, cache_node->func_name, binds); value->cached_func = CachedFunc(cache_node); return value; } diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index b81fd14b99c2..32eecec25b06 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -27,7 +27,6 @@ #include #include #include -#include #include #include diff --git a/src/relay/backend/contrib/codegen_json/codegen_json.h b/src/relay/backend/contrib/codegen_json/codegen_json.h index 192e09140375..4966f3f01c7d 100644 --- a/src/relay/backend/contrib/codegen_json/codegen_json.h +++ b/src/relay/backend/contrib/codegen_json/codegen_json.h @@ -27,7 +27,6 @@ #include #include #include -#include #include #include diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index c9a58282d13e..e96255e976e9 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -393,7 +393,6 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; - code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; // dnnl_kernel file is saved under src/runtime/contrib/dnnl so that we don't diff --git a/src/relay/backend/contrib/ethosn/capabilities.h b/src/relay/backend/contrib/ethosn/capabilities.h deleted file mode 100644 index cc14ca101da6..000000000000 --- a/src/relay/backend/contrib/ethosn/capabilities.h +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relay/backend/contrib/ethosn/capabilities.h - * \brief The Ethos-N processor series has four variants, the Ethos-N37, Ethos-N57, Ethos-N77 - * and the Ethos-N78. This release of the integration supports the first three variants and - * the default configuration of the fourth variant. - * Configuration information for each variant is stored as a blob in this file. These blobs - * are passed into the Ethos-N support library, which in turn uses them to optimize the - * generated command-stream appropriately for the specified variant. - */ - -#ifndef TVM_RELAY_BACKEND_CONTRIB_ETHOSN_CAPABILITIES_H_ -#define TVM_RELAY_BACKEND_CONTRIB_ETHOSN_CAPABILITIES_H_ - -#include - -#include "ethosn_api_version.h" - -namespace tvm { -namespace relay { -namespace contrib { -namespace ethosn { - -/* Ethos-N variants (Ethos-N77, Ethos-N57, Ethos-N37 and Ethos-N78) - * variant[0] - Ethos-N77 - * variant[1] - Ethos-N57 - * variant[2] - Ethos-N37 - * variant[3] - Ethos-N78 - */ -#if _ETHOSN_API_VERSION_ == 2011 -static std::vector variants[4] = { - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x00, 0x02, 0x00, - 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x02, - 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - }}; -#else -static std::vector variants[4] = { - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x00, 0x02, 0x00, - 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - }}; -#endif -} // namespace ethosn -} // namespace contrib -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_CONTRIB_ETHOSN_CAPABILITIES_H_ diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index dab0e6c42f80..97b308e51e18 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -24,7 +24,6 @@ #include #include -#include "capabilities.h" #include "codegen_ethosn.h" #include "ethosn_api.h" @@ -198,19 +197,14 @@ sl::TensorsAndId MakeOps(const sl::TensorAndId& op) { NetworkWithIDs ConstructNetworkVisitor::Construct(const Function& func) { // Initialise everything -#if _ETHOSN_API_VERSION_ >= 2011 auto ctx = transform::PassContext::Current(); auto cfg = ctx->GetConfig("relay.ext.ethos-n.options"); if (!cfg.defined()) { cfg = AttrsWithDefaultValues(); } -#endif NetworkWithIDs network_with_ids; -#if _ETHOSN_API_VERSION_ >= 2011 - network_ = sl::CreateNetwork(variants[cfg.value()->variant]); -#else - network_ = sl::CreateNetwork(); -#endif + network_ = sl::CreateNetwork(sl::GetFwAndHwCapabilities( + sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); network_with_ids.network = network_; operand_table_.clear(); @@ -572,11 +566,7 @@ sl::CompilationOptions EthosnCompiler::CreateOptions() { cfg = AttrsWithDefaultValues(); } -#if _ETHOSN_API_VERSION_ >= 2011 sl::CompilationOptions options; -#else - sl::CompilationOptions options(variants[cfg.value()->variant]); -#endif options.m_Strategy0 = cfg.value()->strategy0; options.m_Strategy1 = cfg.value()->strategy1; options.m_Strategy3 = cfg.value()->strategy3; @@ -590,9 +580,6 @@ sl::CompilationOptions EthosnCompiler::CreateOptions() { options.m_BlockConfig8x32 = cfg.value()->block_config_8x32; options.m_BlockConfig8x8 = cfg.value()->block_config_8x8; options.m_EnableIntermediateCompression = cfg.value()->enable_intermediate_compression; -#if _ETHOSN_API_VERSION_ == 2008 - options.m_DebugInfo.m_DumpDebugFiles = cfg.value()->dump_debug_files; -#endif options.m_DisableWinograd = cfg.value()->disable_winograd; options.m_DebugInfo.m_DebugDir = cfg.value()->debug_dir; options.m_CompilerAlgorithm = @@ -619,20 +606,18 @@ std::pair, std::vector> EthosnCompiler::GetInput return std::make_pair(input_order, output_order); } -#if _ETHOSN_API_VERSION_ >= 2011 auto ctx = transform::PassContext::Current(); auto cfg = ctx -> GetConfig("relay.ext.ethos-n.options").defined() ? ctx -> GetConfig("relay.ext.ethos-n.options") : AttrsWithDefaultValues(); -auto m_Queries = sl::SupportQueries(variants[cfg.value()->variant]); -#endif +auto m_Queries = sl::SupportQueries(sl::GetFwAndHwCapabilities( + sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; ConvolutionParams params; auto err = EthosnAPI::QnnConv2d(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 if (params.is_depthwise) { *rv = !err && m_Queries.IsDepthwiseConvolutionSupported(params.bias_info, params.weights_info, @@ -641,15 +626,6 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d") *rv = !err && m_Queries.IsConvolutionSupported(params.bias_info, params.weights_info, params.conv_info, params.activation_info); } -#else - if (params.is_depthwise) { - *rv = !err && sl::IsDepthwiseConvolutionSupported(params.bias_info, params.weights_info, - params.conv_info, params.activation_info); - } else { - *rv = !err && sl::IsConvolutionSupported(params.bias_info, params.weights_info, - params.conv_info, params.activation_info); - } -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc") @@ -657,13 +633,8 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc") Call call = args[0]; FullyConnectedParams params; auto err = EthosnAPI::QnnFullyConnected(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsFullyConnectedSupported(params.bias_info, params.weights_info, params.fc_info, params.input_info); -#else - *rv = !err && sl::IsFullyConnectedSupported(params.bias_info, params.weights_info, - params.fc_info, params.input_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") @@ -671,11 +642,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") Call call = args[0]; MaxPool2DParams params; auto err = EthosnAPI::MaxPool2D(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); -#else - *rv = !err && sl::IsPoolingSupported(params.pool_info, params.input_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") @@ -683,11 +650,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") Call call = args[0]; AvgPool2DParams params; auto err = EthosnAPI::AvgPool2D(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); -#else - *rv = !err && sl::IsPoolingSupported(params.pool_info, params.input_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") @@ -695,11 +658,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") Call call = args[0]; ReshapeParams params; auto err = EthosnAPI::Reshape(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsReshapeSupported(params.new_shape, params.input_info); -#else - *rv = !err && sl::IsReshapeSupported(params.new_shape, params.input_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") @@ -707,13 +666,8 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") Call call = args[0]; AdditionParams params; auto err = EthosnAPI::Addition(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsAdditionSupported(params.lhs_info, params.rhs_info, params.output_quantization_info); -#else - *rv = !err && sl::IsAdditionSupported(params.lhs_info, params.rhs_info, - params.output_quantization_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") @@ -721,11 +675,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") Call call = args[0]; SigmoidParams params; auto err = EthosnAPI::Sigmoid(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsSigmoidSupported(params.input_info); -#else - *rv = !err && sl::IsSigmoidSupported(params.input_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") @@ -733,11 +683,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") Call call = args[0]; ConcatenateParams params; auto err = EthosnAPI::Concatenate(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsConcatenationSupported(params.input_infos, params.concat_info); -#else - *rv = !err && sl::IsConcatenationSupported(params.input_infos, params.concat_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") @@ -745,11 +691,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") Call call = args[0]; SplitParams params; auto err = EthosnAPI::Split(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsSplitSupported(params.input_info, params.split_info); -#else - *rv = !err && sl::IsSplitSupported(params.input_info, params.split_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") @@ -757,11 +699,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") Call call = args[0]; DepthToSpaceParams params; auto err = EthosnAPI::DepthToSpace(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsDepthToSpaceSupported(params.input_info, params.depth_info); -#else - *rv = !err && sl::IsDepthToSpaceSupported(params.input_info, params.depth_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") @@ -769,11 +707,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") Call call = args[0]; ReluParams params; auto err = EthosnAPI::Relu(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsReluSupported(params.relu_info, params.input_info); -#else - *rv = !err && sl::IsReluSupported(params.relu_info, params.input_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.query").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index e44aa31d6b13..63ae7a3e4704 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -226,7 +226,8 @@ NetworkWithIDs ConstructNetwork(const IRModule& mod, const GlobalVar& var, const /*! \brief Attributes to store the compiler options for Ethos-N */ struct EthosnCompilerConfigNode : public tvm::AttrsNode { - int variant; + String variant; + int sram_size_bytes; bool strategy0; bool strategy1; bool strategy3; @@ -240,18 +241,14 @@ struct EthosnCompilerConfigNode : public tvm::AttrsNode { Expr let_binding = GetRef(l); const LetNode* let; while ((let = let_binding.as())) { + ICHECK(!let->value.as()) + << "invariant violated, inner functions should not exist (did you set opt_level = 2?)"; VisitExpr(let->value); var_register_map_.insert({let->var, this->last_register_}); let_binding = let->body; @@ -490,6 +492,9 @@ class VMFunctionCompiler : ExprFunctor { argument_registers.push_back(reg->second); } + // Extract functions attrs + op_attrs[op_index] = func->attrs->dict; + Emit(Instruction::InvokePacked(op_index, argument_registers.size(), outputs.size(), argument_registers)); } @@ -1156,25 +1161,24 @@ void VMCompiler::Codegen() { if (cached_funcs.size() == 0) { return; } - std::unordered_map funcs; + Map funcs; for (auto& cfunc : cached_funcs) { - std::string target_str = cfunc->target->str(); + Target target = cfunc->target; // NOTE: because module, is mutable, we need to make an // explicit copy of the IRModule. IRModule mod = cfunc->funcs; mod.CopyOnWrite(); - if (target_str == "ext_dev") { + if (target->kind->device_type == kDLExtDev) { // Collect metadata in functions that are handled by external codegen. ICHECK(mod->ContainGlobalVar(cfunc->func_name)); Function func = Downcast(mod->Lookup(cfunc->func_name)); backend::UpdateConstants(func, ¶ms_); - continue; - } else if (funcs.count(target_str) == 0) { - funcs.emplace(target_str, mod); + } else if (funcs.count(target) == 0) { + funcs.Set(target, mod); } else { - funcs[target_str]->Update(mod); + funcs[target]->Update(mod); } } @@ -1182,11 +1186,7 @@ void VMCompiler::Codegen() { auto ext_mods = compile_engine->LowerExternalFunctions(); runtime::Module lib; if (funcs.size() > 0) { - Map build_funcs; - for (const auto& i : funcs) { - build_funcs.Set(i.first, i.second); - } - lib = tvm::build(build_funcs, target_host_); + lib = tvm::build(funcs, target_host_); } else { // There is no function handled by TVM. We create a virtual main module // to make sure a DSO module will be also available. diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 6ed24d5053c4..5ce06d9fefaa 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -131,6 +131,8 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { return rhs.operator std::string() == val->value; } else if (auto* val = lhs.as()) { return rhs.operator std::string() == val->data; + } else { + ICHECK(false) << "PatternMatcher: Unsupported TVMDataType " << lhs; } break; case kTVMObjectHandle: @@ -140,6 +142,13 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { } else if (auto* val = lhs.as()) { return rhs.operator String() == val->data; } + } else { + // Compare the objects for structural equality + static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); + ICHECK(structural_equal) << "node.StructuralEqual is not registered."; + if ((*structural_equal)(lhs, GetRef(rhs.ptr()), false, true)) { + return true; + } } break; default: diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 62ff0b1a86b3..3b3c8797d7f2 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -115,6 +115,8 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span s n->attrs = std::move(attrs); n->type_args = std::move(type_args); n->span = std::move(span); + n->saved_deleter_ = n->deleter_; + n->deleter_ = CallNode::Deleter_; data_ = std::move(n); } @@ -288,16 +290,24 @@ inline void Dismantle(const Expr& expr) { // special handling if (const CallNode* op = node.as()) { - for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) { - fpush_to_stack(*it); + // do not process args if used elsewhere + if (op->args.use_count() < 2) { + for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) { + fpush_to_stack(*it); + } } - fpush_to_stack(op->op); } else if (const TupleNode* op = node.as()) { - for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) { - fpush_to_stack(*it); + // do not process fields if used elsewhere + if (op->fields.use_count() < 2) { + for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) { + fpush_to_stack(*it); + } } } else if (const TupleGetItemNode* op = node.as()) { - fpush_to_stack(op->tuple); + // do not process tuple if used elsewhere + if (op->tuple.use_count() < 2) { + fpush_to_stack(op->tuple); + } } } } @@ -306,7 +316,6 @@ inline void Dismantle(const Expr& expr) { /* * Non-recursive destructor */ - Call::~Call() { // attempt to dismantle if referenced one or zero times if (this->use_count() < 2) { @@ -316,5 +325,16 @@ Call::~Call() { } } +/* + * CallNode's deleter + */ +void CallNode::Deleter_(Object* ptr) { + auto p = reinterpret_cast(ptr); + // resore original deleter + p->deleter_ = p->saved_deleter_; + // create Call reference in order to invoke ~Call + auto c = GetRef(p); +} + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 596f812e25af..4a7974cae5ae 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -133,8 +133,6 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) DLOG(INFO) << "Executing function pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; - pass_ctx.Trace(mod, pass_info, true); - // Execute the pass function and return a new module. IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map); @@ -159,8 +157,6 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; - pass_ctx.Trace(updated_mod, pass_info, false); - // TODO(@jroesch): move away from eager type checking for performance reasons // make issue. return transform::InferType()(updated_mod); diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index bbfef5883e3d..81de4bc90ad7 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -78,7 +78,8 @@ Expr MakeStack(Expr data, int axis); Expr MakeTranspose(Expr data, Array axes); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides, - String slice_mode); + String slice_mode, + Optional> axes = NullValue>()); Expr MakeTile(Expr data, Array reps); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 38cb763883b7..1ac800f357b0 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -27,7 +27,6 @@ #include #include #include -#include #include diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index bf45a412050f..9361e1996796 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2445,99 +2445,40 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr return false; } - auto dshape = data->shape; - int64_t num_axis = dshape.size(); - - // calculate output shape - std::vector oshape(num_axis); - if (param->begin && param->end && param->strides) { - // stride will be set as 1 if slice mode is enabled - std::vector stride_vec(num_axis, 1); - if (param->slice_mode == "end") { - for (size_t i = 0; i < param->strides.value().size(); ++i) { - ICHECK(param->strides.value()[i].defined()); - stride_vec[i] = param->strides.value()[i]->value; - } - } - const int64_t max_range = std::numeric_limits::max(); - std::vector begin_vec; - for (size_t i = 0; i < param->begin.value().size(); ++i) { - if (!param->begin.value()[i].defined()) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } else { - begin_vec.push_back(param->begin.value()[i]->value); - } - } - for (int64_t i = begin_vec.size(); i < num_axis; ++i) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } + ICHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin; + ICHECK(param->end) << "strided_slice recieved invalid end " << param->end; + ICHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides; + + auto begin = param->begin.value(); + auto end = param->end.value(); + auto strides = param->strides.value(); + + const size_t src_tensor_dim = static_cast(data->shape.size()); + Array axes; + if (param->axes) { + axes = param->axes.value(); + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && + axes.size() == strides.size()) + << "axes, begin, end, and strides must have the same length"; + } else { + for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); - std::vector end_vec; - for (size_t i = 0; i < param->end.value().size(); ++i) { - // allow end to be None - if (!param->end.value()[i].defined()) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else if (param->slice_mode == "size") { - if (param->end.value()[i]->value < 0) { - end_vec.push_back(max_range); - } else { - end_vec.push_back(begin_vec[i] + param->end.value()[i]->value); - } - } else if (param->slice_mode == "end") { - end_vec.push_back(param->end.value()[i]->value); - } else { - LOG(FATAL) << "Unsupported slice mode: " << param->slice_mode; - } + const IntImm one = IntImm(DataType::Int(64), 1); + const IntImm zero = IntImm(DataType::Int(64), 0); + const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits::max()); + + for (size_t i = strides.size(); i < src_tensor_dim; ++i) { + strides.push_back(one); } - for (int64_t i = end_vec.size(); i < num_axis; ++i) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + for (size_t i = begin.size(); i < src_tensor_dim; ++i) { + begin.push_back(topi::GetConstInt(strides[i]) > 0 ? zero : max_range); } - - for (int64_t i = 0; i < num_axis; ++i) { - int64_t stride_v = stride_vec[i]; - int64_t begin_v = begin_vec[i]; - int64_t end_v = end_vec[i]; - - if ((stride_v == 1 && begin_v == 0 && end_v == max_range) || - (stride_v == -1 && begin_v == max_range && end_v == 0)) { - // Quick path, do not slice this dimension. - oshape[i] = dshape[i]; - continue; - } - // Normal path, require the shape to be concrete integer. - // Require concrete integer as symbolic inference of min/max - // can get complicated and not very helpful. - const int64_t* p_dim_size = tir::as_const_int(dshape[i]); - if (!p_dim_size) { - oshape[i] = dshape[i]; - continue; - } - int64_t dim_size = p_dim_size[0]; - begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; - end_v = (end_v < 0) ? dim_size + end_v : end_v; - - int64_t slice_range, step; - if (stride_v < 0) { - if (end_v < -1) end_v = -1; - ICHECK_LE(end_v, begin_v) << "strided_slice get empty slice at axis " << i; - begin_v = std::min(dim_size - 1, begin_v); - slice_range = begin_v - end_v; - step = -stride_v; - } else { - if (begin_v < 0) begin_v = 0; - ICHECK_GE(stride_v, 0); - ICHECK_LE(begin_v, end_v) << "strided_slice get invalid slice at axis " << i; - end_v = std::min(dim_size, end_v); - slice_range = end_v - begin_v; - step = stride_v; - } - oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); + for (size_t i = end.size(); i < src_tensor_dim; ++i) { + end.push_back(topi::GetConstInt(strides[i]) < 0 ? zero : max_range); } - } else { - ICHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin; - ICHECK(param->end) << "strided_slice recieved invalid end " << param->end; - ICHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides; } + auto oshape = + topi::StridedSliceOutputShape(data->shape, begin, end, strides, axes, param->slice_mode); reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -2596,78 +2537,130 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, // Not support NHW4c -> NCHW return {{Layout::Undef()}, {Layout::Undef()}}; } else { - for (size_t i = 0; i < new_layout_name.size(); ++i) { - auto index = layout.IndexOf(new_layout[i]); - if (index == -1) { - return {{Layout::Undef()}, {Layout::Undef()}}; + if (params->axes) { + auto axes = params->axes.value(); + Array new_axes; + + for (size_t i = 0; i < axes.size(); ++i) { + auto old_idx = axes[i]; + auto new_idx = new_layout.IndexOf(layout[old_idx]); + new_begin.push_back(begin[i]); + new_end.push_back(end[i]); + new_strides.push_back(strides[i]); + new_axes.push_back(new_idx); } + params->axes = new_axes; - size_t new_index = static_cast(index); - int64_t bg, ed, st; - if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) { - st = strides[new_index]->value; - } else { - st = 1; - } - if (new_index < begin.size() && begin[new_index].defined()) { - bg = begin[new_index]->value; - } else { - bg = 0; - } - if (new_index < end.size() && end[new_index].defined()) { - ed = end[new_index]->value; - } else { - ed = shape[new_index].as()->value; - } + } else { + for (size_t i = 0; i < new_layout_name.size(); ++i) { + auto index = layout.IndexOf(new_layout[i]); + if (index == -1) { + return {{Layout::Undef()}, {Layout::Undef()}}; + } + + size_t new_index = static_cast(index); + int64_t bg, ed, st; + if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) { + st = strides[new_index]->value; + } else { + st = 1; + } + if (new_index < begin.size() && begin[new_index].defined()) { + bg = begin[new_index]->value; + } else { + bg = 0; + } + if (new_index < end.size() && end[new_index].defined()) { + ed = end[new_index]->value; + } else { + ed = shape[new_index].as()->value; + } - new_begin.push_back(IntImm(begin[0]->dtype, bg)); - new_end.push_back(IntImm(end[0]->dtype, ed)); - new_strides.push_back(IntImm(strides[0]->dtype, st)); + new_begin.push_back(IntImm(begin[0]->dtype, bg)); + new_end.push_back(IntImm(end[0]->dtype, ed)); + new_strides.push_back(IntImm(strides[0]->dtype, st)); + } } + params->begin = new_begin; params->end = new_end; params->strides = new_strides; layout = new_layout; } } else { - for (size_t i = 0; i < begin.size(); i++) { - const LayoutAxis& axis = layout[i]; - if (!axis.IsPrimal()) { - // original layout that contains splitted axes is not supported - return {{Layout::Undef()}, {Layout::Undef()}}; - } - auto factor = new_layout.FactorOf(axis); - if (factor == -1) { - new_begin.push_back(IntImm(begin[i]->dtype, begin[i])); - new_end.push_back(IntImm(end[i]->dtype, end[i])); - } else { - if (strides.defined() && i < strides.size()) { - auto stride = strides[i]; - // arbitrary stride is not supported - if (stride.defined() && stride->value != 1) { + if (params->axes) { + auto axes = params->axes.value(); + Array new_axes; + + for (size_t i = 0; i < axes.size(); ++i) { + auto old_idx = axes[i]; + auto new_idx = new_layout.IndexOf(layout[old_idx]); + new_axes.push_back(new_idx); + + const LayoutAxis& axis = layout[old_idx]; + if (!axis.IsPrimal()) { + // original layout that contains splitted axes is not supported + return {{Layout::Undef()}, {Layout::Undef()}}; + } + + auto factor = new_layout.FactorOf(axis); + + if (factor == -1) { + new_begin.push_back(begin[i]); + new_end.push_back(end[i]); + } else { + int64_t bg = begin[i]; + int64_t ed = end[i]; + if (bg % factor || ed % factor) { + // transform to original layout return {{Layout::Undef()}, {Layout::Undef()}}; } + new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); + new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } - int64_t bg = begin[i].defined() ? begin[i]->value : 0; - int64_t ed; - if (!end[i].defined()) { - ed = shape[i].as()->value; - } else if (params->slice_mode == "size") { - if (end[i]->value < 0) { + } + params->axes = new_axes; + + } else { + for (size_t i = 0; i < begin.size(); i++) { + const LayoutAxis& axis = layout[i]; + if (!axis.IsPrimal()) { + // original layout that contains splitted axes is not supported + return {{Layout::Undef()}, {Layout::Undef()}}; + } + auto factor = new_layout.FactorOf(axis); + if (factor == -1) { + new_begin.push_back(IntImm(begin[i]->dtype, begin[i])); + new_end.push_back(IntImm(end[i]->dtype, end[i])); + } else { + if (strides.defined() && i < strides.size()) { + auto stride = strides[i]; + // arbitrary stride is not supported + if (stride.defined() && stride->value != 1) { + return {{Layout::Undef()}, {Layout::Undef()}}; + } + } + int64_t bg = begin[i].defined() ? begin[i]->value : 0; + int64_t ed; + if (!end[i].defined()) { ed = shape[i].as()->value; + } else if (params->slice_mode == "size") { + if (end[i]->value < 0) { + ed = shape[i].as()->value; + } else { + ed = bg + end[i]->value; + } } else { - ed = bg + end[i]->value; + ed = end[i]->value; } - } else { - ed = end[i]->value; - } - if (bg % factor || ed % factor) { - // transform to original layout - return {{Layout::Undef()}, {Layout::Undef()}}; + if (bg % factor || ed % factor) { + // transform to original layout + return {{Layout::Undef()}, {Layout::Undef()}}; + } + new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); + new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } - new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); - new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } } @@ -2683,63 +2676,27 @@ Array StridedSliceCompute(const Attrs& attrs, const Array(); ICHECK(param != nullptr); - Array begin, end, strides; - Array begin_expr, end_expr, strides_expr; - begin = param->begin.value(); - end = param->end.value(); - strides = param->strides.value(); - if (IsDynamic(out_type)) { - auto input = inputs[0]; - size_t src_tensor_dim = input->shape.size(); - ICHECK(begin.size() == src_tensor_dim) - << "for dynamic inputs, len(begin) must equal the input dimension"; - Array out_shape; - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(tvm::tir::Var("dim")); - } - for (size_t i = 0; i < src_tensor_dim; ++i) { - int64_t begin_i = begin[i]->value; - if (begin_i < 0) { - begin_i += topi::detail::GetConstInt(input->shape[i]); - } - begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back( - tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), - (i < strides.size() ? strides[i]->value : 1))); - } - return Array{te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); - } - return input(real_indices); - }, - std::string{"T_strided_slice_dynamic"}, std::string{topi::kInjective})}; - } else { - for (size_t i = 0; i < begin.size(); ++i) { - begin_expr.push_back(begin[i]); - } - for (size_t i = 0; i < end.size(); ++i) { - end_expr.push_back(end[i]); - } - for (size_t i = 0; i < strides.size(); ++i) { - strides_expr.push_back(strides[i]); - } + ICHECK(param->begin && param->end && param->strides); + Array begin = param->begin.value(); + Array end = param->end.value(); + Array strides = param->strides.value(); + if (param->axes) { + auto axes = param->axes.value(); + return Array{ + topi::strided_slice_with_axes(inputs[0], begin, end, strides, axes, param->slice_mode)}; } - return Array{ - topi::strided_slice(inputs[0], begin_expr, end_expr, strides_expr, param->slice_mode)}; + return Array{topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)}; } // Positional relay function to create StridedSlice operator used by frontend FFI. Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides, - String slice_mode) { + String slice_mode, Optional> axes) { auto attrs = make_object(); attrs->begin = std::move(begin); attrs->end = std::move(end); attrs->strides = std::move(strides); attrs->slice_mode = slice_mode; + attrs->axes = std::move(axes); static const Op& op = Op::Get("strided_slice"); return Call(op, {data}, Attrs(attrs), {}); } @@ -3057,16 +3014,21 @@ Array SliceLikeCompute(const Attrs& attrs, const Array& ICHECK(param != nullptr); Array src_shape = inputs[0]->shape; Array target_shape = inputs[1]->shape; - Array begin_idx, end_idx, strides; + Array begin_idx, end_idx, strides; for (size_t i = 0; i < src_shape.size(); ++i) { begin_idx.push_back(0); strides.push_back(1); } - end_idx = Array(src_shape); + for (auto s : src_shape) { + ICHECK(s->IsInstance()) << "slice_like does not support dynamic input shape"; + end_idx.push_back(topi::GetConstInt(s)); + } if (!param->axes.defined()) { for (size_t i = 0; i < src_shape.size(); ++i) { if (i < target_shape.size()) { - end_idx.Set(i, target_shape[i]); + ICHECK(target_shape[i]->IsInstance()) + << "slice_like does not support dynamic output shape"; + end_idx.Set(i, topi::GetConstInt(target_shape[i])); ICHECK_LE(topi::GetConstInt(end_idx[i]), topi::GetConstInt(src_shape[i])) << "End index of axis " << i << " exceeds input shape: " << topi::GetConstInt(end_idx[i]) << " vs " @@ -3078,7 +3040,9 @@ Array SliceLikeCompute(const Attrs& attrs, const Array& if (axis < 0) { axis = static_cast(src_shape.size()) + axis; } - end_idx.Set(axis, target_shape[axis]); + ICHECK(target_shape[axis]->IsInstance()) + << "slice_like does not support dynamic output shape"; + end_idx.Set(axis, topi::GetConstInt(target_shape[axis])); ICHECK_LE(topi::GetConstInt(end_idx[axis]), topi::GetConstInt(src_shape[axis])) << "End index of axis " << axis << " exceeds input shape: " << topi::GetConstInt(end_idx[axis]) << " vs " @@ -3373,10 +3337,12 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; } -Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0) { +Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, + Optional index_rank = NullValue()) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); attrs->batch_dims = batch_dims; + attrs->index_rank = index_rank; return Call(op, {data, indices}, Attrs(attrs)); } @@ -3974,10 +3940,11 @@ bool UniqueRel(const Array& types, int num_inputs, const Attrs& attrs, } const int ndim = static_cast(data->shape.size()); ICHECK_EQ(ndim, 1) << "Unique: input must be 1-D tensor"; - ICHECK_EQ(data->dtype.is_int(), true) << "Unique: input must have int32 or int64 dtype"; + std::vector fields; fields.push_back(TensorType(data->shape, data->dtype)); // unique fields.push_back(TensorType(data->shape, DataType::Int(32))); // indices + fields.push_back(TensorType(data->shape, DataType::Int(32))); // inverse_indices fields.push_back(TensorType(Array{1}, DataType::Int(32))); // num_unique const auto* param = attrs.as(); if (param->return_counts) { diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 53cd71745d5b..8c33c1648cf3 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -152,24 +152,39 @@ bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs IndexExpr num_classes = scores_shape[1]; IndexExpr num_boxes = boxes_shape[1]; - IndexExpr num_total_boxes = Any(); - if (!batch.as() && !num_boxes.as()) { - num_total_boxes = batch * num_classes * num_boxes; - } + const auto* param = attrs.as(); + CHECK(param); - // assign output type std::vector fields; - std::vector oshape{num_total_boxes, 3}; - fields.push_back(TensorType(oshape, DataType::Int(64))); - std::vector countshape{1}; - fields.push_back(TensorType(countshape, DataType::Int(64))); + if (param->output_format == "onnx") { + IndexExpr num_total_boxes = Any(); + if (!batch.as() && !num_boxes.as()) { + num_total_boxes = batch * num_classes * num_boxes; + } + std::vector oshape{num_total_boxes, 3}; + std::vector counts_shape{1}; + fields.push_back(TensorType(oshape, DataType::Int(64))); + fields.push_back(TensorType(counts_shape, DataType::Int(64))); + } else { + IndexExpr num_total_boxes_per_batch = Any(); + if (!num_boxes.as()) { + num_total_boxes_per_batch = num_classes * num_boxes; + } + std::vector indices_shape{batch, num_total_boxes_per_batch, 2}; + std::vector scores_shape{batch, num_total_boxes_per_batch}; + std::vector counts_shape{batch}; + fields.push_back(TensorType(indices_shape, DataType::Int(64))); + fields.push_back(TensorType(scores_shape, DataType::Float(32))); + fields.push_back(TensorType(counts_shape, DataType::Int(64))); + } reporter->Assign(types[5], TupleType(Array(fields))); return true; } Expr MakeAllClassNMS(Expr boxes, Expr scores, Expr max_output_boxes_per_class, Expr iou_threshold, - Expr score_threshold) { + Expr score_threshold, std::string output_format = "onnx") { auto attrs = make_object(); + attrs->output_format = std::move(output_format); static const Op& op = Op::Get("vision.all_class_non_max_suppression"); return Call(op, {boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold}, Attrs(attrs), {}); diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index b0fe9356a758..7af5c2ac1c33 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -53,15 +53,20 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* dequantize_attrs = attrs.as(); int axis = dequantize_attrs->axis; - axis = (axis < 0) ? data->shape.size() + axis : axis; - ICHECK_LT(axis, static_cast(data->shape.size())) - << "axis " << dequantize_attrs->axis << " is out of range"; + auto rank = static_cast(data->shape.size()); + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; + PrimExpr axis_shape; + if (rank > 0) { + axis_shape = data->shape[axis]; + } else { + axis_shape = Integer(1); + } // Check and assign types for scale and zero points. - AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale - AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point - + AssignType(types[1], DataType::Float(32), axis_shape, reporter); // scale + AssignType(types[2], DataType::Int(32), axis_shape, reporter); // zero point const Array oshape = data->shape; // assign output type, output will always be float 32. reporter->Assign(types[3], TensorType(oshape, DataType::Float(32))); diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index e365dca3860f..b12e25a425b6 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -27,7 +27,6 @@ #include #include #include -#include #include "pass_utils.h" diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc new file mode 100644 index 000000000000..f883b4113656 --- /dev/null +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/quantize_fake_quantization.cc + * \brief A pass for taking fake quantized graphs and converting them + * to actual integer operations. + */ + +#include +#include +#include + +/* Description of FakeQuantizationToInteger + * + * The purpose of this pass is to find regions of the graph that follow + * the general pattern: + * + * x w + * | | + * dq dq + * \ / + * op1 + * | + * op2 + * | + * q + * + * and convert them into subgraphs with actual integer operations on x and w + * + * The pass does this via a multi-pass approach: + * + * The main pass is a MixedModeMutator that traverses the full graph searching for + * quantize operations + * + * The second pass is an ExprVisitor that recursively searches for subgraphs leading to the + * quantize for subtraphs bounded by dequantize operations. This pass extracts the affine + * types of the inputs for later processing, where affine denotes the transformation + * x_real = (x_affine - zero_point) * scale + * + * The third pass is an ExprMutator that recursively rewrites the subgraphs using packed funcs + * registered with the FTVMFakeQuantizationToInteger attribute. These packed funcs rewrite + * the ops based on the affine types of their inputs and then return the affine types of the + * new rewriten ops to pass that information down the stack during rewrite. + * + * After the second and third passes run, the first pass replaces the quantize with the + * rewritten subgraph and the processing continues + */ + +namespace tvm { +namespace relay { + +/*! + * \brief AffineType representation + * \sa AffineType + */ +class AffineTypeNode : public Object { + public: + /*! \brief The scale of this type */ + Expr scale; + /*! \brief The zero point of this type */ + Expr zero_point; + /*! \brief The data type of this type */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("scale", &scale); + v->Visit("zero_point", &zero_point); + v->Visit("dtype", &dtype); + } + + bool SEqualReduce(const AffineTypeNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(scale, other->scale) && equal(zero_point, other->zero_point) && + equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(scale); + hash_reduce(zero_point); + hash_reduce(dtype); + } + + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const char* _type_key = "AffineTypeNode"; + TVM_DECLARE_BASE_OBJECT_INFO(AffineTypeNode, Object); +}; + +/*! + * \brief Managed reference to AffineTypes. + * \sa AffineTypeNode + */ +class AffineType : public ObjectRef { + public: + TVM_DLL AffineType(Expr scale, Expr zero_point, DataType dtype) { + ObjectPtr n = make_object(); + n->scale = std::move(scale); + n->zero_point = std::move(zero_point); + n->dtype = std::move(dtype); + data_ = std::move(n); + } + TVM_DEFINE_OBJECT_REF_METHODS(AffineType, ObjectRef, AffineTypeNode); +}; + +TVM_REGISTER_NODE_TYPE(AffineTypeNode); + +using ExprSet = std::unordered_set; +using ExprMap = std::unordered_map; +using AffineTypeMap = Map; + +using FTVMFakeQuantizationToInteger = + runtime::TypedPackedFunc(const Expr& expr, const AffineTypeMap& map)>; + +class SubgraphExtractor : public ExprVisitor { + public: + const ExprSet GetSubgraph(const Expr& expr) { + VisitExpr(expr); + ExprSet subgraph; + if (is_fake_quantized_) { + for (auto kv : this->visit_counter_) { + if (auto call_node = GetRef(kv.first).as()) { + if (call_node->op != quantize_op_) { + subgraph.insert(Downcast(GetRef(kv.first))); + } + } + } + } + return subgraph; + } + const AffineTypeMap GetAffineTypes() { return affine_types_; } + void VisitExpr(const Expr& expr) override { + if (expr.as() == nullptr && expr.as() == nullptr && + expr.as() == nullptr) { + is_fake_quantized_ = false; + } else { + ExprVisitor::VisitExpr(expr); + } + } + + protected: + void VisitExpr_(const CallNode* call_node) override { + if (call_node->op == quantize_op_) { + // Only look at arg0 for quantize + VisitExpr(call_node->args[0]); + // Collect type of quantize ops + affine_types_.Set(GetRef(call_node), + AffineType(call_node->args[1], call_node->args[2], + call_node->checked_type().as()->dtype)); + } else if (call_node->op == dequantize_op_) { + // Collect type of dequantize ops + affine_types_.Set(GetRef(call_node), + AffineType(call_node->args[1], call_node->args[2], + call_node->args[0]->checked_type().as()->dtype)); + } else { + // run normally on everything else. + ExprVisitor::VisitExpr_(call_node); + } + } + + const Op quantize_op_ = Op::Get("qnn.quantize"); + const Op dequantize_op_ = Op::Get("qnn.dequantize"); + bool is_fake_quantized_ = true; + AffineTypeMap affine_types_; +}; + +class SubgraphMutator : public ExprMutator { + public: + SubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types) + : subgraph_(subgraph), affine_types_(affine_types) {} + + Expr MutateSubgraph(const Expr& expr) { + if (subgraph_.size() == 0) { + return expr; + } + const CallNode* quantize_node = expr.as(); + ICHECK(quantize_node); + ICHECK(quantize_node->op == quantize_op_); + out_type_ = affine_types_[expr]; + static auto fqfq = + Op::GetAttrMap("FTVMFakeQuantizationToInteger"); + for (auto node : subgraph_) { + if (!fqfq.count(Downcast(node.as()->op))) { + // Only modify the subgraph if we have translation + // rules for every op + return expr; + } + } + return Mutate(expr); + } + + protected: + Expr VisitExpr_(const CallNode* call_node) { + Expr out; + + static auto fqfq = + Op::GetAttrMap("FTVMFakeQuantizationToInteger"); + Op op = Downcast(call_node->op); + if (fqfq.count(op)) { + Expr expr; + if (op == dequantize_op_) { + expr = GetRef(call_node); + } else { + expr = ExprMutator::VisitExpr_(call_node); + // Set the current op to the output type, useful if we can't deduce output parameters + // from input parameters + affine_types_.Set(expr, out_type_); + } + // Call the rewrite + Array vals = fqfq[op](expr, affine_types_); + // Save teh outputs of the rewrite + ICHECK(vals.size() == 4) + << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for " + << AsText(op, false); + out = Downcast(vals[0]); + affine_types_.Set(out, AffineType(Downcast(vals[1]), Downcast(vals[2]), + DataType(String2DLDataType(Downcast(vals[3]))))); + } else { + ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node " + << AsText(GetRef(call_node), false); + } + return out; + } + ExprSet subgraph_; + AffineTypeMap affine_types_; + AffineType out_type_; + const Op quantize_op_ = Op::Get("qnn.quantize"); + const Op dequantize_op_ = Op::Get("qnn.dequantize"); +}; + +class FakeQuantizationRewriter : public MixedModeMutator { + protected: + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (const CallNode* call_node = post.as()) { + if (call_node->op == quantize_op_) { + SubgraphExtractor extractor; + ExprSet subgraph = extractor.GetSubgraph(GetRef(pre)); + AffineTypeMap affine_types = extractor.GetAffineTypes(); + + ExprSet post_subgraph; + AffineTypeMap post_affine_types; + + for (auto kv : affine_types) { + if (pre == kv.first.as()) { + // we havent memoized the current op yet + post_affine_types.Set(post, kv.second); + } else { + post_affine_types.Set(memo_.at(kv.first), kv.second); + } + } + for (auto expr : subgraph) { + post_subgraph.insert(memo_[expr]); + } + Expr out = SubgraphMutator(post_subgraph, post_affine_types).MutateSubgraph(post); + return out; + } + } + return post; + } + const Op quantize_op_ = Op::Get("qnn.quantize"); +}; + +Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod) { + return FakeQuantizationRewriter().Mutate(expr); +} + +namespace transform { + +Pass FakeQuantizationToInteger() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(FakeQuantizationToInteger(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.FakeQuantizationToInteger") + .set_body_typed(FakeQuantizationToInteger); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc index 91fb4cfa8973..f6da52ebe30c 100644 --- a/src/relay/transforms/fast_math.cc +++ b/src/relay/transforms/fast_math.cc @@ -34,7 +34,11 @@ namespace relay { class FastMathMutator : public ExprRewriter { public: - FastMathMutator() : exp_op_(Op::Get("exp")), erf_op_(Op::Get("erf")), tanh_op_(Op::Get("tanh")) {} + FastMathMutator() + : exp_op_(Op::Get("exp")), + erf_op_(Op::Get("erf")), + tanh_op_(Op::Get("tanh")), + softmax_op_(Op::Get("nn.softmax")) {} Expr Rewrite_(const CallNode* pre, const Expr& post) override { if (pre->op == exp_op_) { @@ -43,6 +47,8 @@ class FastMathMutator : public ExprRewriter { return FastErf(post.as()->args[0]); } else if (pre->op == tanh_op_) { return FastTanh(post.as()->args[0]); + } else if (pre->op == softmax_op_) { + return FastSoftmax(post.as()->args[0], post.as()->attrs); } return post; } @@ -54,6 +60,7 @@ class FastMathMutator : public ExprRewriter { const Op& exp_op_; const Op& erf_op_; const Op& tanh_op_; + const Op& softmax_op_; }; Expr FastMath(const Expr& e) { diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index fe5f547449ad..57603035b848 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -27,7 +27,6 @@ #include #include #include -#include #include #include diff --git a/src/relay/transforms/label_ops.cc b/src/relay/transforms/label_ops.cc index e0d3892a8d01..861342b03a76 100644 --- a/src/relay/transforms/label_ops.cc +++ b/src/relay/transforms/label_ops.cc @@ -19,7 +19,6 @@ #include #include #include -#include namespace tvm { namespace relay { diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 94891c3c98ea..1dda0d5cf429 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -35,7 +35,6 @@ #include #include #include -#include #include #include diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 50a695bf1d84..920ac153b63d 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -498,6 +498,11 @@ inline Expr FastTanh(Expr e) { return Call(op, {e}); } +inline Expr FastSoftmax(Expr e, tvm::Attrs attr) { + static const Op& op = Op::Get("nn.fast_softmax"); + return Call(op, {e}, attr); +} + inline Expr Log(Expr e) { static const Op& op = Op::Get("log"); return Call(op, {e}); diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc index 7e587664b4dc..846bc08e3054 100644 --- a/src/relay/transforms/simplify_inference.cc +++ b/src/relay/transforms/simplify_inference.cc @@ -178,7 +178,7 @@ Expr L2NormToInferUnpack(const Attrs attrs, Expr data) { return Divide(data, sqrt); } -class InferenceSimplifier : public ExprMutator { +class InferenceSimplifier : public MixedModeMutator { public: InferenceSimplifier() : batch_norm_op_(Op::Get("nn.batch_norm")), @@ -188,8 +188,7 @@ class InferenceSimplifier : public ExprMutator { group_norm_op_(Op::Get("nn.group_norm")), l2_norm_op_(Op::Get("nn.l2_normalize")) {} - Expr VisitExpr_(const TupleGetItemNode* n) final { - Expr new_e = ExprMutator::VisitExpr_(n); + Expr Rewrite_(const TupleGetItemNode* n, const Expr& new_e) final { const auto* new_n = new_e.as(); if (new_n->index != 0) { return new_e; @@ -205,8 +204,7 @@ class InferenceSimplifier : public ExprMutator { return new_e; } - Expr VisitExpr_(const CallNode* n) { - auto new_n = ExprMutator::VisitExpr_(n); + Expr Rewrite_(const CallNode* n, const Expr& new_n) { if (n->op == batch_norm_op_) { ty_map_[new_n.as()->args[0]] = n->args[0]->checked_type(); } else if (n->op == layer_norm_op_) { diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 3d9b1481f6e6..159404be5351 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -21,7 +21,12 @@ * \file src/runtime/container.cc * \brief Implementations of common containers. */ -#include +#include +#include +#include +#include +#include +#include #include #include #include @@ -29,6 +34,42 @@ namespace tvm { namespace runtime { +// Array +TVM_REGISTER_OBJECT_TYPE(ArrayNode); + +TVM_REGISTER_GLOBAL("runtime.Array").set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector data; + for (int i = 0; i < args.size(); ++i) { + if (args[i].type_code() != kTVMNullptr) { + data.push_back(args[i].operator ObjectRef()); + } else { + data.push_back(ObjectRef(nullptr)); + } + } + *ret = Array(data); +}); + +TVM_REGISTER_GLOBAL("runtime.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + int64_t i = args[1]; + ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + ICHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + ICHECK_LT(static_cast(i), n->size()) << "out of bound of array"; + *ret = n->at(i); +}); + +TVM_REGISTER_GLOBAL("runtime.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + ICHECK(ptr->IsInstance()); + *ret = static_cast(static_cast(ptr)->size()); +}); + +// ADT + +TVM_REGISTER_OBJECT_TYPE(ADTObj); + TVM_REGISTER_GLOBAL("runtime.GetADTTag").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); @@ -67,6 +108,9 @@ TVM_REGISTER_GLOBAL("runtime.ADT").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = ADT(tag, fields); }); +// String +TVM_REGISTER_OBJECT_TYPE(StringObj); + TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) { return String(std::move(str)); }); @@ -75,41 +119,7 @@ TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) { return std::string(str); }); -TVM_REGISTER_OBJECT_TYPE(ADTObj); -TVM_REGISTER_OBJECT_TYPE(StringObj); -TVM_REGISTER_OBJECT_TYPE(ClosureObj); - -TVM_REGISTER_OBJECT_TYPE(ArrayNode); - -TVM_REGISTER_GLOBAL("runtime.Array").set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector data; - for (int i = 0; i < args.size(); ++i) { - if (args[i].type_code() != kTVMNullptr) { - data.push_back(args[i].operator ObjectRef()); - } else { - data.push_back(ObjectRef(nullptr)); - } - } - *ret = Array(data); -}); - -TVM_REGISTER_GLOBAL("runtime.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { - int64_t i = args[1]; - ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - ICHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - ICHECK_LT(static_cast(i), n->size()) << "out of bound of array"; - *ret = n->at(i); -}); - -TVM_REGISTER_GLOBAL("runtime.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - ICHECK(ptr->IsInstance()); - *ret = static_cast(static_cast(ptr)->size()); -}); - +// Map TVM_REGISTER_OBJECT_TYPE(MapNode); TVM_REGISTER_GLOBAL("runtime.Map").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -174,5 +184,27 @@ TVM_REGISTER_GLOBAL("runtime.MapItems").set_body([](TVMArgs args, TVMRetValue* r TVM_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[]; #endif +// Closure +TVM_REGISTER_OBJECT_TYPE(ClosureObj); + +// ShapeTuple +TVM_REGISTER_OBJECT_TYPE(ShapeTupleObj); + +TVM_REGISTER_GLOBAL("runtime.ShapeTuple").set_body([](TVMArgs args, TVMRetValue* rv) { + std::vector shape; + for (int i = 0; i < args.size(); i++) { + shape.push_back(args[i]); + } + *rv = ShapeTuple(shape); +}); + +TVM_REGISTER_GLOBAL("runtime.GetShapeTupleSize").set_body_typed([](ShapeTuple shape) { + return static_cast(shape.size()); +}); + +TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple shape, int idx) { + ICHECK_LT(idx, shape.size()); + return shape[idx]; +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 6562d1bfc62d..5bbc536afaca 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -381,9 +381,9 @@ class ACLRuntime : public JSONRuntimeBase { void CreatePoolingLayer(CachedLayer* layer, const JSONGraphNode& node) { std::vector padding = node.GetAttr>("padding"); std::vector strides = node.GetAttr>("strides"); + std::vector dilation = node.GetAttr>("dilation"); bool ceil_mode = std::stoi(node.GetAttr>("ceil_mode")[0]); arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides, ceil_mode); - auto attr_pool_size = node.GetAttr>("pool_size"); int pool_size_h = std::stoi(attr_pool_size[0]); int pool_size_w = std::stoi(attr_pool_size[1]); @@ -408,6 +408,8 @@ class ACLRuntime : public JSONRuntimeBase { LOG(FATAL) << "Pooling type not supported"; } + ICHECK(dilation.size() == 2 && dilation[0] == "1" && dilation[1] == "1") + << "Dilation other than (1, 1) not supported"; arm_compute::PoolingLayerInfo pool_info = arm_compute::PoolingLayerInfo(pool_type, arm_compute::Size2D(pool_size_h, pool_size_w), arm_compute::DataLayout::NHWC, pad_stride_info, exclude_pad); diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index 9ccfa5183cd6..4724b14bffa1 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -181,28 +181,48 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue* ret, TBatchGemmOp op) { DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; + int bit_depth = sizeof(DType) * 8; + ICHECK_EQ(A->ndim, 3); ICHECK_EQ(B->ndim, 3); ICHECK_EQ(C->ndim, 3); - int batch_size = BatchCount3D(A); - ICHECK_EQ(BatchCount3D(B), batch_size); - ICHECK_EQ(BatchCount3D(C), batch_size); + + int batch_size = BatchCount3D(C); ICHECK_EQ(ElementStride(A), 1); ICHECK_EQ(ElementStride(B), 1); ICHECK_EQ(ElementStride(C), 1); + // C can never be transposed. ICHECK(!IsInPlaceTransposed3D(C)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed3D(A) ? !transa : transa; transb = IsInPlaceTransposed3D(B) ? !transb : transb; + ICHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); ICHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); + double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; - const int A_size = A->shape[1] * A->shape[2]; - const int B_size = B->shape[1] * B->shape[2]; - const int C_size = C->shape[1] * C->shape[2]; + + int A_stride = A->shape[1] * A->shape[2]; + int B_stride = B->shape[1] * B->shape[2]; + int C_stride = C->shape[1] * C->shape[2]; + + // Broadcast A or B by changing its stride. + int batch_size_a = BatchCount3D(A); + int batch_size_b = BatchCount3D(B); + if (batch_size_a != batch_size_b) { + if (batch_size_a == 1) { + A_stride = 0; + } else if (batch_size_b == 1) { + B_stride = 0; + } + } else { + ICHECK_EQ(batch_size_a, batch_size); + ICHECK_EQ(batch_size_b, batch_size); + } + DType* A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); DType* B_data = reinterpret_cast(static_cast(B->data) + @@ -210,9 +230,9 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue* ret, TBatchGemmOp op) { DType* C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); op(batch_size, transb, transa, ColumnCount3D(B, transb), RowCount3D(A, transa), - ColumnCount3D(A, transa), static_cast(alpha), B_data, B_size, - ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), - static_cast(beta), C_data, C_size, ColumnStride3D(C)); + ColumnCount3D(A, transa), static_cast(alpha), B_data, + B_stride, ColumnStride3D(B), A_data, A_stride, ColumnStride3D(A), + static_cast(beta), C_data, C_stride, ColumnStride3D(C)); } } // namespace contrib diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 9af1602cf3c0..015d68aec819 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -39,7 +39,7 @@ inline void CUBLASTryEnableTensorCore(cublasHandle_t hdl) { // TensorCores are only supported in cublas 9.0 or higher int version; CHECK_CUBLAS_ERROR(cublasGetVersion(hdl, &version)); - if (version >= 9000) CHECK_CUBLAS_ERROR(cublasSetMathMode(hdl, CUBLAS_TENSOR_OP_MATH)); + if (version >= 9000) CHECK_CUBLAS_ERROR(cublasSetMathMode(hdl, CUBLAS_DEFAULT_MATH)); } struct CublasHgemmOp { @@ -275,9 +275,8 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) ICHECK_EQ(A->ndim, 3); ICHECK_EQ(B->ndim, 3); ICHECK_EQ(C->ndim, 3); - int batch_size = BatchCount3D(A); - ICHECK_EQ(BatchCount3D(B), batch_size); - ICHECK_EQ(BatchCount3D(C), batch_size); + + int batch_size = BatchCount3D(C); ICHECK_EQ(ElementStride(A), 1); ICHECK_EQ(ElementStride(B), 1); ICHECK_EQ(ElementStride(C), 1); @@ -299,9 +298,23 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; - const int A_size = A->shape[1] * A->shape[2]; - const int B_size = B->shape[1] * B->shape[2]; - const int C_size = C->shape[1] * C->shape[2]; + int A_stride = A->shape[1] * A->shape[2]; + int B_stride = B->shape[1] * B->shape[2]; + int C_stride = C->shape[1] * C->shape[2]; + + // Broadcast A or B by changing its stride. + int batch_size_a = BatchCount3D(A); + int batch_size_b = BatchCount3D(B); + if (batch_size_a != batch_size_b) { + if (batch_size_a == 1) { + A_stride = 0; + } else if (batch_size_b == 1) { + B_stride = 0; + } + } else { + ICHECK_EQ(batch_size_a, batch_size); + ICHECK_EQ(batch_size_b, batch_size); + } cudaDataType_t cuda_in_type = GetCudaDataType(A->dtype); cudaDataType_t cuda_out_type = GetCudaDataType(C->dtype); @@ -325,8 +338,9 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx( hdl, CUBLASBooleanToTranspose(transb), CUBLASBooleanToTranspose(transa), ColumnCount3D(B, transb), RowCount3D(A, transa), ColumnCount3D(A, transa), alpha_ptr, B_data, - cuda_in_type, ColumnStride3D(B), B_size, A_data, cuda_in_type, ColumnStride3D(A), A_size, - beta_ptr, C_data, cuda_out_type, ColumnStride3D(C), C_size, batch_size, cuda_out_type, algo)); + cuda_in_type, ColumnStride3D(B), B_stride, A_data, cuda_in_type, ColumnStride3D(A), A_stride, + beta_ptr, C_data, cuda_out_type, ColumnStride3D(C), C_stride, batch_size, cuda_out_type, + algo)); } // matrix multiplication for row major diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.h b/src/runtime/contrib/edgetpu/edgetpu_runtime.h index a7a57ff422e3..341062f1c492 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.h +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.h @@ -31,6 +31,7 @@ #include #include "../tflite/tflite_runtime.h" +#include "edgetpu.h" namespace tvm { namespace runtime { @@ -43,6 +44,14 @@ namespace runtime { */ class EdgeTPURuntime : public TFLiteRuntime { public: + /*! + * \brief Destructor of EdgeTPURuntime. + * + * NOTE: tflite::Interpreter member should be destruct before the EdgeTpuContext member + * destruction. If the order is reverse, occurs SEGV in the destructor of tflite::Interpreter. + */ + ~EdgeTPURuntime() { interpreter_.reset(); } + /*! * \return The type key of the executor. */ diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 55f16635b9e6..1735d8569215 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -25,7 +25,6 @@ #ifndef TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_ #define TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_ -#include #include #include diff --git a/src/runtime/contrib/onnx/onnx_module.cc b/src/runtime/contrib/onnx/onnx_module.cc index b235d63dbc58..8732b700a218 100644 --- a/src/runtime/contrib/onnx/onnx_module.cc +++ b/src/runtime/contrib/onnx/onnx_module.cc @@ -21,7 +21,6 @@ * \file onnx_module.cc * \brief ONNX Module without runtime support */ -#include #include #include diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index b8d6f6cd9ff0..d8182b0e8378 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -178,15 +178,7 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { nvinfer1::IExecutionContext* context = engine->createExecutionContext(); CleanUp(); - // Allocate I/O buffers on GPU for TVM inputs which are on a different context. - std::vector device_buffers(engine->getNbBindings()); - for (size_t i = 0; i < network_input_names_.size(); ++i) { - AllocateDeviceBuffer(engine, network_input_names_[i], &device_buffers); - } - for (size_t i = 0; i < network_output_names_.size(); ++i) { - AllocateDeviceBuffer(engine, network_output_names_[i], &device_buffers); - } - return {engine, context, network_input_names_, network_output_names_, device_buffers}; + return {engine, context, network_input_names_, network_output_names_}; } nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, @@ -245,19 +237,6 @@ void TensorRTBuilder::CleanUp() { } } -void TensorRTBuilder::AllocateDeviceBuffer(nvinfer1::ICudaEngine* engine, const std::string& name, - std::vector* device_buffers) { - const uint32_t entry_id = entry_id_map_[name]; - if (data_entry_[entry_id]->device.device_type != kDLCUDA) { - const int binding_index = engine->getBindingIndex(name.c_str()); - ICHECK_NE(binding_index, -1); - std::vector shape(data_entry_[entry_id]->shape, - data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); - device_buffers->at(binding_index) = - runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); - } -} - } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h b/src/runtime/contrib/tensorrt/tensorrt_builder.h index 4926a4d02685..0b1c3997ec57 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.h +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h @@ -52,8 +52,6 @@ struct TensorRTEngineAndContext { nvinfer1::IExecutionContext* context; std::vector inputs; std::vector outputs; - /*! \brief GPU buffers for inputs and outputs. */ - std::vector device_buffers; }; /*! @@ -123,12 +121,6 @@ class TensorRTBuilder { /*! \brief Clean up resources used to create engine. */ void CleanUp(); - /*! \brief Allocate a GPU buffer for input or output DLTensor, only if the context is not GPU - * already. Inputs that are already on the GPU can be passed directly to TensorRT and will not - * need a buffer. */ - void AllocateDeviceBuffer(nvinfer1::ICudaEngine* engine, const std::string& name, - std::vector* device_buffers); - /*! \brief Maps a node to its outputs. */ std::unordered_map> node_output_map_; diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index e96359481ddb..6358e59ce3bc 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -64,7 +64,9 @@ class TensorRTRuntime : public JSONRuntimeBase { const Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names), use_implicit_batch_(true), - max_workspace_size_(size_t(1) << 30) {} + max_workspace_size_(size_t(1) << 30), + max_batch_size_(-1), + multi_engine_mode_(false) {} /*! * \brief The type key of the module. @@ -85,6 +87,7 @@ class TensorRTRuntime : public JSONRuntimeBase { LoadGlobalAttributes(); if (GetCachedEnginesFromDisk()) return; SetupConstants(consts); + multi_engine_mode_ = dmlc::GetEnv("TVM_TENSORRT_MULTI_ENGINE", false); } void LoadGlobalAttributes() { @@ -110,23 +113,25 @@ class TensorRTRuntime : public JSONRuntimeBase { #ifdef TVM_GRAPH_EXECUTOR_TENSORRT /*! \brief Destroy engines and contexts. */ - ~TensorRTRuntime() { + void DestroyEngines() { for (auto& it : trt_engine_cache_) { it.second.context->destroy(); it.second.engine->destroy(); } + trt_engine_cache_.clear(); } + ~TensorRTRuntime() { DestroyEngines(); } + /*! \brief Run inference using built engine. */ void Run() override { - BuildEngine(); - batch_size_ = data_entry_[input_var_eid_[0]]->shape[0]; - if (batch_size_ == 0) return; - auto& engine_and_context = trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size_)); + auto& engine_and_context = GetOrBuildEngine(); + int batch_size = GetBatchSize(); + if (batch_size == 0) return; auto engine = engine_and_context.engine; auto context = engine_and_context.context; - auto& device_buffers = engine_and_context.device_buffers; std::vector bindings(engine->getNbBindings(), nullptr); + // Setup input bindings. for (size_t i = 0; i < input_nodes_.size(); ++i) { auto nid = input_nodes_[i]; if (nodes_[nid].GetOpType() == "input") { @@ -138,13 +143,14 @@ class TensorRTRuntime : public JSONRuntimeBase { if (data_entry_[eid]->device.device_type == kDLCUDA) { bindings[binding_index] = data_entry_[eid]->data; } else { - device_buffers[binding_index].CopyFrom(data_entry_[eid]); - bindings[binding_index] = device_buffers[binding_index]->data; + auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); + device_buffer.CopyFrom(data_entry_[eid]); + bindings[binding_index] = device_buffer->data; } } } } - + // Setup output bindings. for (size_t i = 0; i < outputs_.size(); ++i) { uint32_t eid = EntryID(outputs_[i]); const std::string& name = engine_and_context.outputs[i]; @@ -153,18 +159,19 @@ class TensorRTRuntime : public JSONRuntimeBase { if (data_entry_[eid]->device.device_type == kDLCUDA) { bindings[binding_index] = data_entry_[eid]->data; } else { - bindings[binding_index] = device_buffers[binding_index]->data; + auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); + bindings[binding_index] = device_buffer->data; } } #if TRT_VERSION_GE(6, 0, 1) if (use_implicit_batch_) { - ICHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed."; + ICHECK(context->execute(batch_size, bindings.data())) << "Running TensorRT failed."; } else { ICHECK(context->executeV2(bindings.data())) << "Running TensorRT failed."; } #else - ICHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed."; + ICHECK(context->execute(batch_size, bindings.data())) << "Running TensorRT failed."; #endif // Copy outputs from GPU buffers if needed. @@ -174,25 +181,58 @@ class TensorRTRuntime : public JSONRuntimeBase { int binding_index = engine->getBindingIndex(name.c_str()); ICHECK_NE(binding_index, -1); if (data_entry_[eid]->device.device_type != kDLCUDA) { - device_buffers[binding_index].CopyTo(const_cast(data_entry_[eid])); + auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); + device_buffer.CopyTo(const_cast(data_entry_[eid])); } } } private: + /*! \brief Get batch size for engine from the runtime input shapes. */ + int GetBatchSize() { + return data_entry_[input_var_eid_[0]]->ndim == 0 ? 1 : data_entry_[input_var_eid_[0]]->shape[0]; + } + + /*! \brief Find an engine in the cache which we can reuse depending on the mode. If no compatible + * engine exists, return false to indicate that a new one should be built. */ + bool FindCompatibleEngine(int batch_size, int* compatible_engine_batch_size) { + if (multi_engine_mode_) { + // Exact match is required for multi engine mode. + if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size))) { + *compatible_engine_batch_size = batch_size; + return true; + } + return false; + } + // Check for engine with compatible max_batch_size. + if (batch_size <= max_batch_size_) { + *compatible_engine_batch_size = max_batch_size_; + return true; + } + return false; + } + /*! - * \brief Build TensorRT engine from JSON representation and cache it. If engine is already built, - * do nothing. + * \brief Build TensorRT engine from JSON representation and cache it. If compatible engine is + * already built, do nothing. */ - void BuildEngine() { - batch_size_ = - data_entry_[input_var_eid_[0]]->ndim == 0 ? 1 : data_entry_[input_var_eid_[0]]->shape[0]; - if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) return; + TensorRTEngineAndContext& GetOrBuildEngine() { + int batch_size = GetBatchSize(); + int compatible_engine_batch_size = -1; + if (FindCompatibleEngine(batch_size, &compatible_engine_batch_size)) { + // A compatible engine already exists. + return trt_engine_cache_.at(std::make_pair(symbol_name_, compatible_engine_batch_size)); + } + // For single engine mode, remove previous engine and update max_batch_size. + if (!multi_engine_mode_) { + DestroyEngines(); + max_batch_size_ = batch_size; + } DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_ - << " with batch size " << batch_size_; + << " with batch size " << batch_size; const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false); TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, use_implicit_batch_, - use_fp16, batch_size_); + use_fp16, batch_size); // Add inputs and constants. for (size_t i = 0; i < input_nodes_.size(); ++i) { @@ -221,10 +261,11 @@ class TensorRTRuntime : public JSONRuntimeBase { } // Build engine. - trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)] = builder.BuildEngine(); + trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = builder.BuildEngine(); DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_ - << " with batch size " << batch_size_; + << " with batch size " << batch_size; CacheEngineToDisk(); + return trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size)); } /*! \brief If TVM_TENSORRT_CACHE_DIR is set, will check that directory for @@ -268,7 +309,7 @@ class TensorRTRuntime : public JSONRuntimeBase { * directory so it can be loaded later. */ void CacheEngineToDisk() { - batch_size_ = data_entry_[input_var_eid_[0]]->shape[0]; + int batch_size = GetBatchSize(); std::string cache_dir = dmlc::GetEnv("TVM_TENSORRT_CACHE_DIR", std::string("")); if (cache_dir.empty()) return; std::string key = GetSubgraphKey(); @@ -276,7 +317,7 @@ class TensorRTRuntime : public JSONRuntimeBase { DLOG(INFO) << "Caching TensorRT engine to " << path; // Serialize engine to disk nvinfer1::IHostMemory* serialized_engine = - trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].engine->serialize(); + trt_engine_cache_[std::make_pair(symbol_name_, batch_size)].engine->serialize(); SaveBinaryToFile(path, std::string(static_cast(serialized_engine->data()), serialized_engine->size())); serialized_engine->destroy(); @@ -285,9 +326,9 @@ class TensorRTRuntime : public JSONRuntimeBase { dmlc::JSONWriter writer(&os); writer.BeginObject(); writer.WriteObjectKeyValue("inputs", - trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].inputs); - writer.WriteObjectKeyValue( - "outputs", trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].outputs); + trt_engine_cache_[std::make_pair(symbol_name_, batch_size)].inputs); + writer.WriteObjectKeyValue("outputs", + trt_engine_cache_[std::make_pair(symbol_name_, batch_size)].outputs); writer.EndObject(); std::string meta_path = cache_dir + "/" + key + ".meta"; SaveBinaryToFile(meta_path, os.str()); @@ -300,29 +341,41 @@ class TensorRTRuntime : public JSONRuntimeBase { return symbol_name_ + (dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false) ? "_fp16" : "_fp32"); } - /*! \brief Get the batch size when in implicit_batch mode. */ - int GetBatchSize() { - if (!use_implicit_batch_) return -1; - for (size_t i = 0; i < input_nodes_.size(); ++i) { - auto nid = input_nodes_[i]; - if (nodes_[nid].GetOpType() == "input") { - // Get batch size from first input. - return nodes_[nid].GetOpShape()[0][0]; + /*! \brief Retreive a GPU buffer for input or output or allocate if needed. */ + NDArray GetOrAllocateDeviceBuffer(int entry_id, int binding_index) { + std::vector shape(data_entry_[entry_id]->shape, + data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); + if (device_buffers_.count(binding_index)) { + // Buffer is already initialized. + if (shape[0] > device_buffers_[binding_index]->shape[0]) { + // Buffer is too small. Need to allocate bigger buffer. + device_buffers_[binding_index] = + runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); + } else if (shape[0] < device_buffers_[binding_index]->shape[0]) { + // Buffer is too large. Create view. + return device_buffers_[binding_index].CreateView(shape, data_entry_[entry_id]->dtype); } + } else { + // Buffer not initialized yet. + device_buffers_[binding_index] = + runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); } - return -1; + return device_buffers_.at(binding_index); } - /*! \brief Map of function name to TRT engine if built already. */ + /*! \brief Map of function name and max batch size to TRT engine if built already. */ std::unordered_map, TensorRTEngineAndContext, PairHash> trt_engine_cache_; + /*! \brief Map of inding index to GPU buffers for inputs and outputs. Only used when target device + * is not "cuda". Since TensorRT execution can only read data from GPU, we need to copy data from + * the runtime device to these buffers first. These will be allocated for the highest batch size + * used by all engines. */ + std::unordered_map device_buffers_; + /*! \brief TensorRT logger. */ TensorRTLogger logger_; - /*! \brief Batch size that the engine is optimized for. */ - int batch_size_; - #else void Run() override { LOG(FATAL) << "TensorRT runtime is not enabled. " @@ -342,6 +395,17 @@ class TensorRTRuntime : public JSONRuntimeBase { bool use_implicit_batch_; size_t max_workspace_size_; + + /*! \brief Highest batch size that an engine has been built for, used in single-engine mode only + * (multi_engine_mode=false). */ + int max_batch_size_; + + /*! \brief The strategy to use for dynamic batching. With multi_engine_mode=true, a new TensorRT + * engine is created for each unique batch size encountered. With multi_engine_mode=false, only + * one TensorRT engine is alive at any given time. It is replaced if a higher batch size is + * encountered. Multi-engine mode should give better performance, at a cost of higher memory usage + * and more time spent building engines. */ + bool multi_engine_mode_; }; runtime::Module TensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, diff --git a/src/runtime/file_utils.h b/src/runtime/file_utils.h index 718d10d5df70..4e7f158bb04f 100644 --- a/src/runtime/file_utils.h +++ b/src/runtime/file_utils.h @@ -24,7 +24,8 @@ #ifndef TVM_RUNTIME_FILE_UTILS_H_ #define TVM_RUNTIME_FILE_UTILS_H_ -#include +#include +#include #include #include diff --git a/src/runtime/graph_executor/debug/graph_executor_debug.cc b/src/runtime/graph_executor/debug/graph_executor_debug.cc index 5736462a648d..1ea01b19e8aa 100644 --- a/src/runtime/graph_executor/debug/graph_executor_debug.cc +++ b/src/runtime/graph_executor/debug/graph_executor_debug.cc @@ -20,7 +20,7 @@ /*! * \file graph_executor_debug.cc */ -#include +#include #include #include #include diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index 584aafe3410b..1084b4ee3ec4 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -22,7 +22,8 @@ */ #include "graph_executor.h" -#include +#include +#include #include #include #include @@ -266,10 +267,10 @@ void GraphExecutor::DefaultLookupLinkedParam(TVMArgs args, TVMRetValue* rv) { std::vector shape_vec{template_tensor->shape, template_tensor->shape + template_tensor->ndim}; - std::unique_ptr container{new NDArray::Container( - static_cast(opaque_handle), shape_vec, template_tensor->dtype, dev)}; + auto* container = new NDArray::Container(static_cast(opaque_handle), shape_vec, + template_tensor->dtype, dev); container->SetDeleter(GraphExecutor::LinkedNDArrayDeleter); - *rv = NDArray(GetObjectPtr(container.release())); + *rv = NDArray(GetObjectPtr(container)); } void GraphExecutor::SetupStorage() { diff --git a/src/runtime/graph_executor/graph_executor_factory.cc b/src/runtime/graph_executor/graph_executor_factory.cc index 8ea21cabf519..a13fbd860d43 100644 --- a/src/runtime/graph_executor/graph_executor_factory.cc +++ b/src/runtime/graph_executor/graph_executor_factory.cc @@ -24,7 +24,7 @@ #include "./graph_executor_factory.h" -#include +#include #include #include diff --git a/src/runtime/metadata_module.cc b/src/runtime/metadata_module.cc index 4a1d89ce1a1f..7cb986bba62c 100644 --- a/src/runtime/metadata_module.cc +++ b/src/runtime/metadata_module.cc @@ -27,7 +27,8 @@ * code and metadata significantly reduces the efforts for handling external * codegen and runtimes. */ -#include +#include +#include #include #include #include diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 9ebe04efbe4c..7d2ef0c9367b 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -42,9 +42,65 @@ #include "../workspace_pool.h" +/* Macro for convenience in using AutoReleasePoolWrapper. + * With this macro we can add AutoReleasePoolWrapper to our ObjC code in more + * native way. + * + * For example, this is ObjC code with autoreleasepool: + * @autoreleasepool { + * // Some code + * } + * + * To avoid possible memory leaks when an exception will be generated, we + * should update this code: + * AUTORELEASEPOOL { // Replace @autoreleasepool -> AUTORELEASEPOOL + * // Some code + * }; // Add semicolon after close bracket + * + * In macro AUTORELEASEPOOL we get the instance of AutoReleasePoolWrapper and + * put a lambda function with code from autoreleasepool to the insertion + * operator of AutoReleasePoolWrapper class. + * + * Note: If you want to return a value from the autoreleasepool, you should + * declare the variable with result before AUTORELEASEPOOL macro. This variable + * will be captured by reference and you can use it in the code in autorelease + * pool. But you should write return statement after AUTORELEASEPOOL macro. + */ +#define AUTORELEASEPOOL tvm::runtime::metal::AutoReleasePoolWrapper::GetInstance() << [&]() + namespace tvm { namespace runtime { namespace metal { +/*! + * \brief Wrapper on autoreleasepool with exception handling + * + * \note In case when the exception was thrown from the autoreleasepool, the + * allocated resources won't be released in proper way. So, we handle exception + * in autoreleasepool and after the autoreleasepool we rethrow this exception. + */ +class AutoReleasePoolWrapper { + public: + static AutoReleasePoolWrapper& GetInstance(); + template + void operator<<(const T& f) { + std::exception_ptr eptr; + @autoreleasepool { + try { + f(); + } catch (...) { + eptr = std::current_exception(); + } + } + if (eptr) std::rethrow_exception(eptr); + } + + private: + AutoReleasePoolWrapper() = default; + ~AutoReleasePoolWrapper() = default; + AutoReleasePoolWrapper(const AutoReleasePoolWrapper&) = delete; + AutoReleasePoolWrapper& operator=(const AutoReleasePoolWrapper&) = delete; +}; + /*! * \brief Structure for error handling in queues */ diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 193e4647733a..1c5666dfc17f 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -29,17 +29,20 @@ namespace runtime { namespace metal { +AutoReleasePoolWrapper& AutoReleasePoolWrapper::GetInstance() { + static AutoReleasePoolWrapper instance; + return instance; +} + MetalWorkspace* MetalWorkspace::Global() { - @autoreleasepool { - // NOTE: explicitly use new to avoid exit-time destruction of global state - // Global state will be recycled by OS as the process exits. - static MetalWorkspace* inst = new MetalWorkspace(); - return inst; - } + // NOTE: explicitly use new to avoid exit-time destruction of global state + // Global state will be recycled by OS as the process exits. + static MetalWorkspace* inst = new MetalWorkspace(); + return inst; } void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { - @autoreleasepool { + AUTORELEASEPOOL { this->Init(); size_t index = static_cast(dev.device_id); if (kind == kExist) { @@ -80,7 +83,7 @@ case kDriverVersion: return; } - } + }; } static const char* kDummyKernel = R"A0B0( @@ -161,7 +164,8 @@ int GetWarpSize(id dev) { void* MetalWorkspace::AllocDataSpace(Device device, size_t nbytes, size_t alignment, DLDataType type_hint) { - @autoreleasepool { + id buf; + AUTORELEASEPOOL { this->Init(); id dev = GetDevice(device); // GPU memory only @@ -173,20 +177,20 @@ int GetWarpSize(id dev) { storage_mode = MTLResourceStorageModeManaged; #endif */ - id buf = [dev newBufferWithLength:nbytes options:storage_mode]; + buf = [dev newBufferWithLength:nbytes options:storage_mode]; ICHECK(buf != nil); - return (void*)(buf); - } + }; + return (void*)(buf); } void MetalWorkspace::FreeDataSpace(Device dev, void* ptr) { - @autoreleasepool { + AUTORELEASEPOOL { // MTLBuffer PurgeableState should be set to empty before manual // release in order to prevent memory leak [(id)ptr setPurgeableState:MTLPurgeableStateEmpty]; // release the ptr. CFRelease(ptr); - } + }; } Stream* GetStream(TVMStreamHandle stream, int device_id) { @@ -199,7 +203,7 @@ int GetWarpSize(id dev) { void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) { - @autoreleasepool { + AUTORELEASEPOOL { this->Init(); Device dev = dev_from; Stream* s = GetStream(stream, dev.device_id); @@ -261,7 +265,7 @@ int GetWarpSize(id dev) { LOG(FATAL) << "Expect copy from/to Metal or between Metal" << ", from=" << from_dev_type << ", to=" << to_dev_type; } - } + }; } TVMStreamHandle MetalWorkspace::CreateStream(Device dev) { @@ -276,7 +280,7 @@ int GetWarpSize(id dev) { } void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { - @autoreleasepool { + AUTORELEASEPOOL { Stream* s = GetStream(stream, dev.device_id); // commit an empty command buffer and wait until it completes. id cb = s->GetCommandBuffer(); @@ -285,7 +289,7 @@ int GetWarpSize(id dev) { if (s->HasErrorHappened()) { LOG(FATAL) << "Error! Some problems on GPU happaned!"; } - } + }; } void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) { diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index e22caa21a81e..88501880557e 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -30,6 +30,7 @@ #include "../file_utils.h" #include "../meta_data.h" #include "../pack_args.h" +#include "../source_utils.h" #include "../thread_storage_scope.h" #include "metal_common.h" @@ -43,7 +44,9 @@ public: explicit MetalModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string source) - : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {} + : data_(data), fmt_(fmt), fmap_(fmap), source_(source) { + parsed_kernels_ = SplitKernels(data); + } const char* type_key() const final { return "metal"; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; @@ -71,6 +74,7 @@ void SaveToBinary(dmlc::Stream* stream) final { return ""; } } + // get a from primary context in device_id id GetPipelineState(size_t device_id, const std::string& func_name) { metal::MetalWorkspace* w = metal::MetalWorkspace::Global(); @@ -85,37 +89,44 @@ void SaveToBinary(dmlc::Stream* stream) final { if (it != e.smap.end()) return it->second; // compile NSError* err_msg = nil; - if (e.lib == nil) { - if (fmt_ == "metal") { - MTLCompileOptions* opts = [MTLCompileOptions alloc]; - opts.languageVersion = MTLLanguageVersion2_3; - opts.fastMathEnabled = YES; - // opts = nil; - e.lib = [w->devices[device_id] - newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()] - options:opts - error:&err_msg]; - [opts dealloc]; - if (e.lib == nil) { - LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; - } - if (err_msg != nil) { - LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String]; - } - } else { - // Build from library. - auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL); - auto data = dispatch_data_create(data_.c_str(), data_.length(), q, - ^{ - }); - e.lib = [w->devices[device_id] newLibraryWithData:data error:&err_msg]; - if (err_msg != nil || e.lib == nil) { - LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; - } + id lib = nil; + std::string source; + auto kernel = parsed_kernels_.find(func_name); + // If we cannot find this kernel in parsed_kernels_, it means that all kernels going together + // without explicit separator. In this case we use data_ with all kernels. It done for backward + // compatibility. + if (kernel != parsed_kernels_.end()) + source = kernel->second; + else + source = data_; + if (fmt_ == "metal") { + MTLCompileOptions* opts = [MTLCompileOptions alloc]; + opts.languageVersion = MTLLanguageVersion2_3; + opts.fastMathEnabled = YES; + // opts = nil; + lib = + [w->devices[device_id] newLibraryWithSource:[NSString stringWithUTF8String:source.c_str()] + options:opts + error:&err_msg]; + [opts dealloc]; + if (lib == nil) { + LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; + } + if (err_msg != nil) { + LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String]; + } + } else { + // Build from library. + auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL); + auto data = dispatch_data_create(source.c_str(), source.length(), q, + ^{ + }); + lib = [w->devices[device_id] newLibraryWithData:data error:&err_msg]; + if (err_msg != nil || lib == nil) { + LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; } } - id f = - [e.lib newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; + id f = [lib newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; ICHECK(f != nil) << "cannot find function " << func_name; id state = [w->devices[device_id] newComputePipelineStateWithFunction:f error:&err_msg]; @@ -123,6 +134,7 @@ void SaveToBinary(dmlc::Stream* stream) final { << " for function " << func_name << [[err_msg localizedDescription] UTF8String]; [f release]; + [lib release]; // The state.threadExecutionWidth can change dynamically according // to the resource constraint in kernel, so it is not strictly hold // Turn of warp aware optimziation for now. @@ -135,13 +147,10 @@ void SaveToBinary(dmlc::Stream* stream) final { private: // device specific entry struct DeviceEntry { - // library - id lib = nil; // state cache; - std::unordered_map > smap; + std::unordered_map> smap; ~DeviceEntry() { - if (lib != nil) [lib release]; for (auto&& kv : smap) { [kv.second release]; } @@ -159,6 +168,8 @@ void SaveToBinary(dmlc::Stream* stream) final { std::vector finfo_; // internal mutex when updating the module std::mutex mutex_; + // parsed kernel data + std::unordered_map parsed_kernels_; }; // a wrapped function class to get packed func. @@ -182,7 +193,7 @@ void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_na } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const { - @autoreleasepool { + AUTORELEASEPOOL { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->device.device_id; auto stream = static_cast(t->stream[device_id]); @@ -212,7 +223,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons [encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock]; [encoder endEncoding]; [cb commit]; - } + }; } private: @@ -237,27 +248,33 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons PackedFunc MetalModuleNode::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { - @autoreleasepool { + PackedFunc pf; + AUTORELEASEPOOL { ICHECK_EQ(sptr_to_self.get(), this); ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); - if (it == fmap_.end()) return PackedFunc(); + if (it == fmap_.end()) { + pf = PackedFunc(); + return; + } const FunctionInfo& info = it->second; MetalWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, info.thread_axis_tags); - return PackFuncNonBufferArg(f, info.arg_types); - } + pf = PackFuncNonBufferArg(f, info.arg_types); + }; + return pf; } Module MetalModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source) { - @autoreleasepool { + ObjectPtr n; + AUTORELEASEPOOL { metal::MetalWorkspace::Global()->Init(); - auto n = make_object(data, fmt, fmap, source); - return Module(n); - } + n = make_object(data, fmt, fmap, source); + }; + return Module(n); } // Load module from module. diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 15b9c0dde877..acc7fc7286d1 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -139,8 +139,6 @@ bool RuntimeEnabled(const std::string& target) { f_name = "target.build.stackvm"; } else if (target == "rpc") { f_name = "device_api.rpc"; - } else if (target == "micro_dev") { - f_name = "device_api.micro_dev"; } else if (target == "hexagon") { f_name = "device_api.hexagon"; } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") { diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 3d3466bed47c..968a4488bbcf 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -123,7 +123,7 @@ struct NDArray::Internal { } // Local create function which allocates tensor metadata // but does not allocate space for the data. - static NDArray Create(std::vector shape, DLDataType dtype, Device dev) { + static NDArray Create(ShapeTuple shape, DLDataType dtype, Device dev) { VerifyDataType(dtype); // critical zone: construct header @@ -134,7 +134,7 @@ struct NDArray::Internal { NDArray ret(GetObjectPtr(data)); // setup shape data->shape_ = std::move(shape); - data->dl_tensor.shape = dmlc::BeginPtr(data->shape_); + data->dl_tensor.shape = const_cast(data->shape_.data()); data->dl_tensor.ndim = static_cast(data->shape_.size()); // setup dtype data->dl_tensor.dtype = dtype; @@ -172,7 +172,7 @@ struct NDArray::Internal { } }; -NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { +NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype) { ICHECK(data_ != nullptr); ICHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view for compact tensor"; NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.device); @@ -190,8 +190,7 @@ NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { DLManagedTensor* NDArray::ToDLPack() const { return Internal::ToDLPack(get_mutable()); } -NDArray NDArray::Empty(std::vector shape, DLDataType dtype, Device dev, - Optional mem_scope) { +NDArray NDArray::Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional mem_scope) { NDArray ret = Internal::Create(shape, dtype, dev); ret.get_mutable()->dl_tensor.data = DeviceAPI::Get(ret->device) @@ -207,9 +206,11 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { data->manager_ctx = tensor; data->dl_tensor = tensor->dl_tensor; // update shape_ - data->shape_.resize(data->dl_tensor.ndim); - data->shape_.assign(data->dl_tensor.shape, data->dl_tensor.shape + data->dl_tensor.ndim); - data->dl_tensor.shape = data->shape_.data(); + std::vector shape; + shape.resize(data->dl_tensor.ndim); + shape.assign(data->dl_tensor.shape, data->dl_tensor.shape + data->dl_tensor.ndim); + data->shape_ = ShapeTuple(shape); + data->dl_tensor.shape = const_cast(data->shape_.data()); return NDArray(GetObjectPtr(data)); } @@ -242,7 +243,7 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str DeviceAPI::Get(dev)->CopyDataFromTo(const_cast(from), to, stream); } -std::vector NDArray::Shape() const { return get_mutable()->shape_; } +ShapeTuple NDArray::Shape() const { return get_mutable()->shape_; } runtime::DataType NDArray::DataType() const { return runtime::DataType(get_mutable()->dl_tensor.dtype); } @@ -274,7 +275,7 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ Device dev; dev.device_type = static_cast(device_type); dev.device_id = device_id; - auto ndarray = NDArray::Empty(std::vector(shape, shape + ndim), dtype, dev); + auto ndarray = NDArray::Empty(ShapeTuple(shape, shape + ndim), dtype, dev); *out = NDArray::Internal::MoveToFFIHandle(ndarray); API_END(); @@ -283,7 +284,7 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ TVM_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body([](TVMArgs args, TVMRetValue* ret) { int64_t* shape_ptr = static_cast(static_cast(args[0])); int ndim = args[1]; - std::vector shape(shape_ptr, shape_ptr + ndim); + ShapeTuple shape(shape_ptr, shape_ptr + ndim); DataType dtype = args[2]; Device dev = args[3]; Optional mem_scope = args[4]; diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 93420feec805..c31576f6d286 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -66,6 +66,7 @@ #include "../file_utils.h" #include "../meta_data.h" #include "../pack_args.h" +#include "../texture.h" #include "../thread_storage_scope.h" #include "../workspace_pool.h" @@ -174,6 +175,29 @@ inline const char* CLGetErrorString(cl_int error) { } } +inline cl_channel_type DTypeToOpenCLChannelType(DLDataType data_type) { + DataType dtype(data_type); + if (dtype == DataType::Float(32)) { + return CL_FLOAT; + } else if (dtype == DataType::Float(16)) { + return CL_HALF_FLOAT; + } else if (dtype == DataType::Int(8)) { + return CL_SIGNED_INT8; + } else if (dtype == DataType::Int(16)) { + return CL_SIGNED_INT16; + } else if (dtype == DataType::Int(32)) { + return CL_SIGNED_INT32; + } else if (dtype == DataType::UInt(8)) { + return CL_UNSIGNED_INT8; + } else if (dtype == DataType::UInt(16)) { + return CL_UNSIGNED_INT16; + } else if (dtype == DataType::UInt(32)) { + return CL_UNSIGNED_INT32; + } + LOG(FATAL) << "data type is not supported in OpenCL runtime yet: " << dtype; + return CL_FLOAT; +} + /*! * \brief Protected OpenCL call * \param func Expression to call. @@ -243,11 +267,18 @@ class OpenCLWorkspace : public DeviceAPI { void SetDevice(Device dev) final; void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; void* AllocDataSpace(Device dev, size_t size, size_t alignment, DLDataType type_hint) final; + void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, + Optional mem_scope = NullOpt) final; void FreeDataSpace(Device dev, void* ptr) final; void StreamSync(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; + // Texture (image2d_t) alloca APIs + cl_mem AllocTexture(Device dev, size_t width, size_t height, DLDataType type_hint); + void* AllocTextureWorkspace(Device dev, size_t width, size_t height, DLDataType type_hint); + void FreeTextureWorkspace(Device dev, void* data); + /*! * \brief Get the thread local ThreadEntry */ @@ -256,10 +287,7 @@ class OpenCLWorkspace : public DeviceAPI { // get the global workspace static OpenCLWorkspace* Global(); - protected: - void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, - Device dev_from, Device dev_to, DLDataType type_hint, - TVMStreamHandle stream) final; + void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final; }; /*! \brief Thread local workspace */ @@ -278,9 +306,11 @@ class OpenCLThreadEntry { std::vector kernel_table; /*! \brief workspace pool */ WorkspacePool pool; + /*! \brief texture pool */ + TexturePool texture_pool; // constructor OpenCLThreadEntry(DLDeviceType device_type, DeviceAPI* device_api) - : pool(device_type, device_api) { + : pool(device_type, device_api), texture_pool(device_type, device_api) { device.device_id = 0; device.device_type = device_type; } @@ -289,6 +319,29 @@ class OpenCLThreadEntry { // get the global workspace static OpenCLThreadEntry* ThreadLocal(); }; + +/*! \brief OpenCL runtime buffer structure with tracked memory layout */ +struct BufferDescriptor { + enum class MemoryLayout { + /*! \brief One dimensional buffer in row-major layout*/ + kBuffer1D, + /*! \brief Two dimensional texture w/ width = axis[-1] + * e.g. image2d[height=NCH, width=W] + */ + kImage2DActivation, + /*! \brief Two dimensional texture w/ height = axis[0] + * e.g. image2d[height=O, width=IHW] + */ + kImage2DWeight, + }; + BufferDescriptor() = default; + explicit BufferDescriptor(Optional scope) : layout(MemoryLayoutFromScope(scope)) {} + static MemoryLayout MemoryLayoutFromScope(Optional mem_scope); + static String ScopeFromMemoryLayout(MemoryLayout mem_scope); + + cl_mem buffer{nullptr}; + MemoryLayout layout{MemoryLayout::kBuffer1D}; +}; } // namespace cl // Module to support thread-safe multi-device execution. @@ -326,14 +379,6 @@ class OpenCLModuleNode : public ModuleNode { cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, const std::string& func_name, const KTRefEntry& e); - /* - * \brief Splits the provided serialized source file into separate - * source for each kernel primitive. - * \param source The serialized program source file (fmt: cl) - * \return Mapping from primitive name to kernel source - */ - std::unordered_map SplitKernels(std::string source) const; - private: // The workspace, need to keep reference to use it in destructor. // In case of static destruction order problem. @@ -357,7 +402,6 @@ class OpenCLModuleNode : public ModuleNode { // parsed kernel data std::unordered_map parsed_kernels_; }; - } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index e9f092cc6579..26eddb40a7d5 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -32,6 +32,63 @@ namespace cl { std::string GetPlatformInfo(cl_platform_id pid, cl_platform_info param_name); std::string GetDeviceInfo(cl_device_id pid, cl_device_info param_name); +struct ImageInfo { + size_t origin[3] = {}; + size_t region[3] = {}; + size_t row_pitch = 0; + size_t slice_pitch = 0; +}; + +/*! + * \brief Utility to apply a memory layout specific lowering convention + * to infer the physical shape from the provided DLTensor's logical shape. + * \param desc Descriptor which contains the buffer and layout tag. + * \param The DLTensor used to infer the tensors physical shape. + */ +ImageInfo GetImageInfo(const cl::BufferDescriptor* desc, const DLTensor* tensor) { + ImageInfo info{}; + ICHECK(tensor->dtype.lanes == 1) << "Image dtype has lanes: " << tensor->dtype.lanes; + + info.origin[0] = info.origin[1] = info.origin[2] = 0; + info.row_pitch = 0; + info.slice_pitch = 0; + + size_t axis = DefaultTextureLayoutSeparator( + tensor->ndim, cl::BufferDescriptor::ScopeFromMemoryLayout(desc->layout)); + auto texture_shape = ApplyTexture2DFlattening(tensor->shape, tensor->ndim, axis); + info.region[0] = texture_shape.width; + info.region[1] = texture_shape.height; + info.region[2] = 1; + return info; +} + +cl::BufferDescriptor::MemoryLayout cl::BufferDescriptor::MemoryLayoutFromScope( + Optional mem_scope) { + if (!mem_scope.defined()) { + return cl::BufferDescriptor::MemoryLayout::kBuffer1D; + } else if (mem_scope.value() == "global.texture") { + return cl::BufferDescriptor::MemoryLayout::kImage2DActivation; + } else if (mem_scope.value() == "global.texture-weight") { + return cl::BufferDescriptor::MemoryLayout::kImage2DWeight; + } + LOG(FATAL) << "No memory layout defined for memory of scope: " << mem_scope.value(); + return cl::BufferDescriptor::MemoryLayout::kBuffer1D; +} + +String cl::BufferDescriptor::ScopeFromMemoryLayout(cl::BufferDescriptor::MemoryLayout layout) { + switch (layout) { + case cl::BufferDescriptor::MemoryLayout::kBuffer1D: + return "global"; + case cl::BufferDescriptor::MemoryLayout::kImage2DActivation: + return "global.texture"; + case cl::BufferDescriptor::MemoryLayout::kImage2DWeight: + return "global.texture-weight"; + } + LOG(FATAL) << "No scope corresponding to the provided memory layout: " + << static_cast(layout); + return ""; +} + OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { return OpenCLThreadEntry::ThreadLocal(); } OpenCLWorkspace* OpenCLWorkspace::Global() { @@ -138,9 +195,30 @@ void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t size, size_t alignment, this->Init(); ICHECK(context != nullptr) << "No OpenCL device"; cl_int err_code; - cl_mem mptr = clCreateBuffer(this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code); + cl::BufferDescriptor* desc = new cl::BufferDescriptor; + desc->buffer = clCreateBuffer(this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code); + desc->layout = cl::BufferDescriptor::MemoryLayout::kBuffer1D; OPENCL_CHECK_ERROR(err_code); - return mptr; + return desc; +} + +void* OpenCLWorkspace::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, + Optional mem_scope) { + if (!mem_scope.defined() || mem_scope.value() == "global") { + return DeviceAPI::AllocDataSpace(dev, ndim, shape, dtype, mem_scope); + } + ICHECK(IsTextureStorage(std::string(mem_scope.value()))) + << "Device does not support allocate data space with " + << "specified memory scope: " << mem_scope.value(); + + ICHECK(ndim > 2) << "Shape for texture allocation must be at least rank 3; " + << "provided shape is rank " << ndim; + + cl::BufferDescriptor* desc = new cl::BufferDescriptor(mem_scope); + size_t axis = DefaultTextureLayoutSeparator(ndim, mem_scope.value()); + auto texture = ApplyTexture2DFlattening(shape, ndim, axis); + desc->buffer = AllocTexture(dev, texture.width, texture.height, dtype); + return desc; } void OpenCLWorkspace::FreeDataSpace(Device dev, void* ptr) { @@ -148,31 +226,87 @@ void OpenCLWorkspace::FreeDataSpace(Device dev, void* ptr) { // for some OpenCL platforms. OPENCL_CALL(clFinish(this->GetQueue(dev))); - cl_mem mptr = static_cast(ptr); - OPENCL_CALL(clReleaseMemObject(mptr)); + cl::BufferDescriptor* desc = static_cast(ptr); + OPENCL_CALL(clReleaseMemObject(desc->buffer)); + delete desc; } -void OpenCLWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, - size_t to_offset, size_t size, Device dev_from, Device dev_to, - DLDataType type_hint, TVMStreamHandle stream) { +cl_mem OpenCLWorkspace::AllocTexture(Device dev, size_t width, size_t height, + DLDataType type_hint) { this->Init(); - ICHECK(stream == nullptr); - if (IsOpenCLDevice(dev_from) && IsOpenCLDevice(dev_to)) { - OPENCL_CALL(clEnqueueCopyBuffer(this->GetQueue(dev_to), - static_cast((void*)from), // NOLINT(*) - static_cast(to), from_offset, to_offset, size, 0, - nullptr, nullptr)); - } else if (IsOpenCLDevice(dev_from) && dev_to.device_type == kDLCPU) { - OPENCL_CALL(clEnqueueReadBuffer(this->GetQueue(dev_from), - static_cast((void*)from), // NOLINT(*) - CL_FALSE, from_offset, size, static_cast(to) + to_offset, - 0, nullptr, nullptr)); - OPENCL_CALL(clFinish(this->GetQueue(dev_from))); - } else if (dev_from.device_type == kDLCPU && IsOpenCLDevice(dev_to)) { - OPENCL_CALL(clEnqueueWriteBuffer(this->GetQueue(dev_to), static_cast(to), CL_FALSE, - to_offset, size, static_cast(from) + from_offset, - 0, nullptr, nullptr)); - OPENCL_CALL(clFinish(this->GetQueue(dev_to))); + ICHECK(context != nullptr) << "No OpenCL device"; + cl_int err_code; + cl_channel_type cl_type = DTypeToOpenCLChannelType(type_hint); + cl_image_format format = {CL_RGBA, cl_type}; + cl_image_desc descriptor = {CL_MEM_OBJECT_IMAGE2D, width, height, 0, 0, 0, 0, 0, 0}; + cl_mem mptr = + clCreateImage(this->context, CL_MEM_READ_WRITE, &format, &descriptor, nullptr, &err_code); + OPENCL_CHECK_ERROR(err_code); + return mptr; +} + +void* OpenCLWorkspace::AllocTextureWorkspace(Device dev, size_t width, size_t height, + DLDataType type_hint) { + return GetThreadEntry()->texture_pool.AllocTexture(dev, width, height, type_hint); +} + +void OpenCLWorkspace::FreeTextureWorkspace(Device dev, void* ptr) { + GetThreadEntry()->texture_pool.FreeTexture(dev, ptr); +} + +void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { + size_t nbytes = GetDataSize(*from); + ICHECK_EQ(nbytes, GetDataSize(*to)); + ICHECK(IsContiguous(*from) && IsContiguous(*to)) + << "CopyDataFromTo only support contiguous array for now"; + + if (IsOpenCLDevice(from->device) && IsOpenCLDevice(to->device)) { + const auto* from_desc = static_cast(from->data); + ICHECK(from_desc->layout == cl::BufferDescriptor::MemoryLayout::kBuffer1D) + << "Device to device copying is currently only implemented for OpenCL buffer storage"; + auto* to_desc = static_cast(to->data); + OPENCL_CALL(clEnqueueCopyBuffer(this->GetQueue(to->device), from_desc->buffer, to_desc->buffer, + from->byte_offset, to->byte_offset, nbytes, 0, nullptr, + nullptr)); + } else if (IsOpenCLDevice(from->device) && to->device.device_type == kDLCPU) { + const auto* from_desc = static_cast(from->data); + switch (from_desc->layout) { + case cl::BufferDescriptor::MemoryLayout::kBuffer1D: + OPENCL_CALL(clEnqueueReadBuffer( + this->GetQueue(from->device), from_desc->buffer, CL_FALSE, from->byte_offset, nbytes, + static_cast(to->data) + to->byte_offset, 0, nullptr, nullptr)); + break; + case cl::BufferDescriptor::MemoryLayout::kImage2DActivation: + case cl::BufferDescriptor::MemoryLayout::kImage2DWeight: + auto image_info = GetImageInfo(from_desc, from); + // TODO(csullivan): Support calculating row_pitch correctly in the case of reuse. + // Note that when utilizing texture pools for memory reuse, the allocated image + // size can be larger than the size to be read. + OPENCL_CALL(clEnqueueReadImage( + this->GetQueue(from->device), from_desc->buffer, CL_FALSE, image_info.origin, + image_info.region, image_info.row_pitch, image_info.slice_pitch, + static_cast(to->data) + to->byte_offset, 0, nullptr, nullptr)); + break; + } + OPENCL_CALL(clFinish(this->GetQueue(from->device))); + } else if (from->device.device_type == kDLCPU && IsOpenCLDevice(to->device)) { + auto* to_desc = static_cast(to->data); + switch (to_desc->layout) { + case cl::BufferDescriptor::MemoryLayout::kBuffer1D: + OPENCL_CALL(clEnqueueWriteBuffer( + this->GetQueue(to->device), to_desc->buffer, CL_FALSE, to->byte_offset, nbytes, + static_cast(from->data) + from->byte_offset, 0, nullptr, nullptr)); + break; + case cl::BufferDescriptor::MemoryLayout::kImage2DActivation: + case cl::BufferDescriptor::MemoryLayout::kImage2DWeight: + auto image_info = GetImageInfo(to_desc, to); + OPENCL_CALL(clEnqueueWriteImage( + this->GetQueue(to->device), to_desc->buffer, CL_FALSE, image_info.origin, + image_info.region, image_info.row_pitch, image_info.slice_pitch, + static_cast(from->data) + from->byte_offset, 0, nullptr, nullptr)); + break; + } + OPENCL_CALL(clFinish(this->GetQueue(to->device))); } else { LOG(FATAL) << "Expect copy from/to OpenCL or between OpenCL"; } @@ -291,6 +425,39 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic initialized_ = true; } +TVM_REGISTER_GLOBAL("device_api.opencl.AllocTexture").set_body([](TVMArgs args, TVMRetValue* rv) { + int device_type = args[0]; + int device_id = args[1]; + int width = args[2]; + int height = args[3]; + int dtype_code_hint = args[4]; + int dtype_bits_hint = args[5]; + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + + DLDataType type_hint; + type_hint.code = static_cast(dtype_code_hint); + type_hint.bits = static_cast(dtype_bits_hint); + type_hint.lanes = 1; + + OpenCLWorkspace* ptr = OpenCLWorkspace::Global(); + *rv = ptr->AllocTextureWorkspace(dev, static_cast(width), static_cast(height), + type_hint); +}); + +TVM_REGISTER_GLOBAL("device_api.opencl.FreeTexture").set_body([](TVMArgs args, TVMRetValue* rv) { + int device_type = args[0]; + int device_id = args[1]; + void* data = args[2]; + OpenCLWorkspace* ptr = OpenCLWorkspace::Global(); + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + ptr->FreeTextureWorkspace(dev, data); + *rv = static_cast(0); +}); + TVM_REGISTER_GLOBAL("device_api.opencl").set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = OpenCLWorkspace::Global(); *rv = static_cast(ptr); diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 6543b1de460c..397f57b36dad 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -29,6 +29,7 @@ #include #include +#include "../source_utils.h" #include "opencl_common.h" namespace tvm { @@ -63,7 +64,8 @@ class OpenCLWrappedFunc { } // setup arguments. for (cl_uint i = 0; i < arg_size_.size(); ++i) { - OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], void_args[i])); + auto* arg = static_cast(void_args[i]); + OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg->buffer)); } cl_command_queue queue = w_->GetQueue(t->device); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); @@ -188,6 +190,11 @@ void OpenCLModuleNode::Init() { // split into source artifacts for each kernel parsed_kernels_ = SplitKernels(GetSource("cl")); + ICHECK(!parsed_kernels_.empty()) << "The OpenCL module expects a kernel delimited " + << "source from code generation, but no kernel " + << "delimiter was found."; + ICHECK_EQ(workspace_->num_registered_kernels, parsed_kernels_.size()) + << "The number of registered kernels does not match number of parsed kernel sources"; // zero initialize cl_program pointers for each device kernel for (auto& kv : parsed_kernels_) { programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); @@ -242,39 +249,6 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre return kernel; } -std::unordered_map OpenCLModuleNode::SplitKernels( - std::string source) const { - std::unordered_map split_kernels; - if (source.size()) { - std::string del{"// Function: "}; - size_t end; - size_t begin = source.find(del); - ICHECK(begin != std::string::npos) << "The OpenCL module expects a kernel delimited " - << "source from code generation, but no kernel " - << "delimiter was found."; - for (size_t num_kernels = 0; num_kernels < workspace_->num_registered_kernels; num_kernels++) { - begin += del.size(); - end = source.find('\n', begin); - std::string func_name = source.substr(begin, end - begin); - begin = ++end; - // std::string::substr returns either start of next kernel - // or std::string::npos, in the latter case substr returns - // all characters until the end of the source string. - end = source.find(del, begin); - std::string func_source = - source.substr(begin, (end == std::string::npos) ? end : end - begin); - split_kernels.insert({func_name, func_source}); - begin = end; - if (end == std::string::npos) { - break; - } - } - } - ICHECK_EQ(workspace_->num_registered_kernels, split_kernels.size()) - << "The number of registered kernels does not match number of parsed kernel sources"; - return split_kernels; -} - Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); diff --git a/src/runtime/opencl/texture_pool.cc b/src/runtime/opencl/texture_pool.cc new file mode 100644 index 000000000000..bf52894da35e --- /dev/null +++ b/src/runtime/opencl/texture_pool.cc @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file texture_pool.h + * \brief Texture pool utility. + */ +#include +#include + +#include "../texture.h" + +namespace tvm { +namespace runtime { + +class TexturePool::Pool { + public: + Pool() = default; + void* Alloc(Device dev, DeviceAPI* device, size_t width, size_t height, DLDataType type_hint) { + Entry e; + e.data = nullptr; + if (free_list_.size() != 0) { + int64_t req_size = height * width; + Entry new_mem; + int64_t min_added_size = std::numeric_limits::max(); + int64_t min_wasted_size = std::numeric_limits::max(); + std::vector::iterator best_mem; + for (auto it = free_list_.begin(); it != free_list_.end(); ++it) { + if (it->type.code != type_hint.code) { + continue; + } + int64_t old_size = it->x * it->y; + new_mem.x = std::max(it->x, width); + new_mem.y = std::max(it->y, height); + int64_t new_size = new_mem.x * new_mem.y; + int64_t added_size = new_size - old_size; + int64_t wasted_size = new_size - req_size; + // Minimize added size first and wasted size thereafter + if ((min_added_size > 0 && added_size < min_added_size) || + (min_added_size == 0 && wasted_size < min_wasted_size)) { + min_added_size = added_size; + min_wasted_size = wasted_size; + best_mem = it; + } + } + + if (min_added_size == 0) { + // use existing block + e = *best_mem; + free_list_.erase(best_mem); + } else if (min_added_size <= req_size) { + // if added size is less or equal to + // what is needed by alloc, then grow entry + device->FreeDataSpace(dev, best_mem->data); + free_list_.erase(best_mem); + new_mem.type = type_hint; + std::vector shape{int64_t(new_mem.y), int64_t(new_mem.x), 4}; + new_mem.data = device->AllocDataSpace(dev, shape.size(), shape.data(), new_mem.type, + Optional("global.texture")); + e = new_mem; + } + } + + if (e.data == nullptr) { + // create new block + std::vector shape{int64_t(height), int64_t(width), 4}; + e.data = device->AllocDataSpace(dev, shape.size(), shape.data(), type_hint, + Optional("global.texture")); + e.x = width; + e.y = height; + e.type = type_hint; + } + + allocated_.push_back(e); + return e.data; + } + + void Free(void* data) { + Entry e; + if (allocated_.back().data == data) { + // quick path, last allocated. + e = allocated_.back(); + allocated_.pop_back(); + } else { + int index = static_cast(allocated_.size()) - 2; + for (; index >= 0 && allocated_[index].data != data; --index) { + } + ICHECK_GE(index, 0) << "Attempt to free texture that has not been allocated"; + e = allocated_[index]; + allocated_.erase(allocated_.begin() + index); + } + free_list_.push_back(e); + } + + // Release all resources immediately + void Release(Device dev, DeviceAPI* device) { + for (auto& e : allocated_) { + device->FreeDataSpace(dev, e.data); + } + for (auto& e : free_list_) { + device->FreeDataSpace(dev, e.data); + } + allocated_.clear(); + free_list_.clear(); + } + + private: + struct Entry { + void* data; + size_t x; + size_t y; + DLDataType type; + }; + std::vector free_list_; + std::vector allocated_; +}; + +TexturePool::TexturePool(DLDeviceType device_type, DeviceAPI* device) + : device_type_(device_type), device_(device) {} + +TexturePool::~TexturePool() { + for (size_t i = 0; i < array_.size(); ++i) { + if (array_[i] != nullptr) { + Device dev; + dev.device_type = device_type_; + dev.device_id = static_cast(i); + array_[i]->Release(dev, device_); + delete array_[i]; + } + } +} + +void* TexturePool::AllocTexture(Device dev, size_t width, size_t height, DLDataType type_hint) { + if (static_cast(dev.device_id) >= array_.size()) { + array_.resize(dev.device_id + 1, nullptr); + } + if (array_[dev.device_id] == nullptr) { + array_[dev.device_id] = new Pool(); + } + return array_[dev.device_id]->Alloc(dev, device_, width, height, type_hint); +} + +void TexturePool::FreeTexture(Device dev, void* ptr) { + ICHECK(static_cast(dev.device_id) < array_.size() && array_[dev.device_id] != nullptr) + << "Attempt to free texture from null texture pool"; + array_[dev.device_id]->Free(ptr); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 7c852da77df6..3776d18fafcc 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -32,6 +32,7 @@ #define TVM_RUNTIME_PACK_ARGS_H_ #include +#include #include #include diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index 1a963bee472c..fa09720fc90d 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -38,9 +38,9 @@ namespace runtime { struct Registry::Manager { // map storing the functions. - // We deliberately used raw pointer - // This is because PackedFunc can contain callbacks into the host languge(python) - // and the resource can become invalid because of indeterminstic order of destruction and forking. + // We deliberately used raw pointer. + // This is because PackedFunc can contain callbacks into the host language (Python) and the + // resource can become invalid because of indeterministic order of destruction and forking. // The resources will only be recycled during program exit. std::unordered_map fmap; // mutex diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 7db84862604f..7272269680c5 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -21,7 +21,7 @@ * \file rpc_module.cc * \brief RPC runtime module. */ -#include +#include #include #include #include diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 4e7fe3196d45..1456fc719113 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -21,7 +21,6 @@ * \file rpc_socket_impl.cc * \brief Socket based RPC implementation. */ -#include #include #include diff --git a/src/runtime/source_utils.cc b/src/runtime/source_utils.cc new file mode 100644 index 000000000000..e1cf94e52e18 --- /dev/null +++ b/src/runtime/source_utils.cc @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file source_utils.cc + */ +#include "source_utils.h" + +namespace tvm { +namespace runtime { + +std::unordered_map SplitKernels(std::string source, + std::string delimiter) { + std::unordered_map split_kernels; + if (source.size()) { + size_t begin = source.find(delimiter); + size_t end = begin; + while (end != std::string::npos) { + begin += delimiter.size(); + end = source.find('\n', begin); + std::string func_name = source.substr(begin, end - begin); + begin = ++end; + end = source.find(delimiter, begin); + std::string func_source = + source.substr(begin, (end == std::string::npos) ? end : end - begin); + split_kernels.insert({func_name, func_source}); + begin = end; + } + } + return split_kernels; +} +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/source_utils.h b/src/runtime/source_utils.h new file mode 100644 index 000000000000..5476585b945c --- /dev/null +++ b/src/runtime/source_utils.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file source_utils.h + * \brief Minimum source manipulation utils for runtime. + */ + +#ifndef TVM_RUNTIME_SOURCE_UTILS_H_ +#define TVM_RUNTIME_SOURCE_UTILS_H_ + +#include +#include + +namespace tvm { +namespace runtime { +/*! + * \brief Split the source file on separate kernels by specified delimiter. + * \param source The source code of the kernels. + * \param delimiter The delimiter which is using for splitting kernels. + * \return Mapping from primitive name to kernel source + */ +std::unordered_map SplitKernels(std::string source, + std::string delimiter = "// Function: "); +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_SOURCE_UTILS_H_ diff --git a/src/runtime/texture.h b/src/runtime/texture.h new file mode 100644 index 000000000000..83725c00b8c2 --- /dev/null +++ b/src/runtime/texture.h @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file texture.h + * \brief Texture utilities + */ +#ifndef TVM_RUNTIME_TEXTURE_H_ +#define TVM_RUNTIME_TEXTURE_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! \brief Structure to represent flattened texture shape */ +template +struct Texture2DShape { + T width; + T height; + T channel; +}; + +/*! + * \param shape_rank Rank N of the Nd-shape + * \param convention Storage scope convention to use for flattening + * \return The axis separator that defines the Nd shape partitioning in 2d + */ +inline size_t DefaultTextureLayoutSeparator(size_t shape_rank, + std::string convention = "global.texture") { + // Texture activation: + // e.g. [N,C,H,W,c] -> Texture2d[N*C*H, W, c] + // Texture weight: + // e.g. [O,I,H,W,c] -> Texture2d[O, I*H*W, c] + size_t separator = 0; + if (convention == "global.texture") { + separator = shape_rank - 2; + } else if (convention == "global.texture-weight") { + separator = 1; + } else { + LOG(FATAL) << "Encountered unknown texture lowering convention: " << convention; + } + return separator; +} + +/*! + * \param shape Nd shape + * \param rank Number of dimensions N of the Nd shape + * \param axis The axis separator that splits the Nd axes into two sets + * \return Width and height of the 2d shape + */ +template +Texture2DShape ApplyTexture2DFlattening(const S& shape, size_t rank, size_t axis) { + ICHECK(axis < rank) + << "Number of axes to flatten into rows must be less than shape rank for 2d flattening"; + Texture2DShape texture{1, 1, shape[rank - 1]}; + for (size_t i = 0; i < rank - 1; i++) { + if (i < axis) { + texture.height *= shape[i]; + } else { + texture.width *= shape[i]; + } + } + return texture; +} + +inline bool IsTextureStorage(std::string scope) { + return scope.find("texture") != std::string::npos; +} + +/*! + * \brief A two dimensional storage pool that recycles temporal workspace + * allocations for dynamically allocated texture. See AllocTexture docstring + * for approach to allocation and reuse. + */ +class TVM_DLL TexturePool { + public: + /*! + * \brief Create pool with specific device type and device. + * \param device_type The device type. + * \param device_api The device API. + */ + TexturePool(DLDeviceType device_type, DeviceAPI* device_api); + /*! \brief destructor */ + ~TexturePool(); + + /*! + * \brief Allocate a two dimensional temporal texture workspace on device + * + * \note Two dimensional texture workspaces will be grown and reused + * according to the following strategy: + * - Choose the workspace which minimizes the amount of memory required to + * grow the workspace to fit the request. + * - If a set of workspaces exist that fit the current request without + * expansion, choose the workspace of that set which most closely + * matches the request size, minimizing wasted space. + * + * \param dev The context of allocation. + * \param width The width of the 2d texture to be allocated. + * \param height The height of the 2d texture to be allocated. + * \param type_hint The type of elements. + */ + void* AllocTexture(Device dev, size_t width, size_t height, DLDataType type_hint); + /*! + * \brief Free temporal texture in backend execution. + * + * \param dev The context of allocation. + * \param ptr The pointer to be freed. + */ + void FreeTexture(Device dev, void* ptr); + + private: + class Pool; + /*! \brief pool of device local array */ + std::vector array_; + /*! \brief device type this pool support */ + DLDeviceType device_type_; + /*! \brief The device API */ + DeviceAPI* device_; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_TEXTURE_H_ diff --git a/src/runtime/thread_map.h b/src/runtime/thread_map.h new file mode 100644 index 000000000000..c3fc7e31e9bd --- /dev/null +++ b/src/runtime/thread_map.h @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_THREAD_MAP_H_ +#define TVM_RUNTIME_THREAD_MAP_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! \brief Container to hold one value per thread + * + * Similar to thread_local, but intended for use as a non-static or + * non-block variable, such as class member variables. All member + * functions are thread-safe to call. If only the current thread's + * value is accessed, no additional synchronization is required. If + * another thread's stored values are accessed, external + * synchronization may be required. + * + * Calls that only require access to already-existing values will not + * block each other. Calls that require constructing a new value will + * block any other calls. + * + * \tparam T The object type to be held. For instantiation of + * ThreadMap and for calls to ThreadMap::Get, only a forward + * declaration is required. For calls to ThreadMap::GetOrMake, a + * full class definition is required. + */ +template +class ThreadMap { + public: + ThreadMap() {} + + /*! \brief Return the current thread's stored object, if it exists. + * + * \return If it exists, a pointer to the stored object. Otherwise, + * returns nullptr. + */ + const T* Get() const { return this->Get(std::this_thread::get_id()); } + + /*! \brief Return the stored object for a given thread, if it exists. + * + * \param id The thread whose object should be returned. + * + * \return If it exists, a pointer to the stored object. Otherwise, + * returns nullptr. + */ + const T* Get(std::thread::id id) const { + std::shared_lock lock(mutex_); + auto res = values_.find(id); + if (res == values_.end()) { + return nullptr; + } else { + return res->second.get(); + } + } + + /*! \brief Return the current thread's stored object, if it exists. + * + * \return If it exists, a pointer to the stored object. Otherwise, + * returns nullptr. + */ + T* Get() { return const_cast(const_cast*>(this)->Get()); } + + /*! \brief Return the stored object for a given thread, if it exists. + * + * \param id The thread whose object should be returned. + * + * \return If it exists, a pointer to the stored object. Otherwise, + * returns nullptr. + */ + T* Get(std::thread::id id) { + return const_cast(const_cast*>(this)->Get(id)); + } + + /*! \brief Return the current thread's stored object, making it if + * necessary. + * + * Since this method can modify the stored map, there is no + * non-const version available. + * + * \tparam Params Types of the stored object's constructor arguments + * + * \return A reference to the stored object + */ + template + T& GetOrMake(Params&&... params) { + return GetOrMake(std::this_thread::get_id(), std::forward(params)...); + } + + /*! \brief Return the stored object for a given thread, making it if + * necessary + * + * Since this method can modify the stored map, there is no + * non-const version available. + * + * \tparam Params Types of the stored object's constructor arguments + * + * \param id The thread whose object should be returned. + * + * \param params Arguments to the stored object's constructor. Only + * used if the specified thread does not currently exist in the map. + * + * \return A reference to the stored object + */ + template + T& GetOrMake(std::thread::id id, Params&&... params) { + // Try to get stored value first, which would only require shared + // access. + if (T* output = Get(id)) { + return *output; + } + + // Not in map, need exclusive lock to write + std::unique_lock lock(mutex_); + + // Check again, in case another thread got the unique lock first + // and already constructed the object. + auto res = values_.find(id); + if (res != values_.end()) { + return *res->second; + } + + // No value exists, make one and return it. + std::unique_ptr& new_val = values_[id] = + std::make_unique(std::forward(params)...); + return *new_val; + } + + /*! \brief Clears all values held by the ThreadMap + * + * Calling Clear() invalidates any pointers/references previously + * returned by Get/GetOrMake. + * + */ + void Clear() { + std::unique_lock lock(mutex_); + values_.clear(); + } + + private: + //! \brief Mutex to protect values_ + mutable std::shared_timed_mutex mutex_; + + //! \brief Map containing stored values + std::unordered_map> values_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_THREAD_MAP_H_ diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 99126b159143..a7d65944d581 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -24,6 +24,7 @@ #include "vm.h" +#include #include #include @@ -105,6 +106,10 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& fun } std::unordered_map metrics; + + ICHECK(exec_->op_attrs.find(packed_index) != exec_->op_attrs.end()) + << packed_index_map_[packed_index] << " not found in op attrs"; + auto& op_attrs = exec_->op_attrs.at(packed_index); for (auto p : op_attrs) { if (std::string(p.first).find("layout") != std::string::npos) { diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 17a66e419316..c96364108a2a 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -23,7 +23,7 @@ */ #include -#include +#include #include #include #include @@ -473,7 +473,7 @@ void VirtualMachine::RunLoop() { case Opcode::InvokeClosure: { auto object = ReadRegister(instr.closure); const auto* closure = object.as(); - + ICHECK(closure); std::vector args; for (auto free_var : closure->free_vars) { args.push_back(free_var); diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc deleted file mode 100644 index b7fe2b1ceb21..000000000000 --- a/src/runtime/vulkan/vulkan.cc +++ /dev/null @@ -1,1340 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "../file_utils.h" -#include "../pack_args.h" -#include "../thread_storage_scope.h" -#include "../workspace_pool.h" -#include "vulkan_common.h" -#include "vulkan_module.h" -#include "vulkan_shader.h" -#include "vulkan_stream.h" - -namespace tvm { -namespace runtime { -namespace vulkan { - -/*! \brief Maximum number of GPU supported in VulkanModule. */ -static constexpr const int kVulkanMaxNumDevice = 8; - -/*! \brief TVM Vulkan binary pack magic number */ -static constexpr const int kVulkanModuleMagic = 0x02700027; - -struct VulkanBuffer { - VkBuffer buffer{VK_NULL_HANDLE}; - VkDeviceMemory memory{VK_NULL_HANDLE}; -}; - -/*! \brief A struct to represent Vulkan buffers backed by host visible memory */ -struct VulkanHostVisibleBuffer { - // A device where the buffer is allocated - VkDevice device{nullptr}; - // Vulkan buffer and memory - VulkanBuffer* vk_buf{nullptr}; - // The corresponding pointer to the host memory - void* host_addr{nullptr}; - // The size of the buffer in bytes - size_t size{0}; -}; - -using VulkanStagingBuffer = VulkanHostVisibleBuffer; -using VulkanUniformBuffer = VulkanHostVisibleBuffer; - -void DeleteHostVisibleBuffer(VulkanHostVisibleBuffer* buf) { - if (buf && buf->vk_buf) { - if (buf->host_addr != nullptr) { - vkUnmapMemory(buf->device, buf->vk_buf->memory); - } - if (buf->vk_buf->memory != VK_NULL_HANDLE) { - vkFreeMemory(buf->device, buf->vk_buf->memory, nullptr); - } - if (buf->vk_buf->buffer != VK_NULL_HANDLE) { - vkDestroyBuffer(buf->device, buf->vk_buf->buffer, nullptr); - } - buf->host_addr = nullptr; - delete buf->vk_buf; - } -} - -class VulkanThreadEntry { - public: - VulkanThreadEntry(); - static VulkanThreadEntry* ThreadLocal(); - - ~VulkanThreadEntry() { - // Because the thread entry refers to Device API - // The command buffer always will be destroyed before - // the instance and device get destroyed. - // The destruction need to be manually called - // to ensure the destruction order. - - pool.reset(); - streams_.clear(); - for (const auto& kv : staging_buffers_) { - DeleteHostVisibleBuffer(kv.second.get()); - } - } - - Device device; - std::unique_ptr pool; - VulkanStream* Stream(size_t device_id); - VulkanStagingBuffer* StagingBuffer(int device_id, size_t size); - void AllocateUniformBuffer(int device_id, size_t size); - VulkanUniformBuffer* GetUniformBuffer(int device_id, size_t size); - - private: - std::unordered_map> streams_; - std::unordered_map> staging_buffers_; - std::unordered_map> uniform_buffers_; -}; - -struct VulkanPipeline { - VulkanContext* vctx_{nullptr}; - VkShaderModule shader{VK_NULL_HANDLE}; - VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE}; - VkDescriptorPool descriptor_pool{VK_NULL_HANDLE}; - VkDescriptorSet descriptor_set{VK_NULL_HANDLE}; - VkPipelineLayout pipeline_layout{VK_NULL_HANDLE}; - VkPipeline pipeline{VK_NULL_HANDLE}; - VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE}; - bool use_ubo{false}; -}; - -typedef dmlc::ThreadLocalStore VulkanThreadStore; - -uint32_t FindMemoryType(const VulkanContext& vctx, VkBufferCreateInfo info, - VkMemoryPropertyFlags req_prop) { - VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); - - VkMemoryRequirements mem_reqs; - vkGetBufferMemoryRequirements(vctx.device, buffer, &mem_reqs); - uint32_t type_bits = mem_reqs.memoryTypeBits; - VkPhysicalDeviceMemoryProperties phy_mem_prop; - vkGetPhysicalDeviceMemoryProperties(vctx.phy_device, &phy_mem_prop); - for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) { - if ((type_bits & 1) == 1 && - (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) { - return i; - } - type_bits >>= 1; - } - LOG(FATAL) << "Requested memory type not found"; - return 0; -} - -VkBufferCreateInfo MakeBufferCreateInfo(const VulkanContext& vctx, size_t nbytes, - VkBufferUsageFlags usage) { - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = nbytes; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(vctx.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - info.usage = usage; - return info; -} - -VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage, - uint32_t mem_type_index) { - auto info = MakeBufferCreateInfo(vctx, nbytes, usage); - // create buffer - VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); - - // bind to memory - bool dedicated_allocation = false; - VkMemoryRequirements2KHR req2; - - if (vctx.get_buffer_memory_requirements_2_functions) { - VkBufferMemoryRequirementsInfo2KHR req_info2; - req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; - req_info2.pNext = 0; - req_info2.buffer = buffer; - - req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; - req2.pNext = 0; - - VkMemoryDedicatedRequirementsKHR dedicated_req; - dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; - dedicated_req.pNext = 0; - req2.pNext = &dedicated_req; - - vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( - vctx.device, &req_info2, &req2); - dedicated_allocation = - dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; - } - - VkDeviceMemory memory; - if (!dedicated_allocation) { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = info.size; - minfo.memoryTypeIndex = mem_type_index; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } else { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = req2.memoryRequirements.size; - minfo.memoryTypeIndex = mem_type_index; - - VkMemoryDedicatedAllocateInfoKHR mdinfo; - mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; - mdinfo.pNext = 0; - mdinfo.image = 0; - mdinfo.buffer = buffer; - minfo.pNext = &mdinfo; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } - VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); - VulkanBuffer* pbuf = new VulkanBuffer(); - pbuf->memory = memory; - pbuf->buffer = buffer; - return pbuf; -} - -class VulkanDeviceAPI final : public DeviceAPI { - public: - VulkanDeviceAPI(); - ~VulkanDeviceAPI() { - for (auto& vctx : context_) { - vkDestroyDevice(vctx.device, nullptr); - } - if (instance_) { - vkDestroyInstance(instance_, nullptr); - } - } - void SetDevice(Device dev) final { VulkanThreadEntry::ThreadLocal()->device = dev; } - void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; - std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); - void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final { - if (nbytes == 0) { - // Vulkan seems to have issues if we return nullptr on zero size alloc - nbytes = 1; - } - const auto& vctx = context(dev.device_id); - auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | - VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - return CreateBuffer(vctx, nbytes, usage, vctx.compute_mtype_index); - } - - void FreeDataSpace(Device dev, void* ptr) final { - // Before releasing the vkBuffer, call sync to - // finish all the vulkan commands that reference the buffer. - StreamSync(dev, nullptr); - - const auto& vctx = context(dev.device_id); - auto* pbuf = static_cast(ptr); - vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr); - vkFreeMemory(vctx.device, pbuf->memory, nullptr); - delete pbuf; - } - - protected: - void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, - Device dev_from, Device dev_to, DLDataType type_hint, - TVMStreamHandle stream) final { - ICHECK(stream == nullptr); - Device dev = dev_from; - if (dev_from.device_type == kDLCPU) { - dev = dev_to; - } - - int from_dev_type = static_cast(dev_from.device_type); - int to_dev_type = static_cast(dev_to.device_type); - if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) { - VulkanThreadEntry::ThreadLocal() - ->Stream(dev_from.device_id) - ->Launch([=](VulkanStreamState* state) { - // 1: copy - const auto* from_buf = static_cast(from); - auto* to_buf = static_cast(to); - VkBufferCopy copy_info; - copy_info.srcOffset = from_offset; - copy_info.dstOffset = to_offset; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info); - // 2: barrier(transfer-> compute|transfer) - ICHECK_EQ(dev_from.device_id, dev_to.device_id) << "Vulkan disallow cross device copy."; - VkMemoryBarrier barrier_info; - barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - barrier_info.pNext = nullptr; - barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; - barrier_info.dstAccessMask = - (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | - VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); - vkCmdPipelineBarrier( - state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT, - VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1, - &barrier_info, 0, nullptr, 0, nullptr); - }); - - } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) { - const auto* from_buf = static_cast(from); - const auto& vctx = context(dev_from.device_id); - auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_from.device_id, size); - VulkanThreadEntry::ThreadLocal() - ->Stream(dev_from.device_id) - ->Launch([&](VulkanStreamState* state) { - VkBufferCopy copy_info; - copy_info.srcOffset = from_offset; - copy_info.dstOffset = 0; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->vk_buf->buffer, 1, - ©_info); - }); - VulkanThreadEntry::ThreadLocal()->Stream(dev_from.device_id)->Synchronize(); - if (!vctx.coherent_staging) { - VkMappedMemoryRange mrange; - mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; - mrange.pNext = nullptr; - mrange.memory = temp->vk_buf->memory; - mrange.offset = 0; - mrange.size = VK_WHOLE_SIZE; // size; - VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange)); - } - memcpy(static_cast(to) + to_offset, static_cast(temp->host_addr), size); - } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) { - const auto& vctx = context(dev_to.device_id); - const auto* to_buf = static_cast(to); - VulkanStagingBuffer* temp = - VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_to.device_id, size); - memcpy(temp->host_addr, static_cast(from) + from_offset, size); - // host side flush if access is not coherent. - // so writes from CPU is visible to GPU - if (!vctx.coherent_staging) { - VkMappedMemoryRange mrange; - mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; - mrange.pNext = nullptr; - mrange.memory = temp->vk_buf->memory; - mrange.offset = 0; - mrange.size = VK_WHOLE_SIZE; // size; - VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange)); - } - - VulkanThreadEntry::ThreadLocal() - ->Stream(dev_to.device_id) - ->Launch([&](VulkanStreamState* state) { - // 0: barrier(host->transfer) - VkMemoryBarrier barrier_info; - barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - barrier_info.pNext = nullptr; - barrier_info.srcAccessMask = 0; - barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; - vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT, - VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0, - nullptr); - // 1: copy - VkBufferCopy copy_info; - copy_info.srcOffset = 0; - copy_info.dstOffset = to_offset; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, temp->vk_buf->buffer, to_buf->buffer, 1, - ©_info); - }); - // TODO(tulloch): should we instead make the staging buffer a property of the - // Stream? This would allow us to elide synchronizations here. - VulkanThreadEntry::ThreadLocal()->Stream(dev_to.device_id)->Synchronize(); - } else { - LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan" - << ", from=" << from_dev_type << ", to=" << to_dev_type; - } - } - - public: - // Current vulkan implementation has one "stream" per CPU thread, - // with all commands writing into a single command buffer that is - // submitted on a call to StreamSync. Therefore, for now, these are - // mostly no-ops. If needed in the future, could have multiple - // command buffers to act as multiple streams. - TVMStreamHandle CreateStream(Device dev) final { return nullptr; } - - void FreeStream(Device dev, TVMStreamHandle stream) final { - ICHECK_EQ(stream, static_cast(nullptr)); - return; - } - - // Syncing two streams is a nop, since there is only one stream. - void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) final { - ICHECK_EQ(event_src, static_cast(nullptr)); - ICHECK_EQ(event_dst, static_cast(nullptr)); - return; - } - - void StreamSync(Device dev, TVMStreamHandle stream) final { - ICHECK_EQ(stream, static_cast(nullptr)); - VulkanThreadEntry::ThreadLocal()->Stream(dev.device_id)->Synchronize(); - } - - void SetStream(Device dev, TVMStreamHandle stream) final { - ICHECK_EQ(stream, static_cast(nullptr)); - return; - } - - void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final { - return VulkanThreadEntry::ThreadLocal()->pool->AllocWorkspace(dev, size); - } - - void FreeWorkspace(Device dev, void* data) final { - VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(dev, data); - } - - static VulkanDeviceAPI* Global() { - // Most of the TVM Global() functions allocate with "new" and do - // not deallocate, as the OS can clean up any leftover buffers at - // the end. In this case, we need the VulkanDeviceAPI destructor - // to call vkDestroyInstance, to prevent a segfault on exit when - // using some nvidia drivers. - static VulkanDeviceAPI inst; - return &inst; - } - - const VulkanContext& context(size_t device_id) const { - ICHECK_LT(device_id, context_.size()); - return context_[device_id]; - } - - private: - VkInstance instance_{nullptr}; - // The physical devices, have 1 to 1 mapping to devices - std::vector context_; -}; - -void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { - size_t index = static_cast(dev.device_id); - if (kind == kExist) { - *rv = static_cast(index < context_.size()); - return; - } - ICHECK_LT(index, context_.size()) << "Invalid device id " << index; - const auto& vctx = context(index); - VkPhysicalDeviceProperties phy_prop; - vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); - - switch (kind) { - case kMaxThreadsPerBlock: { - int64_t value = phy_prop.limits.maxComputeWorkGroupInvocations; - *rv = value; - break; - } - case kMaxSharedMemoryPerBlock: { - int64_t value = phy_prop.limits.maxComputeSharedMemorySize; - *rv = value; - break; - } - case kWarpSize: { - VkPhysicalDeviceSubgroupProperties subgroup_prop; - subgroup_prop.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES; - subgroup_prop.pNext = NULL; - - VkPhysicalDeviceProperties2 phy_prop2; - phy_prop2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; - phy_prop2.pNext = &subgroup_prop; - - vkGetPhysicalDeviceProperties2(vctx.phy_device, &phy_prop2); - int64_t subgroup_size = subgroup_prop.subgroupSize; - ICHECK(subgroup_size >= 1); - - *rv = subgroup_size; - break; - } - case kComputeVersion: { - int64_t value = phy_prop.apiVersion; - std::ostringstream os; - os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." - << VK_VERSION_PATCH(value); - *rv = os.str(); - break; - } - case kDeviceName: - *rv = std::string(phy_prop.deviceName); - break; - case kMaxClockRate: - break; - case kMultiProcessorCount: - break; - case kExist: - break; - case kMaxThreadDimensions: { - int64_t dims[3]; - dims[0] = phy_prop.limits.maxComputeWorkGroupSize[0]; - dims[1] = phy_prop.limits.maxComputeWorkGroupSize[1]; - dims[2] = phy_prop.limits.maxComputeWorkGroupSize[2]; - std::stringstream ss; // use json string to return multiple int values; - ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; - *rv = ss.str(); - break; - } - case kMaxRegistersPerBlock: - break; - case kGcnArch: - break; - case kApiVersion: - *rv = VK_HEADER_VERSION; - break; - case kDriverVersion: { - int64_t value = phy_prop.driverVersion; - std::ostringstream os; - os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." - << VK_VERSION_PATCH(value); - *rv = os.str(); - break; - } - } -} - -VulkanDeviceAPI::VulkanDeviceAPI() { - VkApplicationInfo app_info; - app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; - app_info.pNext = nullptr; - app_info.pApplicationName = "TVM"; - app_info.applicationVersion = 0; - app_info.pEngineName = ""; - app_info.engineVersion = 0; - app_info.apiVersion = VK_MAKE_VERSION(1, 0, 0); - - VkInstanceCreateInfo inst_info; - inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; - inst_info.pNext = nullptr; - inst_info.flags = 0; - - const auto layers = []() -> std::vector { - uint32_t inst_layer_prop_count; - VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr)); - std::vector inst_layer_prop(inst_layer_prop_count); - VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data())); - std::vector l; - for (const auto& lp : inst_layer_prop) { - // TODO(tulloch): add CMAKE options. - (void)lp; // suppress unused variable warning. -#ifdef USE_VULKAN_VALIDATION - if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) { - l.push_back("VK_LAYER_LUNARG_standard_validation"); - } - if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) { - l.push_back("VK_LAYER_LUNARG_parameter_validation"); - } - if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) { - l.push_back("VK_LAYER_KHRONOS_validation"); - } -#endif - } - return l; - }(); - - const auto instance_extensions = []() -> std::vector { - uint32_t inst_extension_prop_count; - VULKAN_CALL( - vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr)); - std::vector inst_extension_prop(inst_extension_prop_count); - VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, - inst_extension_prop.data())); - std::vector extensions; - for (const auto& ip : inst_extension_prop) { - if (std::strcmp(ip.extensionName, "VK_KHR_get_physical_device_properties2") == 0) { - extensions.push_back("VK_KHR_get_physical_device_properties2"); - } - } - return extensions; - }(); - - inst_info.pApplicationInfo = &app_info; - inst_info.enabledLayerCount = layers.size(); - inst_info.ppEnabledLayerNames = layers.data(); - inst_info.enabledExtensionCount = instance_extensions.size(); - inst_info.ppEnabledExtensionNames = instance_extensions.data(); - - VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_)); - - uint32_t phy_dev_count = 0; - VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr)); - std::vector all_phy_devs(phy_dev_count); - VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs))); - for (VkPhysicalDevice phy_dev : all_phy_devs) { - // Get a list of queue families supporting compute, in order of preference. We currently only - // make use of the most preferred one family. - std::vector queue_family_indexes = GetComputeQueueFamilies(phy_dev); - if (queue_family_indexes.empty()) continue; - uint32_t queue_family_index = queue_family_indexes[0]; - float priority = 1.0f; - - struct VkDeviceQueueCreateInfo queue_create_info; - queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; - queue_create_info.pNext = nullptr; - queue_create_info.flags = 0; - queue_create_info.queueFamilyIndex = queue_family_index; - queue_create_info.queueCount = 1; - queue_create_info.pQueuePriorities = &priority; - - VulkanContext ctx; - // setup context - ctx.phy_device = phy_dev; - vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop)); - - const auto extensions = [&]() { - uint32_t device_extension_prop_count; - VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr, - &device_extension_prop_count, nullptr)); - std::vector device_extension_prop(device_extension_prop_count); - VULKAN_CALL(vkEnumerateDeviceExtensionProperties( - ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data())); - std::vector extensions; - for (const auto& dp : device_extension_prop) { - if ((std::strcmp(dp.extensionName, "VK_KHR_push_descriptor") == 0) && dp.specVersion > 0) { - extensions.push_back("VK_KHR_push_descriptor"); - } - if ((std::strcmp(dp.extensionName, "VK_KHR_descriptor_update_template") == 0) && - dp.specVersion > 0) { - extensions.push_back("VK_KHR_descriptor_update_template"); - } - if ((std::strcmp(dp.extensionName, "VK_KHR_get_memory_requirements2") == 0) && - dp.specVersion > 0) { - extensions.push_back("VK_KHR_get_memory_requirements2"); - } - if ((std::strcmp(dp.extensionName, "VK_KHR_dedicated_allocation") == 0) && - dp.specVersion > 0) { - extensions.push_back("VK_KHR_dedicated_allocation"); - } - } - return extensions; - }(); - - // All TVM-generated spirv shaders are marked as requiring int64 - // support, so we need to request it from the device, too. - VkPhysicalDeviceFeatures enabled_features = {}; - enabled_features.shaderInt64 = VK_TRUE; - - VkDeviceCreateInfo device_create_info; - device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; - device_create_info.pNext = nullptr; - device_create_info.flags = 0; - device_create_info.queueCreateInfoCount = 1; - device_create_info.pQueueCreateInfos = &queue_create_info; - device_create_info.enabledLayerCount = 0; - device_create_info.ppEnabledLayerNames = nullptr; - device_create_info.enabledExtensionCount = extensions.size(); - device_create_info.ppEnabledExtensionNames = extensions.data(); - device_create_info.pEnabledFeatures = &enabled_features; - VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device))); - ctx.queue_mutex.reset(new std::mutex()); - vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue)); - ctx.queue_family_index = queue_family_index; - // Find suitable memory type for staging and compute - // Find suitable compute index. - VkBuffer buffer; - VkMemoryRequirements req_staging, req_compute; - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = 1024; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(ctx.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - - // get staging requirement - info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer)); - vkGetBufferMemoryRequirements(ctx.device, buffer, &req_staging); - vkDestroyBuffer(ctx.device, buffer, nullptr); - // get compute requirement - info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | - VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer)); - vkGetBufferMemoryRequirements(ctx.device, buffer, &req_compute); - vkDestroyBuffer(ctx.device, buffer, nullptr); - - // Query phyiscal device property - // find a memory that is host visible, no need to be consistent - int win_rank = -1; - VkPhysicalDeviceMemoryProperties prop; - vkGetPhysicalDeviceMemoryProperties(ctx.phy_device, &prop); - - for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { - VkMemoryType ty = prop.memoryTypes[k]; - size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; - // host visible - if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue; - // match copy requirment - if (!(req_staging.memoryTypeBits & (1 << k))) continue; - if (heap_size < 1024) continue; - int rank = 0; - rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT; - if (rank > win_rank) { - win_rank = rank; - ctx.staging_mtype_index = k; - ctx.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - } - } - ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device."; - - win_rank = -1; - for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { - VkMemoryType ty = prop.memoryTypes[k]; - size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; - // host visible - if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue; - // match copy requirment - if (!(req_staging.memoryTypeBits & (1 << k))) continue; - if (heap_size < 1024) continue; - int rank = 0; - // prefer not host visible - rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); - if (rank > win_rank) { - win_rank = rank; - ctx.compute_mtype_index = k; - } - } - ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; - auto has_extension = [&extensions](const char* query) { - return std::any_of(extensions.begin(), extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }); - }; - -#ifdef USE_VULKAN_IMMEDIATE_MODE - if (has_extension("VK_KHR_push_descriptor") && - has_extension("VK_KHR_descriptor_update_template")) { - ctx.descriptor_template_khr_functions = std::unique_ptr( - new VulkanDescriptorTemplateKHRFunctions()); - ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR = - CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr( - ctx.device, "vkCreateDescriptorUpdateTemplateKHR")); - ctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR = - CHECK_NOTNULL((PFN_vkDestroyDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr( - ctx.device, "vkDestroyDescriptorUpdateTemplateKHR")); - ctx.descriptor_template_khr_functions->vkUpdateDescriptorSetWithTemplateKHR = - CHECK_NOTNULL((PFN_vkUpdateDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr( - ctx.device, "vkUpdateDescriptorSetWithTemplateKHR")); - ctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR = - CHECK_NOTNULL((PFN_vkCmdPushDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr( - ctx.device, "vkCmdPushDescriptorSetWithTemplateKHR")); - } -#endif - -#ifdef USE_VULKAN_DEDICATED_ALLOCATION - if (has_extension("VK_KHR_get_memory_requirements2") && - has_extension("VK_KHR_dedicated_allocation")) { - ctx.get_buffer_memory_requirements_2_functions = - std::unique_ptr( - new VulkanGetBufferMemoryRequirements2Functions()); - ctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR = - CHECK_NOTNULL((PFN_vkGetBufferMemoryRequirements2KHR)vkGetDeviceProcAddr( - ctx.device, "vkGetBufferMemoryRequirements2KHR")); - } -#endif - context_.push_back(std::move(ctx)); - } - - LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices.."; - for (size_t i = 0; i < context_.size(); ++i) { - LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName - << "\' phy_dev_id=" << context_[i].phy_device - << " use_immediate=" << context_[i].UseImmediate(); - } -} - -std::vector VulkanDeviceAPI::GetComputeQueueFamilies(VkPhysicalDevice phy_dev) { - uint32_t queue_prop_count = 0; - vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr); - std::vector queue_props(queue_prop_count); - vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, dmlc::BeginPtr(queue_props)); - - std::vector result; - // Prefer compute-only queues. On cerain devices supporting this (e.g. Mesa RADV), using - // compute-only queues gives better responsiveness for other graphics workload (e.g. desktop). - for (uint32_t i = 0; i != queue_prop_count; ++i) { - if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && - (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) == 0) { - result.push_back(i); - } - } - // Now, push the compute queues that we skipped above into the list. - for (uint32_t i = 0; i != queue_prop_count; ++i) { - if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && - (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) != 0) { - result.push_back(i); - } - } - return result; -} - -// namespace vulkan -class VulkanModuleNode; - -// a wrapped function class to get packed func. -class VulkanWrappedFunc { - public: - void Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags) { - m_ = m; - sptr_ = sptr; - func_name_ = func_name; - num_buffer_args_ = num_buffer_args; - num_pack_args_ = num_pack_args; - thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); - } - - void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const; - - private: - // internal module - VulkanModuleNode* m_; - // the resource holder - ObjectPtr sptr_; - // v The name of the function. - std::string func_name_; - // Number of buffer arguments - size_t num_buffer_args_; - // number of packed arguments. - size_t num_pack_args_; - // Device state cache per device. - // mark as mutable, to enable lazy initialization - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; - - mutable std::array, kVulkanMaxNumDevice> scache_; -}; - -// Multi-device enabled module. -class VulkanModuleNode final : public runtime::ModuleNode { - public: - explicit VulkanModuleNode(std::unordered_map smap, - std::unordered_map fmap, std::string source) - : smap_(smap), fmap_(fmap), source_(source) {} - - const char* type_key() const final { return "vulkan"; } - - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - ICHECK_EQ(sptr_to_self.get(), this); - ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; - auto it = fmap_.find(name); - if (it == fmap_.end()) return PackedFunc(); - const FunctionInfo& info = it->second; - VulkanWrappedFunc f; - size_t num_buffer_args = NumBufferArgs(info.arg_types); - f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, - info.thread_axis_tags); - return PackFuncNonBufferArg(std::move(f), info.arg_types); - } - - ~VulkanModuleNode() { - // cleanup vulkan related caches. - for (size_t device_id = 0; device_id < ecache_.size(); ++device_id) { - for (auto& kv : ecache_[device_id]) { - auto& pe = kv.second; - ICHECK(pe); - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - - if (pe->descriptor_update_template != VK_NULL_HANDLE) { - vctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR( - vctx.device, pe->descriptor_update_template, nullptr); - } - vkDestroyPipeline(vctx.device, pe->pipeline, nullptr); - vkDestroyPipelineLayout(vctx.device, pe->pipeline_layout, nullptr); - vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr); - vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr); - vkDestroyShaderModule(vctx.device, pe->shader, nullptr); - } - } - } - - std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, - size_t num_pack_args) { - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - std::lock_guard lock(mutex_); - const auto& cp = ecache_[device_id][func_name]; - if (cp) { - return cp; - } - // Create new pipeline - auto pe = std::make_shared(); - { - // create shader - auto sit = smap_.find(func_name); - ICHECK(sit != smap_.end()); - pe->use_ubo = sit->second.flag & (1 << ShaderMetaDataFlagMask::kUseUBO); - const std::vector& data = sit->second.data; - VkShaderModuleCreateInfo shader_cinfo; - shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; - shader_cinfo.pNext = nullptr; - shader_cinfo.flags = 0; - shader_cinfo.codeSize = data.size() * sizeof(uint32_t); - shader_cinfo.pCode = data.data(); - VULKAN_CALL(vkCreateShaderModule(vctx.device, &shader_cinfo, nullptr, &(pe->shader))); - } - std::vector arg_binding; - std::vector arg_template; - std::vector descriptor_set_pool_sizes; - uint32_t num_pod = 0, num_buffer = 0; - - auto push_arg_info = [&arg_binding, &arg_template, &descriptor_set_pool_sizes]( - uint32_t binding, VkDescriptorType desc_type) { - { - auto result = - std::find_if(descriptor_set_pool_sizes.begin(), descriptor_set_pool_sizes.end(), - [&](const auto& psize) { return psize.type == desc_type; }); - if (result == descriptor_set_pool_sizes.end()) { - VkDescriptorPoolSize new_size; - new_size.type = desc_type; - new_size.descriptorCount = 1; - descriptor_set_pool_sizes.push_back(new_size); - } else { - result->descriptorCount++; - } - } - - { - VkDescriptorSetLayoutBinding bd; - bd.binding = binding; - bd.descriptorType = desc_type; - bd.descriptorCount = 1; - bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; - bd.pImmutableSamplers = nullptr; - arg_binding.push_back(bd); - } - { - VkDescriptorUpdateTemplateEntryKHR tpl; - tpl.dstBinding = binding; - tpl.dstArrayElement = 0; - tpl.descriptorCount = 1; - tpl.descriptorType = desc_type; - tpl.offset = binding * sizeof(VkDescriptorBufferInfo); - tpl.stride = sizeof(VkDescriptorBufferInfo); - arg_template.push_back(tpl); - } - }; - - { - auto fit = fmap_.find(func_name); - ICHECK(fit != fmap_.end()); - for (DLDataType arg_type : fit->second.arg_types) { - if (arg_type.code == kTVMOpaqueHandle) { - push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER); - ++num_buffer; - } else { - ++num_pod; - } - } - } - - size_t nbytes_scalars = num_pod * sizeof(ArgUnion64); - if (pe->use_ubo) { - // Use UBO instead of push constants - push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER); - VulkanThreadEntry::ThreadLocal()->AllocateUniformBuffer(device_id, nbytes_scalars); - } - - { - VkDescriptorSetLayoutCreateInfo descrip_cinfo; - descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; - descrip_cinfo.pNext = nullptr; - descrip_cinfo.flags = 0; - if (vctx.UseImmediate()) { - descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; - } - descrip_cinfo.bindingCount = arg_binding.size(); - descrip_cinfo.pBindings = arg_binding.data(); - VULKAN_CALL(vkCreateDescriptorSetLayout(vctx.device, &descrip_cinfo, nullptr, - &(pe->descriptor_set_layout))); - } - - if (!vctx.UseImmediate()) { - VkDescriptorPoolCreateInfo descrip_pool_cinfo; - descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; - descrip_pool_cinfo.pNext = nullptr; - descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT; - descrip_pool_cinfo.maxSets = 1; - descrip_pool_cinfo.poolSizeCount = descriptor_set_pool_sizes.size(); - descrip_pool_cinfo.pPoolSizes = descriptor_set_pool_sizes.data(); - VULKAN_CALL(vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr, - &(pe->descriptor_pool))); - - VkDescriptorSetAllocateInfo alloc_info; - alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; - alloc_info.pNext = nullptr; - alloc_info.descriptorPool = pe->descriptor_pool; - alloc_info.descriptorSetCount = 1; - alloc_info.pSetLayouts = &(pe->descriptor_set_layout); - VULKAN_CALL(vkAllocateDescriptorSets(vctx.device, &alloc_info, &(pe->descriptor_set))); - } - - VkPushConstantRange crange; - crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; - crange.offset = 0; - crange.size = sizeof(ArgUnion64) * num_pack_args; - - VkPipelineLayoutCreateInfo playout_cinfo; - playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; - playout_cinfo.pNext = nullptr; - playout_cinfo.flags = 0; - playout_cinfo.setLayoutCount = 1; - playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout); - - if (0 < nbytes_scalars && !pe->use_ubo) { - playout_cinfo.pushConstantRangeCount = 1; - playout_cinfo.pPushConstantRanges = &crange; - ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize); - } else { - playout_cinfo.pushConstantRangeCount = 0; - playout_cinfo.pPushConstantRanges = nullptr; - } - - VULKAN_CALL( - vkCreatePipelineLayout(vctx.device, &playout_cinfo, nullptr, &(pe->pipeline_layout))); - - VkComputePipelineCreateInfo pipeline_cinfo; - pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; - pipeline_cinfo.pNext = nullptr; - pipeline_cinfo.flags = 0; - pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; - pipeline_cinfo.stage.pNext = nullptr; - pipeline_cinfo.stage.flags = 0; - pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; - pipeline_cinfo.stage.module = pe->shader; - pipeline_cinfo.stage.pName = func_name.c_str(); - pipeline_cinfo.stage.pSpecializationInfo = nullptr; - pipeline_cinfo.layout = pe->pipeline_layout; - pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE; - pipeline_cinfo.basePipelineIndex = 0; - VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, - &(pe->pipeline))); - - if (vctx.UseImmediate()) { - VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo; - descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR; - descrip_template_cinfo.pNext = 0; - descrip_template_cinfo.flags = 0; - descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size(); - descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data(); - descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR; - descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout; - descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE; - descrip_template_cinfo.pipelineLayout = pe->pipeline_layout; - descrip_template_cinfo.set = 0; - VULKAN_CALL(vctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR( - vctx.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template))); - } - ecache_[device_id][func_name] = pe; - return pe; - } - - void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string fmt = GetFileFormat(file_name, format); - ICHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan"; - std::string meta_file = GetMetaFilePath(file_name); - SaveMetaDataToFile(meta_file, fmap_); - std::string data_bin; - dmlc::MemoryStringStream fs(&data_bin); - dmlc::Stream* stream = &fs; - uint32_t magic = kVulkanModuleMagic; - stream->Write(magic); - stream->Write(smap_); - SaveBinaryToFile(file_name, data_bin); - } - - void SaveToBinary(dmlc::Stream* stream) final { - stream->Write(fmt_); - stream->Write(fmap_); - stream->Write(smap_); - } - std::string GetSource(const std::string& format) final { - // can only return source code. - return source_; - } - - private: - // function information table. - std::unordered_map smap_; - // function information table. - std::unordered_map fmap_; - // The format - std::string fmt_{"vulkan"}; - // The source - std::string source_; - - // Guards accesses to `ecache_` - std::mutex mutex_; - std::array>, kVulkanMaxNumDevice> - ecache_; -}; - -Module VulkanModuleCreate(std::unordered_map smap, - std::unordered_map fmap, std::string source) { - auto n = make_object(smap, fmap, source); - return Module(n); -} - -VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); } - -VulkanHostVisibleBuffer* GetOrAllocate( - int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index, - std::unordered_map>* buffers_ptr, - bool sync_before_realloc = false) { - auto& buffers = *buffers_ptr; - if (!buffers[device_id]) { - buffers[device_id] = std::make_unique(); - } - - auto& buf = *(buffers[device_id]); - if (buf.device != nullptr && buf.size < size) { - // free previous buffer - if (sync_before_realloc) { - // For the deferred execution mode, we need to make sure that old tasks that use - // the older, smaller buffer get finished - // Synchronization on staging buffers is done after host to device memory copy - // For UBO, we sync here before we reallocate a larger buffer, to minimize synchronization - // points - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Synchronize(); - } - DeleteHostVisibleBuffer(&buf); - } - - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - - if (buf.device == nullptr) { - buf.device = vctx.device; - } - if (buf.host_addr == nullptr) { - buf.vk_buf = CreateBuffer(vctx, size, usage, mem_type_index); - VULKAN_CALL(vkMapMemory(vctx.device, buf.vk_buf->memory, 0, size, 0, &(buf.host_addr))); - buf.size = size; - } - return &buf; -} - -VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) { - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - return GetOrAllocate(device_id, size, usage, vctx.staging_mtype_index, &staging_buffers_); -} - -void VulkanThreadEntry::AllocateUniformBuffer(int device_id, size_t size) { - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - auto info = MakeBufferCreateInfo(vctx, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); - auto mem_type_index = FindMemoryType(vctx, info, prop); - GetOrAllocate(device_id, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, mem_type_index, - &uniform_buffers_, true); -} - -VulkanUniformBuffer* VulkanThreadEntry::GetUniformBuffer(int device_id, size_t size) { - auto& buf = uniform_buffers_[device_id]; - ICHECK(buf); - ICHECK_GE(buf->size, size); - return buf.get(); -} - -VulkanThreadEntry::VulkanThreadEntry() - : pool(std::make_unique(static_cast(kDLVulkan), - VulkanDeviceAPI::Global())) { - device.device_id = 0; - device.device_type = static_cast(kDLVulkan); -} - -VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { - if (!streams_[device_id]) { - streams_[device_id] = std::unique_ptr( - new VulkanStream(&VulkanDeviceAPI::Global()->context(device_id))); - } - return streams_[device_id].get(); -} - -void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, - const ArgUnion64* pack_args) const { - int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id; - ICHECK_LT(device_id, kVulkanMaxNumDevice); - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - if (!scache_[device_id]) { - scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); - } - const auto& pipeline = scache_[device_id]; - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - std::vector descriptor_buffers; - descriptor_buffers.resize(num_buffer_args_); - for (size_t i = 0; i < num_buffer_args_; ++i) { - void* buf = args[static_cast(i)]; - VkDescriptorBufferInfo binfo; - binfo.buffer = static_cast(buf)->buffer; - binfo.offset = 0; - binfo.range = VK_WHOLE_SIZE; - descriptor_buffers[i] = binfo; - } - const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64); - if (pipeline->use_ubo) { - auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); - CHECK(ubo->host_addr) << "The UBO host buffer is not allocated"; - VkDescriptorBufferInfo binfo; - binfo.buffer = ubo->vk_buf->buffer; - binfo.offset = 0; - binfo.range = VK_WHOLE_SIZE; - descriptor_buffers.push_back(binfo); - } - if (vctx.UseImmediate()) { - // Can safely capture by reference as this lambda is immediately executed on the calling thread. - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) { - vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); - ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE); - vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( - state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0, - descriptor_buffers.data()); - - if (pipeline->use_ubo) { - auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); - memcpy(ubo->host_addr, pack_args, nbytes_scalars); - } else if (num_pack_args_ > 0) { - vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, - VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), - pack_args); - } - - vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); - VkMemoryBarrier barrier_info; - barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - barrier_info.pNext = nullptr; - barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT; - barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | - VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); - vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, - VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, - 1, &barrier_info, 0, nullptr, 0, nullptr); - }); - return; - } - - // Otherwise, the more expensive deferred path. - std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); - const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() { - std::vector write_descriptor_sets; - write_descriptor_sets.resize(descriptor_buffers.size()); - for (size_t i = 0; i < write_descriptor_sets.size(); i++) { - write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; - write_descriptor_sets[i].pNext = 0; - write_descriptor_sets[i].dstSet = pipeline->descriptor_set; - write_descriptor_sets[i].dstBinding = i; - write_descriptor_sets[i].dstArrayElement = 0; - write_descriptor_sets[i].descriptorCount = 1; - write_descriptor_sets[i].pImageInfo = 0; - write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]); - write_descriptor_sets[i].pTexelBufferView = 0; - - if (pipeline->use_ubo && i == write_descriptor_sets.size() - 1) { - // The last binding is for UBO - write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; - } else { - write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; - } - } - vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(), - 0, 0); - }; - const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars, - device_id](VulkanStreamState* state) { - vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); - vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, - pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0, - nullptr); - - if (pipeline->use_ubo) { - auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); - memcpy(ubo->host_addr, pack_args_storage.data(), nbytes_scalars); - } else if (num_pack_args_ > 0) { - vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, - 0, pack_args_storage.size() * sizeof(ArgUnion64), - pack_args_storage.data()); - } - - vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); - VkMemoryBarrier barrier_info; - barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - barrier_info.pNext = nullptr; - barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT; - barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | - VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); - vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, - VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, - 1, &barrier_info, 0, nullptr, 0, nullptr); - }; - VulkanStreamToken deferred_token; - deferred_token.descriptor_set_ = pipeline->descriptor_set; - deferred_token.buffers_.resize(descriptor_buffers.size()); - for (size_t i = 0; i < descriptor_buffers.size(); ++i) { - deferred_token.buffers_[i] = descriptor_buffers[i].buffer; - } - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->LaunchDeferred( - deferred_initializer, deferred_kernel, deferred_token); -} - -Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) { - std::string data; - std::unordered_map smap; - std::unordered_map fmap; - std::string fmt = GetFileFormat(file_name, format); - std::string meta_file = GetMetaFilePath(file_name); - LoadBinaryFromFile(file_name, &data); - LoadMetaDataFromFile(meta_file, &fmap); - dmlc::MemoryStringStream fs(&data); - dmlc::Stream* stream = &fs; - uint32_t magic; - stream->Read(&magic); - ICHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch"; - stream->Read(&smap); - return VulkanModuleCreate(smap, fmap, ""); -} - -Module VulkanModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); - std::unordered_map smap; - std::unordered_map fmap; - - std::string fmt; - stream->Read(&fmt); - stream->Read(&fmap); - stream->Read(&smap); - return VulkanModuleCreate(smap, fmap, ""); -} - -TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile); - -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); - -TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = VulkanDeviceAPI::Global(); - *rv = static_cast(ptr); -}); - -} // namespace vulkan -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_buffer.cc b/src/runtime/vulkan/vulkan_buffer.cc new file mode 100644 index 000000000000..ef8215c01738 --- /dev/null +++ b/src/runtime/vulkan/vulkan_buffer.cc @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_buffer.h" + +#include + +#include "vulkan_device_api.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VkBufferCreateInfo MakeBufferCreateInfo(size_t nbytes, VkBufferUsageFlags usage) { + VkBufferCreateInfo info = {VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO}; + info.size = nbytes; + // Since sharingMode is not VK_SHARING_MODE_CONCURRENT, no need to + // specify the queue families. + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + info.usage = usage; + return info; +} + +VulkanBuffer::VulkanBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index) + : device_(device) { + // Create a buffer + VkBufferCreateInfo buffer_info = MakeBufferCreateInfo(nbytes, usage); + VULKAN_CALL(vkCreateBuffer(device, &buffer_info, nullptr, &buffer)); + + // Allocate memory + VkMemoryAllocateInfo mem_info = {VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO}; + mem_info.allocationSize = buffer_info.size; + mem_info.memoryTypeIndex = mem_type_index; + + VkMemoryDedicatedAllocateInfoKHR dedicated_info = { + VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR}; + + bool use_dedicated_allocation = UseDedicatedAllocation(device, buffer, &mem_info.allocationSize); + if (use_dedicated_allocation) { + dedicated_info.buffer = buffer; + mem_info.pNext = &dedicated_info; + } + + VULKAN_CALL(vkAllocateMemory(device, &mem_info, nullptr, &memory)); + + // Bind the buffer to the allocated memory + VULKAN_CALL(vkBindBufferMemory(device, buffer, memory, 0)); +} + +VulkanBuffer::~VulkanBuffer() { + if (buffer) { + vkDestroyBuffer(device_, buffer, nullptr); + } + if (memory) { + vkFreeMemory(device_, memory, nullptr); + } +} + +VulkanBuffer::VulkanBuffer(VulkanBuffer&& other) + : device_(other.device_), buffer(other.buffer), memory(other.memory) { + other.device_ = VK_NULL_HANDLE; + other.buffer = VK_NULL_HANDLE; + other.memory = VK_NULL_HANDLE; +} + +VulkanBuffer& VulkanBuffer::operator=(VulkanBuffer&& other) { + std::swap(device_, other.device_); + std::swap(buffer, other.buffer); + std::swap(memory, other.memory); + return *this; +} + +bool VulkanBuffer::UseDedicatedAllocation(const VulkanDevice& device, VkBuffer buffer, + VkDeviceSize* nbytes) { + if (device.get_buffer_memory_requirements_2_functions) { + // Which buffer to request information about + VkBufferMemoryRequirementsInfo2KHR req_info2 = { + VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR}; + req_info2.buffer = buffer; + + // What information to request + VkMemoryDedicatedRequirementsKHR dedicated_req; + dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; + dedicated_req.pNext = 0; + + VkMemoryRequirements2KHR req2 = {VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR}; + req2.pNext = &dedicated_req; + + device.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( + device, &req_info2, &req2); + if (dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation) { + *nbytes = req2.memoryRequirements.size; + return true; + } + } + + return false; +} + +VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(const VulkanDevice& device, size_t nbytes, + VkBufferUsageFlags usage, uint32_t mem_type_index) + : vk_buf(device, nbytes, usage, mem_type_index), size(nbytes) { + VULKAN_CALL(vkMapMemory(device, vk_buf.memory, 0, size, 0, &host_addr)); +} + +VulkanHostVisibleBuffer::~VulkanHostVisibleBuffer() { + if (host_addr) { + vkUnmapMemory(vk_buf.device_, vk_buf.memory); + } +} + +VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(VulkanHostVisibleBuffer&& other) + : vk_buf(std::move(other.vk_buf)), host_addr(other.host_addr), size(other.size) { + other.host_addr = nullptr; + other.size = 0; +} + +VulkanHostVisibleBuffer& VulkanHostVisibleBuffer::operator=(VulkanHostVisibleBuffer&& other) { + std::swap(vk_buf, other.vk_buf); + std::swap(host_addr, other.host_addr); + std::swap(size, other.size); + + return *this; +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_buffer.h b/src/runtime/vulkan/vulkan_buffer.h new file mode 100644 index 000000000000..a3e37431e434 --- /dev/null +++ b/src/runtime/vulkan/vulkan_buffer.h @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_BUFFER_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_BUFFER_H_ + +#include + +#include +#include + +namespace tvm { +namespace runtime { +namespace vulkan { + +class VulkanDevice; + +class VulkanBuffer { + public: + /* \brief Allocate memory on the device + * + * \param device Which device should have the memory allocation. + * The VulkanDevice given should outlive the VulkanBuffer. + * + * \param nbytes Size of the buffer in bytes + * + * \param usage The usage flags for the buffer (e.g. transfer + * source, transfer dest, storage buffer, etc.) + * + * \param mem_type_index The memory type to index. This should be + * an index to a compatible memory located in + * VkPhysicalDeviceMemoryProperties. + */ + VulkanBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index); + + //! \brief Destructor, deallocates the memory and buffer. + ~VulkanBuffer(); + + // Forbid copy assignment/constructor + VulkanBuffer(const VulkanBuffer&) = delete; + VulkanBuffer& operator=(const VulkanBuffer&) = delete; + + // Allow move assignment/constructor + VulkanBuffer(VulkanBuffer&&); + VulkanBuffer& operator=(VulkanBuffer&&); + + private: + /*! \brief Whether this buffer should be allocated using dedicated + * allocation + * + * In typical usage, there will be one VkDeviceMemory that has a + * large number of VkBuffers pointing to it. Currently, the TVM + * Vulkan runtime has a single VkBuffer for each VkDeviceMemory. In + * this case, there can be performance benefits by explicitly + * marking this as a dedicated allocation. The function returns + * true if the device supports the dedicated allocation extension, + * and the buffer either requires or has better performance with a + * dedicated allocation. + * + * \param[out] nbytes If using dedicated allocation, the number of + * bytes required for the allocation. If not using dedicated + * allocation, this value is unchanged. + * + * \returns Whether the allocation should use the dedicated + * allocation extension. + */ + static bool UseDedicatedAllocation(const VulkanDevice& device, VkBuffer buffer, + VkDeviceSize* nbytes); + + // TODO(elunderberg): Move copy functionality into the buffer class + // so these don't need to be public. + public: + /*! \brief Pointer to the device that owns this buffer. + * + * Assumes that the VulkanBuffer will be destructed before the + * VulkanDevice, and this will never be a dangling reference. + * Stores a VkDevice and not a VulkanDevice, because the + * VulkanDevice may be moved to a different location while the + * VulkanBuffer is alive. + */ + VkDevice device_{VK_NULL_HANDLE}; + + //! \brief Handle to the logical buffer on the device + VkBuffer buffer{VK_NULL_HANDLE}; + + //! \brief Handle to the physical device memory + VkDeviceMemory memory{VK_NULL_HANDLE}; + + friend class VulkanHostVisibleBuffer; +}; + +/*! \brief A struct to represent Vulkan buffers backed by host visible memory */ +class VulkanHostVisibleBuffer { + public: + /* \brief Allocate memory on the device, visible to the host + * + * \param device Which GPU device should have the memory allocation. + * The VulkanDevice specified should outlive the VulkanBuffer. + * + * \param nbytes Size of the buffer in bytes + * + * \param usage The usage flags for the buffer (e.g. transfer + * source, transfer dest, storage buffer, etc.) + * + * \param mem_type_index The memory type to index. This should be + * an index to a compatible memory located in + * VkPhysicalDeviceMemoryProperties. + */ + VulkanHostVisibleBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index); + + //! \brief Unmap memory and deallocate. + ~VulkanHostVisibleBuffer(); + + // Forbid copy assignment/constructor + VulkanHostVisibleBuffer(const VulkanHostVisibleBuffer&) = delete; + VulkanHostVisibleBuffer& operator=(const VulkanHostVisibleBuffer&) = delete; + + // Allow move assignment/constructor + VulkanHostVisibleBuffer(VulkanHostVisibleBuffer&&); + VulkanHostVisibleBuffer& operator=(VulkanHostVisibleBuffer&&); + + private: + // TODO(elunderberg): Move copy functionality into the buffer class + // so these don't need to be public. + public: + VulkanBuffer vk_buf; + void* host_addr{nullptr}; + size_t size{0}; +}; + +using VulkanStagingBuffer = VulkanHostVisibleBuffer; +using VulkanUniformBuffer = VulkanHostVisibleBuffer; + +VulkanHostVisibleBuffer* GetOrAllocate( + int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index, + std::unordered_map>* buffers_ptr, + bool sync_before_realloc = false); + +} // namespace vulkan +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_VULKAN_VULKAN_BUFFER_H_ diff --git a/src/runtime/vulkan/vulkan_common.cc b/src/runtime/vulkan/vulkan_common.cc new file mode 100644 index 000000000000..30df8b86ecd5 --- /dev/null +++ b/src/runtime/vulkan/vulkan_common.cc @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_common.h" + +#include + +namespace tvm { +namespace runtime { +namespace vulkan { + +std::vector FindEnabledExtensions( + const std::vector& ext_prop, + const std::vector& required_extensions, + const std::vector& optional_extensions) { + std::set available_extensions; + for (const auto& prop : ext_prop) { + if (prop.specVersion > 0) { + available_extensions.insert(prop.extensionName); + } + } + + std::vector enabled_extensions; + for (const auto& ext : required_extensions) { + ICHECK(available_extensions.count(ext)) + << "Required vulkan extension \"" << ext << "\" not supported by driver"; + enabled_extensions.push_back(ext); + } + + for (const auto& ext : optional_extensions) { + if (available_extensions.count(ext)) { + enabled_extensions.push_back(ext); + } + } + + return enabled_extensions; +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index 2ef879a487a6..a03801cf511f 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -35,6 +36,12 @@ namespace tvm { namespace runtime { namespace vulkan { +/*! \brief Maximum number of GPU supported in VulkanModule. */ +static constexpr const int kVulkanMaxNumDevice = 8; + +/*! \brief TVM Vulkan binary pack magic number */ +static constexpr const int kVulkanModuleMagic = 0x02700027; + const int kMaxPushConstantsBytes = 128; /*! \brief A mask used when we attach additional information to shaders */ @@ -87,10 +94,10 @@ inline const char* VKGetErrorString(VkResult error) { * \brief Protected Vulkan call * \param func Expression to call. */ -#define VULKAN_CHECK_ERROR(__e) \ - { \ - ICHECK(__e == VK_SUCCESS) << "Vulan Error, code=" << __e << ": " \ - << vulkan::VKGetErrorString(__e); \ +#define VULKAN_CHECK_ERROR(__e) \ + { \ + ICHECK(__e == VK_SUCCESS) << "Vulkan Error, code=" << __e << ": " \ + << vulkan::VKGetErrorString(__e); \ } #define VULKAN_CALL(func) \ @@ -99,45 +106,9 @@ inline const char* VKGetErrorString(VkResult error) { VULKAN_CHECK_ERROR(__e); \ } -struct VulkanDescriptorTemplateKHRFunctions { - PFN_vkCreateDescriptorUpdateTemplateKHR vkCreateDescriptorUpdateTemplateKHR{nullptr}; - PFN_vkDestroyDescriptorUpdateTemplateKHR vkDestroyDescriptorUpdateTemplateKHR{nullptr}; - PFN_vkUpdateDescriptorSetWithTemplateKHR vkUpdateDescriptorSetWithTemplateKHR{nullptr}; - PFN_vkCmdPushDescriptorSetWithTemplateKHR vkCmdPushDescriptorSetWithTemplateKHR{nullptr}; -}; - -struct VulkanGetBufferMemoryRequirements2Functions { - PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr}; -}; - -struct VulkanContext { - // phyiscal device - VkPhysicalDevice phy_device{nullptr}; - // Phyiscal device property - VkPhysicalDeviceProperties phy_device_prop; - // Memory type index for staging. - uint32_t staging_mtype_index{0}; - // whether staging is coherent - bool coherent_staging{false}; - - std::unique_ptr descriptor_template_khr_functions{nullptr}; - std::unique_ptr - get_buffer_memory_requirements_2_functions{nullptr}; - // Memory type index for compute - uint32_t compute_mtype_index{0}; - // The logical device - VkDevice device{nullptr}; - // command queue - - std::unique_ptr queue_mutex; - VkQueue queue{nullptr}; - // queue family_index; - uint32_t queue_family_index{0}; - // Queue family index. - VkQueueFamilyProperties queue_prop; - - bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; } -}; +std::vector FindEnabledExtensions(const std::vector& ext_prop, + const std::vector& required_extensions, + const std::vector& optional_extensions); } // namespace vulkan } // namespace runtime diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc new file mode 100644 index 000000000000..5e4be8209550 --- /dev/null +++ b/src/runtime/vulkan/vulkan_device.cc @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_device.h" + +#include +#include +#include +#include + +#include "vulkan_common.h" +#include "vulkan_device.h" +#include "vulkan_device_api.h" +#include "vulkan_instance.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, + const VulkanDevice& device) { + /////////////////////////////////////////////////////////////// + // Query properties from Vulkan API // + /////////////////////////////////////////////////////////////// + + // Declare output locations for properties + VkPhysicalDeviceProperties2 properties = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2}; + VkPhysicalDeviceDriverProperties driver = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES}; + VkPhysicalDeviceSubgroupProperties subgroup = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES}; + + // Need to do initial query in order to check the apiVersion. + vkGetPhysicalDeviceProperties(device, &properties.properties); + + // Set up linked list for property query + { + void** pp_next = &properties.pNext; + if (device.HasExtension("VK_KHR_driver_properties")) { + *pp_next = &driver; + pp_next = &driver.pNext; + } + if (properties.properties.apiVersion >= VK_API_VERSION_1_1) { + *pp_next = &subgroup; + pp_next = &subgroup.pNext; + } + } + + // Declare output locations for features + VkPhysicalDeviceFeatures2 features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; + VkPhysicalDevice8BitStorageFeatures storage_8bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; + VkPhysicalDevice16BitStorageFeatures storage_16bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; + VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; + + // Set up linked list for feature query + { + void** pp_next = &features.pNext; + if (device.HasExtension("VK_KHR_8bit_storage")) { + *pp_next = &storage_8bit; + pp_next = &storage_8bit.pNext; + } + if (device.HasExtension("VK_KHR_16bit_storage")) { + *pp_next = &storage_16bit; + pp_next = &storage_16bit.pNext; + } + if (device.HasExtension("VK_KHR_shader_float16_int8")) { + *pp_next = &float16_int8; + pp_next = &float16_int8.pNext; + } + } + + if (instance.HasExtension("VK_KHR_get_physical_device_properties2")) { + // Preferred method, call to get all properties that can be queried. + auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL( + vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); + vkGetPhysicalDeviceProperties2KHR(device, &properties); + + auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL( + vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR")); + vkGetPhysicalDeviceFeatures2KHR(device, &features); + } else { + // Fallback, get as many features as we can from the Vulkan1.0 + // API. Corresponding vkGetPhysicalDeviceProperties was already done earlier. + vkGetPhysicalDeviceFeatures(device, &features.features); + } + + /////////////////////////////////////////////////////////////// + // Fill member variables from Vulkan structures // + /////////////////////////////////////////////////////////////// + + supports_float16 = float16_int8.shaderFloat16; + supports_float32 = true; + supports_float64 = features.features.shaderFloat64; + supports_int8 = float16_int8.shaderInt8; + supports_int16 = features.features.shaderInt16; + supports_int32 = true; + supports_int64 = features.features.shaderInt64; + supports_8bit_buffer = storage_8bit.storageBuffer8BitAccess; + supports_16bit_buffer = storage_16bit.storageBuffer16BitAccess; + supports_storage_buffer_storage_class = + device.HasExtension("VK_KHR_storage_buffer_storage_class"); + + // Support is available based on these extensions, but allow it to + // be disabled based on an environment variable. + supports_push_descriptor = device.HasExtension("VK_KHR_push_descriptor") && + device.HasExtension("VK_KHR_descriptor_update_template"); + { + const char* disable = std::getenv("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR"); + if (disable && *disable) { + supports_push_descriptor = false; + } + } + + // Support is available based on these extensions, but allow it to + // be disabled based on an environment variable. + supports_dedicated_allocation = device.HasExtension("VK_KHR_get_memory_requirements2") && + device.HasExtension("VK_KHR_dedicated_allocation"); + { + const char* disable = std::getenv("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION"); + if (disable && *disable) { + supports_dedicated_allocation = false; + } + } + + // The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically + // needed, since it will be set so long at least one queue has + // VK_QUEUE_COMPUTE_BIT. Including it to avoid potential future + // confusion.. + supported_subgroup_operations = + (subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0; + + max_num_threads = properties.properties.limits.maxComputeWorkGroupInvocations; + + // Even if we can't query it, warp size must be at least 1. + thread_warp_size = std::max(subgroup.subgroupSize, 1U); + + max_block_size_x = properties.properties.limits.maxComputeWorkGroupSize[0]; + max_block_size_y = properties.properties.limits.maxComputeWorkGroupSize[1]; + max_block_size_z = properties.properties.limits.maxComputeWorkGroupSize[2]; + max_push_constants_size = properties.properties.limits.maxPushConstantsSize; + max_uniform_buffer_range = properties.properties.limits.maxUniformBufferRange; + max_storage_buffer_range = properties.properties.limits.maxStorageBufferRange; + max_per_stage_descriptor_storage_buffer = + properties.properties.limits.maxPerStageDescriptorStorageBuffers; + max_shared_memory_per_block = properties.properties.limits.maxComputeSharedMemorySize; + device_name = properties.properties.deviceName; + driver_version = properties.properties.driverVersion; + + // By default, use the maximum API version that the driver allows, + // so that any supported features can be used by TVM shaders. + // However, if we can query the conformance version, then limit to + // only using the api version that passes the vulkan conformance + // tests. + vulkan_api_version = properties.properties.apiVersion; + if (device.HasExtension("VK_KHR_driver_properties")) { + auto api_major = VK_VERSION_MAJOR(vulkan_api_version); + auto api_minor = VK_VERSION_MINOR(vulkan_api_version); + if ((api_major > driver.conformanceVersion.major) || + ((api_major == driver.conformanceVersion.major) && + (api_minor > driver.conformanceVersion.minor))) { + vulkan_api_version = + VK_MAKE_VERSION(driver.conformanceVersion.major, driver.conformanceVersion.minor, 0); + } + } + + // From "Versions and Formats" section of Vulkan spec. + max_spirv_version = 0x10000; + if (vulkan_api_version >= VK_API_VERSION_1_2) { + max_spirv_version = 0x10500; + } else if (device.HasExtension("VK_KHR_spirv_1_4")) { + max_spirv_version = 0x10400; + } else if (vulkan_api_version >= VK_API_VERSION_1_1) { + max_spirv_version = 0x10300; + } +} + +VulkanDescriptorTemplateKHRFunctions::VulkanDescriptorTemplateKHRFunctions(VkDevice device) { + vkCreateDescriptorUpdateTemplateKHR = (PFN_vkCreateDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR")); + vkDestroyDescriptorUpdateTemplateKHR = (PFN_vkDestroyDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkDestroyDescriptorUpdateTemplateKHR")); + vkUpdateDescriptorSetWithTemplateKHR = (PFN_vkUpdateDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkUpdateDescriptorSetWithTemplateKHR")); + vkCmdPushDescriptorSetWithTemplateKHR = (PFN_vkCmdPushDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkCmdPushDescriptorSetWithTemplateKHR")); +} + +VulkanGetBufferMemoryRequirements2Functions::VulkanGetBufferMemoryRequirements2Functions( + VkDevice device) { + vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2KHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR")); +} + +VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_device) + : physical_device_(phy_device) { + queue_family_index = SelectComputeQueueFamily(); + if (queue_family_index == uint32_t(-1)) { + // The GPU doesn't support compute, cannot use + return; + } + + enabled_extensions = SelectEnabledExtensions(); + device_properties = VulkanDeviceProperties(instance, *this); + CreateVkDevice(instance); + + // Currently, any exceptions called after this point will prevent + // vkDestroyDevice from being called in the destructor. If this + // becomes an issue, can split out the VulkanDevice into two + // classes, one of which strictly holds the VkDevice, and one which + // holds the ancillary handles that TVM needs. + + vkGetDeviceQueue(device_, queue_family_index, 0, &queue); + + // Find suitable memory type for staging and compute + // Find suitable compute index. + VkBuffer buffer; + VkMemoryRequirements req_staging, req_compute; + VkBufferCreateInfo info; + info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + info.pNext = nullptr; + info.flags = 0; + info.size = 1024; + info.queueFamilyIndexCount = 1; + info.pQueueFamilyIndices = &queue_family_index; + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + + // get staging requirement + info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + VULKAN_CALL(vkCreateBuffer(device_, &info, nullptr, &buffer)); + vkGetBufferMemoryRequirements(device_, buffer, &req_staging); + vkDestroyBuffer(device_, buffer, nullptr); + // get compute requirement + info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + VULKAN_CALL(vkCreateBuffer(device_, &info, nullptr, &buffer)); + vkGetBufferMemoryRequirements(device_, buffer, &req_compute); + vkDestroyBuffer(device_, buffer, nullptr); + + // Query phyiscal device property + // find a memory that is host visible, no need to be consistent + int win_rank = -1; + VkPhysicalDeviceMemoryProperties prop; + vkGetPhysicalDeviceMemoryProperties(physical_device_, &prop); + + for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { + VkMemoryType ty = prop.memoryTypes[k]; + size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; + // host visible + if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue; + // match copy requirment + if (!(req_staging.memoryTypeBits & (1 << k))) continue; + if (heap_size < 1024) continue; + int rank = 0; + rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + if (rank > win_rank) { + win_rank = rank; + staging_mtype_index = k; + coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + } + } + ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device."; + + win_rank = -1; + for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { + VkMemoryType ty = prop.memoryTypes[k]; + size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; + // host visible + if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue; + // match copy requirment + if (!(req_staging.memoryTypeBits & (1 << k))) continue; + if (heap_size < 1024) continue; + int rank = 0; + // prefer not host visible + rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); + if (rank > win_rank) { + win_rank = rank; + compute_mtype_index = k; + } + } + ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; + + if (device_properties.supports_push_descriptor) { + descriptor_template_khr_functions = + std::make_unique(device_); + } + + if (device_properties.supports_dedicated_allocation) { + get_buffer_memory_requirements_2_functions = + std::make_unique(device_); + } +} + +VulkanDevice::~VulkanDevice() { + // Need to clear anything that uses this device calling + // vkDestroyDevice. Might be a sign that the VkDevice should be + // held by member variable rather than beind owned directly by + // VulkanDevice. + stream_per_thread.Clear(); + staging_buffer_per_thread.Clear(); + uniform_buffer_per_thread.Clear(); + + if (device_) { + vkDestroyDevice(device_, nullptr); + } +} + +VulkanDevice::VulkanDevice(VulkanDevice&& other) { do_swap(std::move(other)); } + +VulkanDevice& VulkanDevice::operator=(VulkanDevice&& other) { + do_swap(std::move(other)); + return *this; +} + +void VulkanDevice::do_swap(VulkanDevice&& other) { + if (this == &other) { + return; + } + + std::lock(queue_mutex, other.queue_mutex); + std::lock_guard lock_self(queue_mutex, std::adopt_lock); + std::lock_guard lock_other(other.queue_mutex, std::adopt_lock); + + std::swap(device_properties, other.device_properties); + std::swap(staging_mtype_index, other.staging_mtype_index); + std::swap(coherent_staging, other.coherent_staging); + std::swap(descriptor_template_khr_functions, other.descriptor_template_khr_functions); + std::swap(get_buffer_memory_requirements_2_functions, + other.get_buffer_memory_requirements_2_functions); + std::swap(compute_mtype_index, other.compute_mtype_index); + std::swap(queue, other.queue); + std::swap(queue_family_index, other.queue_family_index); + std::swap(physical_device_, other.physical_device_); + std::swap(enabled_extensions, other.enabled_extensions); + std::swap(device_, other.device_); +} + +bool VulkanDevice::SupportsCompute() const { return queue_family_index != uint32_t(-1); } + +void VulkanDevice::QueueSubmit(VkSubmitInfo submit_info, VkFence fence) const { + // Multiple streams (on different threads) use the same VulkanDevice + // instance, so we need to externally synchronize accesses. + std::lock_guard lock(queue_mutex); + VULKAN_CALL(vkQueueSubmit(queue, 1, &submit_info, fence)); +} + +uint32_t VulkanDevice::SelectComputeQueueFamily() const { + // Get a queue family that supports compute. We currently only use + // one queue from one family. + uint32_t queue_prop_count = 0; + vkGetPhysicalDeviceQueueFamilyProperties(physical_device_, &queue_prop_count, nullptr); + std::vector queue_props(queue_prop_count); + vkGetPhysicalDeviceQueueFamilyProperties(physical_device_, &queue_prop_count, + dmlc::BeginPtr(queue_props)); + + std::vector result; + // Prefer compute-only queues. On certain devices supporting this (e.g. Mesa RADV), using + // compute-only queues gives better responsiveness for other graphics workload (e.g. desktop). + for (uint32_t i = 0; i != queue_prop_count; ++i) { + if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && + (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) == 0) { + return i; + } + } + // Now, push the compute queues that we skipped above into the list. + for (uint32_t i = 0; i != queue_prop_count; ++i) { + if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && + (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) != 0) { + return i; + } + } + + // No queues support compute capability, this GPU cannot be used. + return -1; +} + +std::vector VulkanDevice::SelectEnabledExtensions() const { + std::vector required_extensions{}; + std::vector optional_extensions{ + "VK_KHR_driver_properties", + "VK_KHR_storage_buffer_storage_class", + "VK_KHR_8bit_storage", + "VK_KHR_16bit_storage", + "VK_KHR_shader_float16_int8", + "VK_KHR_push_descriptor", + "VK_KHR_descriptor_update_template", + "VK_KHR_get_memory_requirements2", + "VK_KHR_dedicated_allocation", + "VK_KHR_spirv_1_4", + }; + + uint32_t device_extension_prop_count; + VULKAN_CALL(vkEnumerateDeviceExtensionProperties(physical_device_, nullptr, + &device_extension_prop_count, nullptr)); + std::vector device_extension_prop(device_extension_prop_count); + VULKAN_CALL(vkEnumerateDeviceExtensionProperties( + physical_device_, nullptr, &device_extension_prop_count, device_extension_prop.data())); + + return FindEnabledExtensions(device_extension_prop, required_extensions, optional_extensions); +} + +bool VulkanDevice::HasExtension(const char* query) const { + return std::any_of(enabled_extensions.begin(), enabled_extensions.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }); +} + +void VulkanDevice::CreateVkDevice(const VulkanInstance& instance) { + // Enable all features we may use that a device supports. + VkPhysicalDeviceFeatures2 enabled_features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; + VkPhysicalDevice8BitStorageFeatures storage_8bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; + VkPhysicalDevice16BitStorageFeatures storage_16bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; + VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; + + void** pp_next = &enabled_features.pNext; + bool needs_float16_int8 = false; + + if (device_properties.supports_float16) { + float16_int8.shaderFloat16 = true; + needs_float16_int8 = true; + } + if (device_properties.supports_float64) { + enabled_features.features.shaderFloat64 = true; + } + if (device_properties.supports_int8) { + float16_int8.shaderInt8 = true; + needs_float16_int8 = true; + } + if (device_properties.supports_int16) { + enabled_features.features.shaderInt16 = true; + } + if (device_properties.supports_int64) { + enabled_features.features.shaderInt64 = true; + } + if (device_properties.supports_8bit_buffer) { + storage_8bit.storageBuffer8BitAccess = true; + *pp_next = &storage_8bit; + pp_next = &storage_8bit.pNext; + } + if (device_properties.supports_16bit_buffer) { + storage_16bit.storageBuffer16BitAccess = true; + *pp_next = &storage_16bit; + pp_next = &storage_16bit.pNext; + } + + if (needs_float16_int8) { + *pp_next = &float16_int8; + pp_next = &float16_int8.pNext; + } + + float priority = 1.0f; + + struct VkDeviceQueueCreateInfo queue_create_info; + queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + queue_create_info.pNext = nullptr; + queue_create_info.flags = 0; + queue_create_info.queueFamilyIndex = queue_family_index; + queue_create_info.queueCount = 1; + queue_create_info.pQueuePriorities = &priority; + + VkDeviceCreateInfo device_create_info; + device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; + device_create_info.pNext = nullptr; + device_create_info.flags = 0; + device_create_info.queueCreateInfoCount = 1; + device_create_info.pQueueCreateInfos = &queue_create_info; + device_create_info.enabledLayerCount = 0; + device_create_info.ppEnabledLayerNames = nullptr; + device_create_info.enabledExtensionCount = enabled_extensions.size(); + device_create_info.ppEnabledExtensionNames = enabled_extensions.data(); + + if (instance.HasExtension("VK_KHR_get_physical_device_properties2")) { + device_create_info.pEnabledFeatures = nullptr; + device_create_info.pNext = &enabled_features; + } else { + device_create_info.pNext = nullptr; + device_create_info.pEnabledFeatures = &enabled_features.features; + } + VULKAN_CALL(vkCreateDevice(physical_device_, &device_create_info, nullptr, &device_)); +} + +VulkanStream& VulkanDevice::ThreadLocalStream() { + return const_cast(const_cast(this)->ThreadLocalStream()); +} + +const VulkanStream& VulkanDevice::ThreadLocalStream() const { + return stream_per_thread.GetOrMake(this); +} + +VulkanStagingBuffer& VulkanDevice::ThreadLocalStagingBuffer(size_t min_size) { + auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + VulkanStagingBuffer& result = + staging_buffer_per_thread.GetOrMake(*this, min_size, usage, staging_mtype_index); + + if (result.size < min_size) { + result = VulkanStagingBuffer(*this, min_size, usage, staging_mtype_index); + } + + return result; +} + +void VulkanDevice::AllocateThreadLocalUniformBuffer(size_t min_size) { + auto usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; + auto buffer_info = MakeBufferCreateInfo(min_size, usage); + auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + auto mem_type_index = FindMemoryType(*this, buffer_info, prop); + + VulkanUniformBuffer& result = + uniform_buffer_per_thread.GetOrMake(*this, min_size, usage, mem_type_index); + + if (result.size < min_size) { + result = VulkanUniformBuffer(*this, min_size, usage, mem_type_index); + } +} + +VulkanStagingBuffer& VulkanDevice::ThreadLocalUniformBuffer(size_t min_size) { + VulkanStagingBuffer* buffer = uniform_buffer_per_thread.Get(); + ICHECK(buffer) << "Vulkan uniform buffer requested, but not previously allocated."; + ICHECK_GE(buffer->size, min_size) + << "Vulkan uniform buffer of size " << min_size << " requested, but only " << buffer->size + << " was previously allocated."; + return *buffer; +} + +uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, + VkMemoryPropertyFlags req_prop) { + VkBuffer buffer; + VULKAN_CALL(vkCreateBuffer(device, &info, nullptr, &buffer)); + + VkMemoryRequirements mem_reqs; + vkGetBufferMemoryRequirements(device, buffer, &mem_reqs); + uint32_t type_bits = mem_reqs.memoryTypeBits; + VkPhysicalDeviceMemoryProperties phy_mem_prop; + vkGetPhysicalDeviceMemoryProperties(device, &phy_mem_prop); + for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) { + if ((type_bits & 1) == 1 && + (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) { + return i; + } + type_bits >>= 1; + } + LOG(FATAL) << "Requested memory type not found"; + return 0; +} + +VulkanHostVisibleBuffer* GetOrAllocate( + int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index, + std::unordered_map>* buffers_ptr, + bool sync_before_realloc) { + auto& device = VulkanDeviceAPI::Global()->device(device_id); + + auto& buffers = *buffers_ptr; + + bool needs_alloc = !buffers[device_id] || (buffers[device_id]->size < size); + bool is_realloc = buffers[device_id] && (buffers[device_id]->size < size); + if (is_realloc && sync_before_realloc) { + device.ThreadLocalStream().Synchronize(); + } + + if (needs_alloc) { + auto new_buffer = + std::make_unique(device, size, usage, mem_type_index); + buffers[device_id] = std::move(new_buffer); + } + return buffers[device_id].get(); +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h new file mode 100644 index 000000000000..045628bc9092 --- /dev/null +++ b/src/runtime/vulkan/vulkan_device.h @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../thread_map.h" +#include "vulkan/vulkan_core.h" +#include "vulkan_buffer.h" +#include "vulkan_stream.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +class VulkanInstance; +class VulkanDevice; + +struct VulkanDescriptorTemplateKHRFunctions { + explicit VulkanDescriptorTemplateKHRFunctions(VkDevice device); + + PFN_vkCreateDescriptorUpdateTemplateKHR vkCreateDescriptorUpdateTemplateKHR{nullptr}; + PFN_vkDestroyDescriptorUpdateTemplateKHR vkDestroyDescriptorUpdateTemplateKHR{nullptr}; + PFN_vkUpdateDescriptorSetWithTemplateKHR vkUpdateDescriptorSetWithTemplateKHR{nullptr}; + PFN_vkCmdPushDescriptorSetWithTemplateKHR vkCmdPushDescriptorSetWithTemplateKHR{nullptr}; +}; + +struct VulkanGetBufferMemoryRequirements2Functions { + explicit VulkanGetBufferMemoryRequirements2Functions(VkDevice device); + + PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr}; +}; + +/*! + * \brief Stores the capabilities/limits queried from the physical device. + * + * The member variables here have a 1-1 mapping to Target parameters, + * if target->kind->device_type==kDLVulkan. A separate struct is used + * to maintain the boundary between the Vulkan runtime in + * libtvm_runtime.so, and the Target object in libtvm.so. + */ +struct VulkanDeviceProperties { + VulkanDeviceProperties() {} + VulkanDeviceProperties(const VulkanInstance& instance, const VulkanDevice& device); + + bool supports_float16{false}; + bool supports_float32{true}; + bool supports_float64{false}; + bool supports_int8{false}; + bool supports_int16{false}; + bool supports_int32{true}; + bool supports_int64{false}; + bool supports_8bit_buffer{false}; + bool supports_16bit_buffer{false}; + bool supports_storage_buffer_storage_class{false}; + bool supports_push_descriptor{false}; + bool supports_dedicated_allocation{false}; + uint32_t supported_subgroup_operations{0}; + uint32_t max_num_threads{1}; + uint32_t thread_warp_size{1}; + uint32_t max_block_size_x{1}; + uint32_t max_block_size_y{1}; + uint32_t max_block_size_z{1}; + uint32_t max_push_constants_size{128}; + uint32_t max_uniform_buffer_range{16384}; + uint32_t max_storage_buffer_range{1 << 27}; + uint32_t max_per_stage_descriptor_storage_buffer{4}; + uint32_t max_shared_memory_per_block{16384}; + std::string device_name{"unknown device name"}; + uint32_t driver_version{0}; + uint32_t vulkan_api_version{VK_API_VERSION_1_0}; + uint32_t max_spirv_version{0x10000}; +}; + +/*! \brief Handle to the Vulkan API's VkDevice + * + * Handles all setup and teardown of the class. The owner of the + * VulkanDevice object is responsible for ensuring that it remains + * alive as long as any object that accesses that device is used. + */ +class VulkanDevice { + public: + VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_dev); + ~VulkanDevice(); + + // Allow move constructor/assignment + VulkanDevice(VulkanDevice&&); + VulkanDevice& operator=(VulkanDevice&&); + + // Disable copy constructor/assignment + VulkanDevice(const VulkanDevice&) = delete; + VulkanDevice& operator=(const VulkanDevice&) = delete; + + /*! \brief Expose the internal VkDevice + * + * Allows the managed class to be passed to Vulkan APIs as if it + * were the VkDevice handler itself. + */ + operator VkDevice() const { return device_; } + + /*! \brief Expose the internal VkPhysicalDevice + * + * Allows the managed class to be passed to Vulkan APIs as if it + * were the VkPhysicalDevice handler itself. + */ + operator VkPhysicalDevice() const { return physical_device_; } + + /*! \brief Returns whether this device supports Vulkan compute operations. + * + * If the device does not support Vulkan compute operations, it + * should not be used any further. + */ + bool SupportsCompute() const; + + /*! \brief Calls vkQueueSubmit to run work on the GPU + * + * Currently only supports submitting a single VkSubmitInfo at a + * time. Handles mutexing internally, safe to call from multiple + * CPU threads. + * + * \param submit_info The job submission information to be passed to + * vkQueueSubmit. + * + * \param fence Optional fence to be passed to vkQueueSubmit, + * signals once the command buffers submitted have completed. + */ + void QueueSubmit(VkSubmitInfo submit_info, VkFence fence) const; + + /*! \brief Checks if the device has an extension enabled + * + * Returns true if the device was initialized with the extension + * given. + * + * \param query The name of the extension to check. + */ + bool HasExtension(const char* query) const; + + //! \brief Return the VulkanStream for the current CPU thread + VulkanStream& ThreadLocalStream(); + + //! \brief Return the VulkanStream for the current CPU thread + const VulkanStream& ThreadLocalStream() const; + + /*! \brief Return the staging buffer for the current CPU thread + * + * This function may re-allocate the staging buffer depending on the + * size of the previously allocated buffer. + * + * \param min_size The size in bytes of the staging buffer to be + * returned. The buffer may be larger than requested, depending on + * previous use. + */ + VulkanStagingBuffer& ThreadLocalStagingBuffer(size_t min_size); + + /*! \brief Allocate the uniform buffer for the current CPU thread + * + * \param min_size The minimum size in bytes of the uniformn buffer + * to be allocated. If a larger uniform buffer has already been + * allocated, no allocation is performed. + */ + void AllocateThreadLocalUniformBuffer(size_t min_size); + + /*! \brief Return the uniform buffer for the current CPU thread + * + * Assumes that AllocateThreadLocalUniformBuffer has previously been + * called, with a min_size greater than or equal to the min_size of + * the current call. If this is not the case, will throw an + * exception. + * + * \param min_size The minimum size in bytes of the uniform buffer to be + * returned. + */ + VulkanUniformBuffer& ThreadLocalUniformBuffer(size_t min_size); + + // Cached device properties, queried through Vulkan API. + VulkanDeviceProperties device_properties{}; + + // Memory type index for staging. + uint32_t staging_mtype_index{0}; + // whether staging is coherent + bool coherent_staging{false}; + + std::unique_ptr descriptor_template_khr_functions{nullptr}; + std::unique_ptr + get_buffer_memory_requirements_2_functions{nullptr}; + // Memory type index for compute + uint32_t compute_mtype_index{0}; + + // queue family_index; + uint32_t queue_family_index{uint32_t(-1)}; + + bool UseImmediate() const { return descriptor_template_khr_functions != nullptr; } + + private: + /*! \brief Helper function for move assignment/construction + * + * Named "do_swap" instead of "swap" because otherwise cpplint.py + * thinks that it needs the header include. + */ + void do_swap(VulkanDevice&& other); + + /*! \brief Returns a queue family capable of running Vulkan compute + * operations + */ + uint32_t SelectComputeQueueFamily() const; + + /*! \brief Returns the extensions to be enabled. + * + * All char* in the returned vector point to static memory + * allocations, and do not require cleanup. + */ + std::vector SelectEnabledExtensions() const; + + /*! \brief Initialize the VkDevice + * + * Called during VulkanDevice construction. Assumes that + * queue_family_index, device_properties, and enabled_extensions + * have been set. + */ + void CreateVkDevice(const VulkanInstance& instance); + + //! \brief Handle to the Vulkan API physical device + VkPhysicalDevice physical_device_{nullptr}; + + /*! \brief Extensions enabled for this device + * + * Based on supported extensions queried from physical_device_ prior + * to creating device_. Contains only statically allocated string + * literals, no cleanup required. + */ + std::vector enabled_extensions; + + //! \brief Handle to the Vulkan API logical device + VkDevice device_{nullptr}; + + //! \brief Mutex to protect access to queue + mutable std::mutex queue_mutex; + + /*! \brief Handle to Vulkan API VkQueue. + * + * Work can be executed by submitted to this queue using + * VulkanDevice::QueueSubmit. + */ + VkQueue queue{nullptr}; + + /*! \brief The VulkanStream for each CPU thread. + * + * To mimic the semantics of cudaSetDevice and cuLaunchKernel, each + * CPU thread must have a separate stream of execution. The + * ThreadMap is declared mutable so that the streams can be lazily + * generated. + */ + mutable ThreadMap stream_per_thread; + + //! \brief The VulkanStagingBuffer for each CPU thread. + ThreadMap staging_buffer_per_thread; + + //! \brief The VulkanUniformBuffer for each CPU thread. + ThreadMap uniform_buffer_per_thread; +}; + +uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, + VkMemoryPropertyFlags req_prop); + +VkBufferCreateInfo MakeBufferCreateInfo(size_t nbytes, VkBufferUsageFlags usage); + +} // namespace vulkan +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc new file mode 100644 index 000000000000..1fede98f7211 --- /dev/null +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -0,0 +1,416 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_device_api.h" + +#include +#include +#include +#include +#include + +#include "vulkan_common.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanDeviceAPI* VulkanDeviceAPI::Global() { + // Most of the TVM Global() functions allocate with "new" and do + // not deallocate, as the OS can clean up any leftover buffers at + // the end. In this case, we need the VulkanDeviceAPI destructor + // to call vkDestroyInstance, to prevent a segfault on exit when + // using some nvidia drivers. + static VulkanDeviceAPI inst; + return &inst; +} + +VulkanDeviceAPI::VulkanDeviceAPI() { + std::vector vulkan_physical_devices = instance_.GetPhysicalDevices(); + for (VkPhysicalDevice phy_dev : vulkan_physical_devices) { + VulkanDevice device(instance_, phy_dev); + + if (device.SupportsCompute()) { + devices_.push_back(std::move(device)); + } + } +} + +VulkanDeviceAPI::~VulkanDeviceAPI() {} + +void VulkanDeviceAPI::SetDevice(Device dev) { + ICHECK_EQ(dev.device_type, kDLVulkan) + << "Active vulkan device cannot be set to non-vulkan device" << dev; + + ICHECK_LE(dev.device_id, static_cast(devices_.size())) + << "Attempted to set active vulkan device to device_id==" << dev.device_id << ", but only " + << devices_.size() << " devices present"; + + active_device_id_per_thread.GetOrMake(0) = dev.device_id; +} + +int VulkanDeviceAPI::GetActiveDeviceID() { return active_device_id_per_thread.GetOrMake(0); } + +VulkanDevice& VulkanDeviceAPI::GetActiveDevice() { return device(GetActiveDeviceID()); } + +void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { + size_t index = static_cast(dev.device_id); + if (kind == kExist) { + *rv = static_cast(index < devices_.size()); + return; + } + + const auto& prop = device(index).device_properties; + + switch (kind) { + case kMaxThreadsPerBlock: { + *rv = int64_t(prop.max_num_threads); + break; + } + case kMaxSharedMemoryPerBlock: { + *rv = int64_t(prop.max_shared_memory_per_block); + break; + } + case kWarpSize: { + *rv = int64_t(prop.thread_warp_size); + break; + } + case kComputeVersion: { + int64_t value = prop.vulkan_api_version; + std::ostringstream os; + os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." + << VK_VERSION_PATCH(value); + *rv = os.str(); + break; + } + case kDeviceName: + *rv = prop.device_name; + break; + + case kMaxClockRate: + break; + + case kMultiProcessorCount: + break; + + case kExist: + break; + + case kMaxThreadDimensions: { + std::stringstream ss; // use json string to return multiple int values; + ss << "[" << prop.max_block_size_x << ", " << prop.max_block_size_y << ", " + << prop.max_block_size_z << "]"; + *rv = ss.str(); + break; + } + + case kMaxRegistersPerBlock: + break; + + case kGcnArch: + break; + + case kApiVersion: + *rv = VK_HEADER_VERSION; + break; + + case kDriverVersion: { + int64_t value = prop.driver_version; + std::ostringstream os; + os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." + << VK_VERSION_PATCH(value); + *rv = os.str(); + break; + } + } +} + +void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) { + size_t index = static_cast(dev.device_id); + const auto& prop = device(index).device_properties; + + if (property == "supports_float16") { + *rv = prop.supports_float16; + } + if (property == "supports_float32") { + *rv = prop.supports_float32; + } + if (property == "supports_float64") { + *rv = prop.supports_float64; + } + if (property == "supports_int8") { + *rv = prop.supports_int8; + } + if (property == "supports_int16") { + *rv = prop.supports_int16; + } + if (property == "supports_int32") { + *rv = prop.supports_int32; + } + if (property == "supports_int64") { + *rv = prop.supports_int64; + } + if (property == "supports_8bit_buffer") { + *rv = prop.supports_8bit_buffer; + } + if (property == "supports_16bit_buffer") { + *rv = prop.supports_16bit_buffer; + } + if (property == "supports_storage_buffer_storage_class") { + *rv = prop.supports_storage_buffer_storage_class; + } + if (property == "supports_push_descriptor") { + *rv = prop.supports_push_descriptor; + } + if (property == "supports_dedicated_allocation") { + *rv = prop.supports_dedicated_allocation; + } + if (property == "supported_subgroup_operations") { + *rv = int64_t(prop.supported_subgroup_operations); + } + if (property == "max_num_threads") { + *rv = int64_t(prop.max_num_threads); + } + if (property == "thread_warp_size") { + *rv = int64_t(prop.thread_warp_size); + } + if (property == "max_block_size_x") { + *rv = int64_t(prop.max_block_size_x); + } + if (property == "max_block_size_y") { + *rv = int64_t(prop.max_block_size_y); + } + if (property == "max_block_size_z") { + *rv = int64_t(prop.max_block_size_z); + } + if (property == "max_push_constants_size") { + *rv = int64_t(prop.max_push_constants_size); + } + if (property == "max_uniform_buffer_range") { + *rv = int64_t(prop.max_uniform_buffer_range); + } + if (property == "max_storage_buffer_range") { + *rv = int64_t(prop.max_storage_buffer_range); + } + if (property == "max_per_stage_descriptor_storage_buffer") { + *rv = int64_t(prop.max_per_stage_descriptor_storage_buffer); + } + if (property == "max_shared_memory_per_block") { + *rv = int64_t(prop.max_shared_memory_per_block); + } + if (property == ":string device_name") { + *rv = prop.device_name; + } + if (property == "driver_version") { + *rv = int64_t(prop.driver_version); + } + if (property == "vulkan_api_version") { + *rv = int64_t(prop.vulkan_api_version); + } + if (property == "max_spirv_version") { + *rv = int64_t(prop.max_spirv_version); + } +} + +void* VulkanDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignment, + DLDataType type_hint) { + if (nbytes == 0) { + // Vulkan seems to have issues if we return nullptr on zero size alloc + nbytes = 1; + } + const auto& device = this->device(dev.device_id); + auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + return new VulkanBuffer(device, nbytes, usage, device.compute_mtype_index); +} + +void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { + // Before releasing the vkBuffer, call sync to + // finish all the vulkan commands that reference the buffer. + StreamSync(dev, nullptr); + + auto* pbuf = static_cast(ptr); + delete pbuf; +} + +void* VulkanDeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { + auto& pool = pool_per_thread.GetOrMake(kDLVulkan, this); + return pool.AllocWorkspace(dev, size); +} + +void VulkanDeviceAPI::FreeWorkspace(Device dev, void* data) { + auto* pool = pool_per_thread.Get(); + ICHECK(pool) << "Attempted to free a vulkan workspace on a CPU-thread " + << "that has never allocated a workspace"; + pool->FreeWorkspace(dev, data); +} + +TVMStreamHandle VulkanDeviceAPI::CreateStream(Device dev) { return nullptr; } + +void VulkanDeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) { + ICHECK_EQ(stream, static_cast(nullptr)); +} + +// Syncing two streams is a nop, since there is only one stream. +void VulkanDeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, + TVMStreamHandle event_dst) { + ICHECK_EQ(event_src, static_cast(nullptr)); + ICHECK_EQ(event_dst, static_cast(nullptr)); +} + +void VulkanDeviceAPI::StreamSync(Device dev, TVMStreamHandle stream) { + ICHECK_EQ(stream, static_cast(nullptr)); + device(dev.device_id).ThreadLocalStream().Synchronize(); +} + +void VulkanDeviceAPI::SetStream(Device dev, TVMStreamHandle stream) { + ICHECK_EQ(stream, static_cast(nullptr)); +} + +void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, Device dev_from, Device dev_to, + DLDataType type_hint, TVMStreamHandle stream) { + ICHECK(stream == nullptr); + Device dev = dev_from; + if (dev_from.device_type == kDLCPU) { + dev = dev_to; + } + + int from_dev_type = static_cast(dev_from.device_type); + int to_dev_type = static_cast(dev_to.device_type); + if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) { + ICHECK_EQ(dev_from.device_id, dev_to.device_id) + << "The Vulkan runtime does not support deviceA to deviceB copies. " + << "This should be changed to a deviceA to CPU copy, followed by a CPU to deviceB copy"; + + device(dev_from.device_id).ThreadLocalStream().Launch([=](VulkanStreamState* state) { + // 1: copy + const auto* from_buf = static_cast(from); + auto* to_buf = static_cast(to); + VkBufferCopy copy_info; + copy_info.srcOffset = from_offset; + copy_info.dstOffset = to_offset; + copy_info.size = size; + vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info); + // 2: barrier(transfer-> compute|transfer) + VkMemoryBarrier barrier_info; + barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + barrier_info.pNext = nullptr; + barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; + barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | + VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, + 1, &barrier_info, 0, nullptr, 0, nullptr); + }); + + } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) { + const auto* from_buf = static_cast(from); + auto& device = this->device(dev_from.device_id); + auto& stream = device.ThreadLocalStream(); + auto& staging_buffer = device.ThreadLocalStagingBuffer(size); + stream.Launch([&](VulkanStreamState* state) { + VkBufferCopy copy_info; + copy_info.srcOffset = from_offset; + copy_info.dstOffset = 0; + copy_info.size = size; + vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, staging_buffer.vk_buf.buffer, 1, + ©_info); + }); + stream.Synchronize(); + if (!device.coherent_staging) { + VkMappedMemoryRange mrange; + mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; + mrange.pNext = nullptr; + mrange.memory = staging_buffer.vk_buf.memory; + mrange.offset = 0; + mrange.size = VK_WHOLE_SIZE; // size; + VULKAN_CALL(vkInvalidateMappedMemoryRanges(device, 1, &mrange)); + } + memcpy(static_cast(to) + to_offset, static_cast(staging_buffer.host_addr), size); + } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) { + auto& device = this->device(dev_to.device_id); + auto& stream = device.ThreadLocalStream(); + const auto* to_buf = static_cast(to); + auto& staging_buffer = device.ThreadLocalStagingBuffer(size); + memcpy(staging_buffer.host_addr, static_cast(from) + from_offset, size); + // host side flush if access is not coherent. + // so writes from CPU is visible to GPU + if (!device.coherent_staging) { + VkMappedMemoryRange mrange; + mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; + mrange.pNext = nullptr; + mrange.memory = staging_buffer.vk_buf.memory; + mrange.offset = 0; + mrange.size = VK_WHOLE_SIZE; // size; + VULKAN_CALL(vkFlushMappedMemoryRanges(device, 1, &mrange)); + } + + stream.Launch([&](VulkanStreamState* state) { + // 0: barrier(host->transfer) + VkMemoryBarrier barrier_info; + barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + barrier_info.pNext = nullptr; + barrier_info.srcAccessMask = 0; + barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0, + nullptr); + // 1: copy + VkBufferCopy copy_info; + copy_info.srcOffset = 0; + copy_info.dstOffset = to_offset; + copy_info.size = size; + vkCmdCopyBuffer(state->cmd_buffer_, staging_buffer.vk_buf.buffer, to_buf->buffer, 1, + ©_info); + }); + // TODO(tulloch): should we instead make the staging buffer a property of the + // Stream? This would allow us to elide synchronizations here. + stream.Synchronize(); + } else { + LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan" + << ", from=" << from_dev_type << ", to=" << to_dev_type; + } +} + +const VulkanDevice& VulkanDeviceAPI::device(size_t device_id) const { + ICHECK_LT(device_id, devices_.size()) << "Requested Vulkan device_id=" << device_id + << ", but only " << devices_.size() << " devices present"; + return devices_[device_id]; +} + +VulkanDevice& VulkanDeviceAPI::device(size_t device_id) { + return const_cast(const_cast(this)->device(device_id)); +} + +TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = VulkanDeviceAPI::Global(); + *rv = static_cast(ptr); +}); + +TVM_REGISTER_GLOBAL("device_api.vulkan.get_target_property") + .set_body_typed([](Device dev, const std::string& property) { + TVMRetValue rv; + VulkanDeviceAPI::Global()->GetTargetProperty(dev, property, &rv); + return rv; + }); + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h new file mode 100644 index 000000000000..b8be3eb43c79 --- /dev/null +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_DEVICE_API_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_DEVICE_API_H_ + +#include +#include + +#include +#include + +#include "../thread_map.h" +#include "../workspace_pool.h" +#include "vulkan/vulkan_core.h" +#include "vulkan_device.h" +#include "vulkan_instance.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +class VulkanDeviceAPI final : public DeviceAPI { + public: + static VulkanDeviceAPI* Global(); + VulkanDeviceAPI(); + ~VulkanDeviceAPI(); + + // Implement active device + void SetDevice(Device dev) final; + void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; + + // Implement memory management required by DeviceAPI + void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final; + void FreeDataSpace(Device dev, void* ptr) final; + void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; + void FreeWorkspace(Device dev, void* data) final; + + // Current vulkan implementation has one "stream" per CPU thread, + // with all commands writing into a single command buffer that is + // submitted on a call to StreamSync. Therefore, for now, these are + // mostly no-ops. If needed in the future, could have multiple + // command buffers to act as multiple streams. + TVMStreamHandle CreateStream(Device dev) final; + void FreeStream(Device dev, TVMStreamHandle stream) final; + void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) final; + void StreamSync(Device dev, TVMStreamHandle stream) final; + void SetStream(Device dev, TVMStreamHandle stream) final; + + protected: + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + Device dev_from, Device dev_to, DLDataType type_hint, + TVMStreamHandle stream) final; + + // End of required methods for the DeviceAPI interface + + public: + /*! \brief Return the currently active VulkanDevice + * + * The active device can be set using VulkanDeviceAPI::SetDevice. + * Each CPU thread has its own active device, mimicking the + * semantics of cudaSetDevice. + */ + VulkanDevice& GetActiveDevice(); + + /*! \brief Return the currently active VulkanDevice + * + * The active device can be set using VulkanDeviceAPI::SetDevice. + * Each CPU thread has its own active device, mimicking the + * semantics of cudaSetDevice. + */ + int GetActiveDeviceID(); + + /*! \brief Return the VulkanDevice associated with a specific device_id + * + * These are constructed during VulkanDeviceAPI initialization, so + * this function returns immediately. + */ + const VulkanDevice& device(size_t device_id) const; + + /*! \brief Return the VulkanDevice associated with a specific device_id + * + * These are constructed during VulkanDeviceAPI initialization, so + * this function returns immediately. + */ + VulkanDevice& device(size_t device_id); + + /*! \brief Returns a property to be stored in a target. + * + * Returns the results of feature/property queries done during the + * device initialization. + */ + void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv); + + private: + std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); + + /*! \brief The Vulkan API instance owned by the VulkanDeviceAPI + * + * Holds and manages VkInstance. + */ + VulkanInstance instance_; + + /*! \brief Handles to the Vulkan devices + * + * The physical devices. These are constructed after the instance_, + * and must be destructed before the instance_. + */ + std::vector devices_; + + /*! \brief One pool of device memory for each CPU thread. + * + * These allocate memory based on the devices stored in devices_. + * The memory pools must be destructed before devices_. + */ + ThreadMap pool_per_thread; + + /*! \brief The index of the active device for each CPU thread. + * + * To mimic the semantics of cudaSetDevice, each CPU thread can set + * the device on which functions should run. If unset, the active + * device defaults to device_id == 0. + */ + ThreadMap active_device_id_per_thread; +}; + +} // namespace vulkan +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_VULKAN_VULKAN_DEVICE_API_H_ diff --git a/src/runtime/vulkan/vulkan_instance.cc b/src/runtime/vulkan/vulkan_instance.cc new file mode 100644 index 000000000000..351319e0e898 --- /dev/null +++ b/src/runtime/vulkan/vulkan_instance.cc @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_instance.h" + +#include +#include + +#include "vulkan_common.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanInstance::VulkanInstance() { + const auto layers = []() { + std::vector layers; + + const char* validation_enabled_env = std::getenv("TVM_VULKAN_ENABLE_VALIDATION_LAYERS"); + bool validation_enabled = validation_enabled_env && *validation_enabled_env; + + if (validation_enabled) { + uint32_t inst_layer_prop_count; + VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr)); + std::vector inst_layer_prop(inst_layer_prop_count); + VULKAN_CALL( + vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data())); + + for (const auto& lp : inst_layer_prop) { + if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) { + layers.push_back("VK_LAYER_LUNARG_standard_validation"); + } + if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) { + layers.push_back("VK_LAYER_LUNARG_parameter_validation"); + } + if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) { + layers.push_back("VK_LAYER_KHRONOS_validation"); + } + } + } + return layers; + }(); + + { + std::vector required_extensions{}; + std::vector optional_extensions{"VK_KHR_get_physical_device_properties2"}; + + uint32_t inst_extension_prop_count; + VULKAN_CALL( + vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr)); + std::vector inst_extension_prop(inst_extension_prop_count); + VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, + inst_extension_prop.data())); + + enabled_extensions_ = + FindEnabledExtensions(inst_extension_prop, required_extensions, optional_extensions); + } + + uint32_t api_version = VK_MAKE_VERSION(1, 0, 0); + { + // Result from vkGetInstanceProcAddr is NULL if driver only + // supports vulkan 1.0. + auto vkEnumerateInstanceVersion = + (PFN_vkEnumerateInstanceVersion)vkGetInstanceProcAddr(NULL, "vkEnumerateInstanceVersion"); + if (vkEnumerateInstanceVersion) { + vkEnumerateInstanceVersion(&api_version); + } + } + + { + VkApplicationInfo app_info; + app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + app_info.pNext = nullptr; + app_info.pApplicationName = "TVM"; + app_info.applicationVersion = 0; + app_info.pEngineName = ""; + app_info.engineVersion = 0; + app_info.apiVersion = api_version; + + VkInstanceCreateInfo inst_info; + inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + inst_info.pNext = nullptr; + inst_info.flags = 0; + inst_info.pApplicationInfo = &app_info; + inst_info.enabledLayerCount = layers.size(); + inst_info.ppEnabledLayerNames = layers.data(); + inst_info.enabledExtensionCount = enabled_extensions_.size(); + inst_info.ppEnabledExtensionNames = enabled_extensions_.data(); + + VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_)); + } +} + +VulkanInstance::~VulkanInstance() { + if (instance_) { + vkDestroyInstance(instance_, nullptr); + } +} + +VulkanInstance::VulkanInstance(VulkanInstance&& other) { do_swap(std::move(other)); } + +VulkanInstance& VulkanInstance::operator=(VulkanInstance&& other) { + do_swap(std::move(other)); + return *this; +} + +void VulkanInstance::do_swap(VulkanInstance&& other) { + if (this == &other) { + return; + } + + std::swap(enabled_extensions_, other.enabled_extensions_); + std::swap(instance_, other.instance_); +} + +bool VulkanInstance::HasExtension(const char* query) const { + return std::any_of(enabled_extensions_.begin(), enabled_extensions_.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }); +} + +std::vector VulkanInstance::GetPhysicalDevices() const { + uint32_t device_count = 0; + VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &device_count, nullptr)); + std::vector devices(device_count); + VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &device_count, devices.data())); + return devices; +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_instance.h b/src/runtime/vulkan/vulkan_instance.h new file mode 100644 index 000000000000..06016d8f0aea --- /dev/null +++ b/src/runtime/vulkan/vulkan_instance.h @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_INSTANCE_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_INSTANCE_H_ + +#include + +#include "vulkan/vulkan_core.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +class VulkanInstance { + public: + VulkanInstance(); + ~VulkanInstance(); + + // Allow move assignment/construction + VulkanInstance(VulkanInstance&&); + VulkanInstance& operator=(VulkanInstance&&); + + // Forbid copy assignment/construction + VulkanInstance(const VulkanInstance&) = delete; + VulkanInstance& operator=(const VulkanInstance&) = delete; + + /*! \brief Expose the internal VkInstance + * + * Allows the managed class to be passed to Vulkan APIs as if it + * were the VkInstance handler itself. + */ + operator VkInstance() const { return instance_; } + + /*! \brief Checks if the device has an extension enabled + * + * Returns true if the device was initialized with the extension + * given. + * + * \param query The name of the extension to check. + */ + bool HasExtension(const char* query) const; + + /*! \brief Return all accessible physical devices + * + * Wrapper around vkEnumeratePhysicalDevices. + */ + std::vector GetPhysicalDevices() const; + + private: + /*! \brief Helper function for move assignment/construction + * + * Named "do_swap" instead of "swap" because otherwise cpplint.py + * thinks that it needs the header include. + */ + void do_swap(VulkanInstance&& other); + + /*! \brief Extensions enabled for this instance + * + * Based on supported extensions queried through + * vkEnumerateInstanceExtensionProperties, prior to creating + * instance_. Contains only statically allocated string literals, + * no cleanup required. + */ + std::vector enabled_extensions_; + + //! \brief The Vulkan API instance handle + VkInstance instance_{nullptr}; +}; + +} // namespace vulkan +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_VULKAN_VULKAN_INSTANCE_H_ diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc new file mode 100644 index 000000000000..89104d9d63d9 --- /dev/null +++ b/src/runtime/vulkan/vulkan_module.cc @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_module.h" + +#include +#include + +#include "../file_utils.h" +#include "vulkan_wrapped_func.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +Module VulkanModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string source) { + auto n = make_object(smap, fmap, source); + return Module(n); +} + +Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) { + std::string data; + std::unordered_map smap; + std::unordered_map fmap; + std::string fmt = GetFileFormat(file_name, format); + std::string meta_file = GetMetaFilePath(file_name); + LoadBinaryFromFile(file_name, &data); + LoadMetaDataFromFile(meta_file, &fmap); + dmlc::MemoryStringStream fs(&data); + dmlc::Stream* stream = &fs; + uint32_t magic; + stream->Read(&magic); + ICHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch"; + stream->Read(&smap); + return VulkanModuleCreate(smap, fmap, ""); +} + +Module VulkanModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::unordered_map smap; + std::unordered_map fmap; + + std::string fmt; + stream->Read(&fmt); + stream->Read(&fmap); + stream->Read(&smap); + return VulkanModuleCreate(smap, fmap, ""); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_stream.cc b/src/runtime/vulkan/vulkan_stream.cc new file mode 100644 index 000000000000..3eff112a6eea --- /dev/null +++ b/src/runtime/vulkan/vulkan_stream.cc @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_stream.h" + +#include "vulkan_device.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanStream::VulkanStream(const VulkanDevice* device) + : device_(device), state_(new VulkanStreamState()) { + // create command pool + VkCommandPoolCreateInfo cmd_pool_cinfo; + cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; + cmd_pool_cinfo.pNext = nullptr; + cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT; + cmd_pool_cinfo.queueFamilyIndex = device_->queue_family_index; + VULKAN_CALL(vkCreateCommandPool(*device_, &cmd_pool_cinfo, nullptr, &cmd_pool_)); + + VkCommandBufferAllocateInfo buffer_alloc_info; + buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + buffer_alloc_info.pNext = nullptr; + buffer_alloc_info.commandPool = cmd_pool_; + buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + buffer_alloc_info.commandBufferCount = 1; + VULKAN_CALL(vkAllocateCommandBuffers(*device_, &buffer_alloc_info, &(state_->cmd_buffer_))); + + VkFenceCreateInfo fence_cinfo; + fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; + fence_cinfo.pNext = nullptr; + fence_cinfo.flags = 0; // VK_FENCE_CREATE_SIGNALED_BIT; + VULKAN_CALL(vkCreateFence(*device_, &fence_cinfo, nullptr, &(state_->fence_))); + + VkCommandBufferBeginInfo cb_begin; + cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + cb_begin.pNext = nullptr; + cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; + cb_begin.pInheritanceInfo = 0; + VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin)); +} + +VulkanStream::~VulkanStream() { + vkDestroyFence(*device_, state_->fence_, nullptr); + vkDestroyCommandPool(*device_, cmd_pool_, nullptr); +} + +void VulkanStream::Launch(const std::function& kernel) { + if (device_->UseImmediate()) { + kernel(state_.get()); + } else { + deferred_kernels_.push_back(kernel); + } +} + +void VulkanStream::LaunchDeferred(const std::function& deferred_initializer, + const std::function& deferred_kernel, + const VulkanStreamToken& deferred_token) { + ICHECK(!device_->UseImmediate()); + + // If the new kernel uses the same descriptor set as one of the + // kernels already in the command buffer, we need to synchronize + // first. + if (std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(), + deferred_tokens_[deferred_token.descriptor_set_].end(), + [&](const VulkanStreamToken& token) { + DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); + return token.descriptor_set_ == deferred_token.descriptor_set_ && + token.buffers_ != deferred_token.buffers_; + })) { + Synchronize(); + } + + // If the new kernel uses the same buffers in the same descriptor + // set as an already-queued kernel, we don't need to initialize it + // again. Since every VulkanWrappedFunc owns a single descriptor + // set, unless the same function is called with the same buffer + // arguments, deferred_initializer() will always be called. + if (!std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(), + deferred_tokens_[deferred_token.descriptor_set_].end(), + [&](const VulkanStreamToken& token) { + DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); + return token.descriptor_set_ == deferred_token.descriptor_set_ && + token.buffers_ == deferred_token.buffers_; + })) { + deferred_initializer(); + } + + // Save the kernel itself to be called later. + deferred_kernels_.push_back(deferred_kernel); + deferred_tokens_[deferred_token.descriptor_set_].push_back(deferred_token); +} + +void VulkanStream::Synchronize() { + if (!device_->UseImmediate()) { + for (const auto& deferred_kernel : deferred_kernels_) { + deferred_kernel(state_.get()); + } + deferred_kernels_.clear(); + deferred_tokens_.clear(); + } else { + DCHECK_EQ(deferred_kernels_.size(), 0); + DCHECK_EQ(deferred_tokens_.size(), 0); + } + + VULKAN_CALL(vkEndCommandBuffer(state_->cmd_buffer_)); + VkSubmitInfo cb_submit; + cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; + cb_submit.pNext = nullptr; + cb_submit.waitSemaphoreCount = 0; + cb_submit.pWaitSemaphores = nullptr; + cb_submit.pWaitDstStageMask = 0; + cb_submit.commandBufferCount = 1; + cb_submit.pCommandBuffers = &(state_->cmd_buffer_); + cb_submit.signalSemaphoreCount = 0; + cb_submit.pSignalSemaphores = nullptr; + + device_->QueueSubmit(cb_submit, state_->fence_); + + uint64_t timeout = 1UL << 30UL; + VkResult res; + do { + res = vkWaitForFences(*device_, 1, &(state_->fence_), 0, timeout); + } while (res == VK_TIMEOUT); + VULKAN_CHECK_ERROR(res); + VULKAN_CALL(vkResetCommandBuffer(state_->cmd_buffer_, 0)); + VULKAN_CALL(vkResetFences(*device_, 1, &(state_->fence_))); + + // Re-initialize the command buffer + VkCommandBufferBeginInfo cb_begin; + cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + cb_begin.pNext = nullptr; + cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; + cb_begin.pInheritanceInfo = 0; + VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin)); +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h index d096a644a1f0..fb4e447c15e1 100644 --- a/src/runtime/vulkan/vulkan_stream.h +++ b/src/runtime/vulkan/vulkan_stream.h @@ -31,6 +31,8 @@ namespace tvm { namespace runtime { namespace vulkan { +class VulkanDevice; + class VulkanStreamState { public: VkCommandBuffer cmd_buffer_; @@ -43,138 +45,65 @@ struct VulkanStreamToken { std::vector buffers_; }; +/*! + * \brief Wrapper around a vulkan command buffer + * + * The VulkanStream collects commands into a VkCommandBuffer. When a + * newly submitted command requires resources reserved by an + * already-submitted command, all of the queued commands are + * submitted to the GPU, and the CPU waits for all queued commands to + * finish. The queued commands can also be explicitly pushed/waited + * on by calling VulkanStream::Synchronize. + * + * Currently, there exists one VulkanStream for each GPU device, for + * each CPU thread. Each time a VulkanWrappedFunc is called, it is + * submitted to the VulkanStream associated with the submitting CPU + * thread, and associated the thread-specific active device set by + * `DeviceAPI::SetDevice`. + */ class VulkanStream { public: - explicit VulkanStream(const VulkanContext* vctx) : vctx_(vctx), state_(new VulkanStreamState()) { - // create command pool - VkCommandPoolCreateInfo cmd_pool_cinfo; - cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; - cmd_pool_cinfo.pNext = nullptr; - cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT; - cmd_pool_cinfo.queueFamilyIndex = vctx_->queue_family_index; - VULKAN_CALL(vkCreateCommandPool(vctx_->device, &cmd_pool_cinfo, nullptr, &cmd_pool_)); - - VkCommandBufferAllocateInfo buffer_alloc_info; - buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; - buffer_alloc_info.pNext = nullptr; - buffer_alloc_info.commandPool = cmd_pool_; - buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; - buffer_alloc_info.commandBufferCount = 1; - VULKAN_CALL( - vkAllocateCommandBuffers(vctx_->device, &buffer_alloc_info, &(state_->cmd_buffer_))); - - VkFenceCreateInfo fence_cinfo; - fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; - fence_cinfo.pNext = nullptr; - fence_cinfo.flags = 0; // VK_FENCE_CREATE_SIGNALED_BIT; - VULKAN_CALL(vkCreateFence(vctx_->device, &fence_cinfo, nullptr, &(state_->fence_))); - - VkCommandBufferBeginInfo cb_begin; - cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; - cb_begin.pNext = nullptr; - cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; - cb_begin.pInheritanceInfo = 0; - VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin)); - } - - ~VulkanStream() { - vkDestroyFence(vctx_->device, state_->fence_, nullptr); - vkDestroyCommandPool(vctx_->device, cmd_pool_, nullptr); - } - - // Launch the kernel on the current stream. - void Launch(const std::function& kernel) { - if (vctx_->UseImmediate()) { - kernel(state_.get()); - } else { - deferred_kernels_.push_back(kernel); - } - } - - // Launch the kernel on the current stream, + explicit VulkanStream(const VulkanDevice* device); + + ~VulkanStream(); + + /*! \brief Push the kernel onto the stream's command buffer. + * + * If device.UseImmediate() is true, the kernel is executed + * immediately to update the command buffer. Otherwise, it is added + * to the list of deferred updates to be pushed onto the command + * buffer. + * + * Assumes that there are no descriptor sets or buffers accessed by this kernel. + * + */ + void Launch(const std::function& kernel); + + /*! \brief Push the kernel onto the stream's command buffer. + * + * Can only be called if device.UseImmediate() is false. The + * kernel is delayed, and isn't pushed to the command buffer until + * all kernels are collected. + * + * \param deferred_initializer Updates the descriptor set. Only + * called if the deferred_token has differences from + * + * \param deferred_kernel Submits updates to the command buffer. + * + * \param deferred_token Indicates which descriptor set and buffers + * are accessed by this kernel. No two kernels in the command + * buffer can use the same descriptor set. + * + */ void LaunchDeferred(const std::function& deferred_initializer, const std::function& deferred_kernel, - const VulkanStreamToken& deferred_token) { - ICHECK(!vctx_->UseImmediate()); - - // It is invalid to schedule this instance on the current stream if we already - // have a matching descriptor set and a non-matching buffer set. - if (std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(), - deferred_tokens_[deferred_token.descriptor_set_].end(), - [&](const VulkanStreamToken& token) { - DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); - return token.descriptor_set_ == deferred_token.descriptor_set_ && - token.buffers_ != deferred_token.buffers_; - })) { - Synchronize(); - } - - // It is unnecessary to invoke our initializer if we have a matching token. - if (!std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(), - deferred_tokens_[deferred_token.descriptor_set_].end(), - [&](const VulkanStreamToken& token) { - DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); - return token.descriptor_set_ == deferred_token.descriptor_set_ && - token.buffers_ == deferred_token.buffers_; - })) { - deferred_initializer(); - } - - deferred_kernels_.push_back(deferred_kernel); - deferred_tokens_[deferred_token.descriptor_set_].push_back(deferred_token); - } + const VulkanStreamToken& deferred_token); // Synchronize the current stream `state_` with respect to the host. - void Synchronize() { - if (!vctx_->UseImmediate()) { - for (const auto& deferred_kernel : deferred_kernels_) { - deferred_kernel(state_.get()); - } - deferred_kernels_.clear(); - deferred_tokens_.clear(); - } else { - DCHECK_EQ(deferred_kernels_.size(), 0); - DCHECK_EQ(deferred_tokens_.size(), 0); - } - - VULKAN_CALL(vkEndCommandBuffer(state_->cmd_buffer_)); - VkSubmitInfo cb_submit; - cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; - cb_submit.pNext = nullptr; - cb_submit.waitSemaphoreCount = 0; - cb_submit.pWaitSemaphores = nullptr; - cb_submit.pWaitDstStageMask = 0; - cb_submit.commandBufferCount = 1; - cb_submit.pCommandBuffers = &(state_->cmd_buffer_); - cb_submit.signalSemaphoreCount = 0; - cb_submit.pSignalSemaphores = nullptr; - - { - // Multiple streams (on different threads) use the same VulkanContext - // instance, so we need to externally synchronize accesses. - std::lock_guard g(*(vctx_->queue_mutex)); - VULKAN_CALL(vkQueueSubmit(vctx_->queue, 1, &cb_submit, state_->fence_)); - } - uint64_t timeout = 1UL << 30UL; - VkResult res; - do { - res = vkWaitForFences(vctx_->device, 1, &(state_->fence_), 0, timeout); - } while (res == VK_TIMEOUT); - VULKAN_CHECK_ERROR(res); - VULKAN_CALL(vkResetCommandBuffer(state_->cmd_buffer_, 0)); - VULKAN_CALL(vkResetFences(vctx_->device, 1, &(state_->fence_))); - - // Re-initialize the command buffer - VkCommandBufferBeginInfo cb_begin; - cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; - cb_begin.pNext = nullptr; - cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; - cb_begin.pInheritanceInfo = 0; - VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin)); - } + void Synchronize(); private: - const VulkanContext* vctx_; + const VulkanDevice* device_; std::unique_ptr state_; // An index of deferred tokens, allowing us to efficiently detect duplicated // deferred_initializer blocks. diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc new file mode 100644 index 000000000000..103b2aa7692c --- /dev/null +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -0,0 +1,416 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_wrapped_func.h" + +#include + +#include + +#include "../file_utils.h" +#include "vulkan_device_api.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr sptr, + const std::string& func_name, size_t num_buffer_args, + size_t num_pack_args, + const std::vector& thread_axis_tags) { + m_ = m; + sptr_ = sptr; + func_name_ = func_name; + num_buffer_args_ = num_buffer_args; + num_pack_args_ = num_pack_args; + thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); +} + +void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, + const ArgUnion64* pack_args) const { + int device_id = VulkanDeviceAPI::Global()->GetActiveDeviceID(); + auto& device = VulkanDeviceAPI::Global()->device(device_id); + if (!scache_[device_id]) { + scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); + } + const auto& pipeline = scache_[device_id]; + ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + std::vector descriptor_buffers; + descriptor_buffers.resize(num_buffer_args_); + for (size_t i = 0; i < num_buffer_args_; ++i) { + void* buf = args[static_cast(i)]; + VkDescriptorBufferInfo binfo; + binfo.buffer = static_cast(buf)->buffer; + binfo.offset = 0; + binfo.range = VK_WHOLE_SIZE; + descriptor_buffers[i] = binfo; + } + const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64); + if (pipeline->use_ubo) { + auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars); + VkDescriptorBufferInfo binfo; + binfo.buffer = ubo.vk_buf.buffer; + binfo.offset = 0; + binfo.range = VK_WHOLE_SIZE; + descriptor_buffers.push_back(binfo); + } + if (device.UseImmediate()) { + // Can safely capture by reference as this lambda is immediately executed on the calling thread. + device.ThreadLocalStream().Launch([&](VulkanStreamState* state) { + vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); + ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE); + device.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( + state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0, + descriptor_buffers.data()); + + if (pipeline->use_ubo) { + auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars); + memcpy(ubo.host_addr, pack_args, nbytes_scalars); + } else if (num_pack_args_ > 0) { + vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, + VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), + pack_args); + } + + vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); + VkMemoryBarrier barrier_info; + barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + barrier_info.pNext = nullptr; + barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT; + barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | + VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, + 1, &barrier_info, 0, nullptr, 0, nullptr); + }); + return; + } + + // Otherwise, the more expensive deferred path. + std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); + const auto& deferred_initializer = [&device, pipeline, descriptor_buffers]() { + std::vector write_descriptor_sets; + write_descriptor_sets.resize(descriptor_buffers.size()); + for (size_t i = 0; i < write_descriptor_sets.size(); i++) { + write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; + write_descriptor_sets[i].pNext = 0; + write_descriptor_sets[i].dstSet = pipeline->descriptor_set; + write_descriptor_sets[i].dstBinding = i; + write_descriptor_sets[i].dstArrayElement = 0; + write_descriptor_sets[i].descriptorCount = 1; + write_descriptor_sets[i].pImageInfo = 0; + write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]); + write_descriptor_sets[i].pTexelBufferView = 0; + + if (pipeline->use_ubo && i == write_descriptor_sets.size() - 1) { + // The last binding is for UBO + write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; + } else { + write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + } + } + vkUpdateDescriptorSets(device, write_descriptor_sets.size(), write_descriptor_sets.data(), 0, + 0); + }; + const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars, + device_id](VulkanStreamState* state) { + auto& device = VulkanDeviceAPI::Global()->device(device_id); + + vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); + vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, + pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0, + nullptr); + + if (pipeline->use_ubo) { + auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars); + memcpy(ubo.host_addr, pack_args_storage.data(), nbytes_scalars); + } else if (num_pack_args_ > 0) { + vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, + 0, pack_args_storage.size() * sizeof(ArgUnion64), + pack_args_storage.data()); + } + + vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); + VkMemoryBarrier barrier_info; + barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + barrier_info.pNext = nullptr; + barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT; + barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | + VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, + 1, &barrier_info, 0, nullptr, 0, nullptr); + }; + VulkanStreamToken deferred_token; + deferred_token.descriptor_set_ = pipeline->descriptor_set; + deferred_token.buffers_.resize(descriptor_buffers.size()); + for (size_t i = 0; i < descriptor_buffers.size(); ++i) { + deferred_token.buffers_[i] = descriptor_buffers[i].buffer; + } + device.ThreadLocalStream().LaunchDeferred(deferred_initializer, deferred_kernel, deferred_token); +} + +VulkanModuleNode::~VulkanModuleNode() { + // cleanup vulkan related caches. + for (size_t device_id = 0; device_id < ecache_.size(); ++device_id) { + for (auto& kv : ecache_[device_id]) { + auto& pe = kv.second; + ICHECK(pe); + const auto& device = VulkanDeviceAPI::Global()->device(device_id); + + if (pe->descriptor_update_template != VK_NULL_HANDLE) { + device.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR( + device, pe->descriptor_update_template, nullptr); + } + vkDestroyPipeline(device, pe->pipeline, nullptr); + vkDestroyPipelineLayout(device, pe->pipeline_layout, nullptr); + vkDestroyDescriptorPool(device, pe->descriptor_pool, nullptr); + vkDestroyDescriptorSetLayout(device, pe->descriptor_set_layout, nullptr); + vkDestroyShaderModule(device, pe->shader, nullptr); + } + } +} + +PackedFunc VulkanModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + ICHECK_EQ(sptr_to_self.get(), this); + ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; + auto it = fmap_.find(name); + if (it == fmap_.end()) return PackedFunc(); + const FunctionInfo& info = it->second; + VulkanWrappedFunc f; + size_t num_buffer_args = NumBufferArgs(info.arg_types); + f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, + info.thread_axis_tags); + return PackFuncNonBufferArg(std::move(f), info.arg_types); +} + +std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, + const std::string& func_name, + size_t num_pack_args) { + auto& device = VulkanDeviceAPI::Global()->device(device_id); + std::lock_guard lock(mutex_); + const auto& cp = ecache_[device_id][func_name]; + if (cp) { + return cp; + } + // Create new pipeline + auto pe = std::make_shared(); + { + // create shader + auto sit = smap_.find(func_name); + ICHECK(sit != smap_.end()); + pe->use_ubo = sit->second.flag & (1 << ShaderMetaDataFlagMask::kUseUBO); + const std::vector& data = sit->second.data; + VkShaderModuleCreateInfo shader_cinfo; + shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + shader_cinfo.pNext = nullptr; + shader_cinfo.flags = 0; + shader_cinfo.codeSize = data.size() * sizeof(uint32_t); + shader_cinfo.pCode = data.data(); + VULKAN_CALL(vkCreateShaderModule(device, &shader_cinfo, nullptr, &(pe->shader))); + } + std::vector arg_binding; + std::vector arg_template; + std::vector descriptor_set_pool_sizes; + uint32_t num_pod = 0, num_buffer = 0; + + auto push_arg_info = [&arg_binding, &arg_template, &descriptor_set_pool_sizes]( + uint32_t binding, VkDescriptorType desc_type) { + { + auto result = std::find_if(descriptor_set_pool_sizes.begin(), descriptor_set_pool_sizes.end(), + [&](const auto& psize) { return psize.type == desc_type; }); + if (result == descriptor_set_pool_sizes.end()) { + VkDescriptorPoolSize new_size; + new_size.type = desc_type; + new_size.descriptorCount = 1; + descriptor_set_pool_sizes.push_back(new_size); + } else { + result->descriptorCount++; + } + } + + { + VkDescriptorSetLayoutBinding bd; + bd.binding = binding; + bd.descriptorType = desc_type; + bd.descriptorCount = 1; + bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + bd.pImmutableSamplers = nullptr; + arg_binding.push_back(bd); + } + { + VkDescriptorUpdateTemplateEntryKHR tpl; + tpl.dstBinding = binding; + tpl.dstArrayElement = 0; + tpl.descriptorCount = 1; + tpl.descriptorType = desc_type; + tpl.offset = binding * sizeof(VkDescriptorBufferInfo); + tpl.stride = sizeof(VkDescriptorBufferInfo); + arg_template.push_back(tpl); + } + }; + + { + auto fit = fmap_.find(func_name); + ICHECK(fit != fmap_.end()); + for (DLDataType arg_type : fit->second.arg_types) { + if (arg_type.code == kTVMOpaqueHandle) { + push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER); + ++num_buffer; + } else { + ++num_pod; + } + } + } + + size_t nbytes_scalars = num_pod * sizeof(ArgUnion64); + if (pe->use_ubo) { + // Use UBO instead of push constants + push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER); + device.AllocateThreadLocalUniformBuffer(nbytes_scalars); + } + + { + VkDescriptorSetLayoutCreateInfo descrip_cinfo; + descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; + descrip_cinfo.pNext = nullptr; + descrip_cinfo.flags = 0; + if (device.UseImmediate()) { + descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; + } + descrip_cinfo.bindingCount = arg_binding.size(); + descrip_cinfo.pBindings = arg_binding.data(); + VULKAN_CALL( + vkCreateDescriptorSetLayout(device, &descrip_cinfo, nullptr, &(pe->descriptor_set_layout))); + } + + if (!device.UseImmediate()) { + VkDescriptorPoolCreateInfo descrip_pool_cinfo; + descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; + descrip_pool_cinfo.pNext = nullptr; + descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT; + descrip_pool_cinfo.maxSets = 1; + descrip_pool_cinfo.poolSizeCount = descriptor_set_pool_sizes.size(); + descrip_pool_cinfo.pPoolSizes = descriptor_set_pool_sizes.data(); + VULKAN_CALL( + vkCreateDescriptorPool(device, &descrip_pool_cinfo, nullptr, &(pe->descriptor_pool))); + + VkDescriptorSetAllocateInfo alloc_info; + alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; + alloc_info.pNext = nullptr; + alloc_info.descriptorPool = pe->descriptor_pool; + alloc_info.descriptorSetCount = 1; + alloc_info.pSetLayouts = &(pe->descriptor_set_layout); + VULKAN_CALL(vkAllocateDescriptorSets(device, &alloc_info, &(pe->descriptor_set))); + } + + VkPushConstantRange crange; + crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + crange.offset = 0; + crange.size = sizeof(ArgUnion64) * num_pack_args; + + VkPipelineLayoutCreateInfo playout_cinfo; + playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; + playout_cinfo.pNext = nullptr; + playout_cinfo.flags = 0; + playout_cinfo.setLayoutCount = 1; + playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout); + + if (0 < nbytes_scalars && !pe->use_ubo) { + playout_cinfo.pushConstantRangeCount = 1; + playout_cinfo.pPushConstantRanges = &crange; + ICHECK_LE(crange.size, device.device_properties.max_push_constants_size) + << "The Vulkan shader uses " << crange.size + << " bytes of push constants, but the device only supports " + << device.device_properties.max_push_constants_size << "bytes. " + << "Please rebuild the shader using a smaller limit on push constants size " + << "by passing -max_push_constants_size=N in the Target string, " + << "or pass -from_device=0 to query all device parameters."; + } else { + playout_cinfo.pushConstantRangeCount = 0; + playout_cinfo.pPushConstantRanges = nullptr; + } + + VULKAN_CALL(vkCreatePipelineLayout(device, &playout_cinfo, nullptr, &(pe->pipeline_layout))); + + VkComputePipelineCreateInfo pipeline_cinfo; + pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + pipeline_cinfo.pNext = nullptr; + pipeline_cinfo.flags = 0; + pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + pipeline_cinfo.stage.pNext = nullptr; + pipeline_cinfo.stage.flags = 0; + pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; + pipeline_cinfo.stage.module = pe->shader; + pipeline_cinfo.stage.pName = func_name.c_str(); + pipeline_cinfo.stage.pSpecializationInfo = nullptr; + pipeline_cinfo.layout = pe->pipeline_layout; + pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE; + pipeline_cinfo.basePipelineIndex = 0; + VULKAN_CALL(vkCreateComputePipelines(device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, + &(pe->pipeline))); + + if (device.UseImmediate()) { + VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo; + descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR; + descrip_template_cinfo.pNext = 0; + descrip_template_cinfo.flags = 0; + descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size(); + descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data(); + descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR; + descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout; + descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE; + descrip_template_cinfo.pipelineLayout = pe->pipeline_layout; + descrip_template_cinfo.set = 0; + VULKAN_CALL(device.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR( + device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template))); + } + ecache_[device_id][func_name] = pe; + return pe; +} + +void VulkanModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { + std::string fmt = GetFileFormat(file_name, format); + ICHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan"; + std::string meta_file = GetMetaFilePath(file_name); + SaveMetaDataToFile(meta_file, fmap_); + std::string data_bin; + dmlc::MemoryStringStream fs(&data_bin); + dmlc::Stream* stream = &fs; + uint32_t magic = kVulkanModuleMagic; + stream->Write(magic); + stream->Write(smap_); + SaveBinaryToFile(file_name, data_bin); +} + +void VulkanModuleNode::SaveToBinary(dmlc::Stream* stream) { + stream->Write(fmt_); + stream->Write(fmap_); + stream->Write(smap_); +} + +std::string VulkanModuleNode::GetSource(const std::string& format) { + // can only return disassembly code. + return source_; +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h new file mode 100644 index 000000000000..a174f22eba59 --- /dev/null +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_WRAPPED_FUNC_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_WRAPPED_FUNC_H_ + +#include +#include +#include +#include +#include +#include + +#include "../meta_data.h" +#include "../pack_args.h" +#include "../thread_storage_scope.h" +#include "vulkan/vulkan_core.h" +#include "vulkan_common.h" +#include "vulkan_device.h" +#include "vulkan_shader.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +struct VulkanPipeline { + VulkanDevice* device{nullptr}; + VkShaderModule shader{VK_NULL_HANDLE}; + VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE}; + VkDescriptorPool descriptor_pool{VK_NULL_HANDLE}; + VkDescriptorSet descriptor_set{VK_NULL_HANDLE}; + VkPipelineLayout pipeline_layout{VK_NULL_HANDLE}; + VkPipeline pipeline{VK_NULL_HANDLE}; + VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE}; + bool use_ubo{false}; +}; + +class VulkanModuleNode; + +// a wrapped function class to get packed func. +class VulkanWrappedFunc { + public: + void Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, + size_t num_buffer_args, size_t num_pack_args, + const std::vector& thread_axis_tags); + + void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const; + + private: + // internal module + VulkanModuleNode* m_; + // the resource holder + ObjectPtr sptr_; + // v The name of the function. + std::string func_name_; + // Number of buffer arguments + size_t num_buffer_args_; + // number of packed arguments. + size_t num_pack_args_; + // Device state cache per device. + // mark as mutable, to enable lazy initialization + // thread axis configuration + ThreadAxisConfig thread_axis_cfg_; + + mutable std::array, kVulkanMaxNumDevice> scache_; +}; + +class VulkanModuleNode final : public runtime::ModuleNode { + public: + explicit VulkanModuleNode(std::unordered_map smap, + std::unordered_map fmap, std::string source) + : smap_(smap), fmap_(fmap), source_(source) {} + ~VulkanModuleNode(); + + const char* type_key() const final { return "vulkan"; } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + + std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, + size_t num_pack_args); + + void SaveToFile(const std::string& file_name, const std::string& format) final; + + void SaveToBinary(dmlc::Stream* stream) final; + std::string GetSource(const std::string& format) final; + + private: + // function information table. + std::unordered_map smap_; + // function information table. + std::unordered_map fmap_; + // The format + std::string fmt_{"vulkan"}; + // The source + std::string source_; + + // Guards accesses to `ecache_` + std::mutex mutex_; + std::array>, kVulkanMaxNumDevice> + ecache_; +}; + +} // namespace vulkan +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_VULKAN_VULKAN_WRAPPED_FUNC_H_ diff --git a/src/support/array.h b/src/support/array.h new file mode 100644 index 000000000000..2cf416c471ec --- /dev/null +++ b/src/support/array.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SUPPORT_ARRAY_H_ +#define TVM_SUPPORT_ARRAY_H_ +#include + +#include + +namespace tvm { +namespace support { + +/*! + * \brief Checks if two arrays contain the same objects + * \tparam T The type of objects in the array + * \param a The first array + * \param b The second array + * \return A boolean indicating if they are the same + */ +template +inline bool ArrayWithSameContent(const Array& a, const Array& b) { + if (a.size() != b.size()) { + return false; + } + int n = a.size(); + for (int i = 0; i < n; ++i) { + if (!a[i].same_as(b[i])) { + return false; + } + } + return true; +} + +/*! + * \brief Checks if two arrays contain the same objects + * \tparam T The type of objects in the array + * \param a The first array + * \param b The second array + * \return A boolean indicating if they are the same + */ +template +inline bool ArrayWithSameContent(const std::vector& a, const std::vector& b) { + if (a.size() != b.size()) { + return false; + } + int n = a.size(); + for (int i = 0; i < n; ++i) { + if (a[i] != b[i]) { + return false; + } + } + return true; +} + +} // namespace support +} // namespace tvm +#endif // TVM_SUPPORT_ARRAY_H_ diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index ea3a22e8ab01..4b5dc9080df1 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include #include #include diff --git a/src/support/utils.h b/src/support/utils.h index 075351760686..d807c5b8bb63 100644 --- a/src/support/utils.h +++ b/src/support/utils.h @@ -32,7 +32,7 @@ #endif // __hexagon__ #endif // _WIN32 -#include +#include #include #include diff --git a/src/target/build_common.h b/src/target/build_common.h index 1816c3ac2650..d2fe6468eef8 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -25,7 +25,6 @@ #define TVM_TARGET_BUILD_COMMON_H_ #include -#include #include #include #include diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 19b7ad7b1d8f..5a4aa39f01b4 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -47,13 +46,9 @@ runtime::Module Build(IRModule mod, Target target) { .value()) { mod = tir::transform::SkipAssert()(mod); } - std::string build_f_name; - if (target->kind->name == "micro_dev") { - build_f_name = "target.build.c"; - } else { - build_f_name = "target.build." + target->kind->name; - } + // the build function. + std::string build_f_name = "target.build." + target->kind->name; const PackedFunc* bf = runtime::Registry::Get(build_f_name); ICHECK(bf != nullptr) << build_f_name << " is not enabled"; return (*bf)(mod, target); diff --git a/src/target/func_registry_generator.h b/src/target/func_registry_generator.h index fb5964859352..8d2af305a0e4 100644 --- a/src/target/func_registry_generator.h +++ b/src/target/func_registry_generator.h @@ -24,7 +24,8 @@ #ifndef TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_ #define TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_ -#include +#include +#include #include #include diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc index 5dbceec32ed7..42957152ea12 100644 --- a/src/target/generic_func.cc +++ b/src/target/generic_func.cc @@ -22,7 +22,6 @@ #include #include #include -#include #include #include #include diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index d5140677d45a..48ccefafe3c4 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -62,7 +62,7 @@ void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, md_builder_.reset(new llvm::MDBuilder(*ctx_)); // types t_void_ = llvm::Type::getVoidTy(*ctx_); - t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(); + t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(GetGlobalAddressSpace()); t_int_ = llvm::Type::getInt32Ty(*ctx_); t_char_ = llvm::Type::getInt8Ty(*ctx_); t_int8_ = llvm::Type::getInt8Ty(*ctx_); @@ -191,20 +191,10 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { void CodeGenLLVM::LinkParameters(const Map params) { // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc, // but they are at a different layer in the compiler... - std::vector param_types; - // args - param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace())); - // tcodes - param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace())); - // num_args - param_types.push_back(t_int_); - // ret_args - param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace())); - // ret_tcodes - param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace())); - // resource_handle - param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace())); + llvm::Type* t_int_p = t_int_->getPointerTo(GetGlobalAddressSpace()); + // args, tcodes, num_args, ret_value, ret_tcode, resource_handle + std::vector param_types{t_void_p_, t_int_p, t_int_, t_void_p_, t_int_p, t_void_p_}; llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false); llvm::Function* function = @@ -215,41 +205,29 @@ void CodeGenLLVM::LinkParameters(const Map params) { llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function); builder_->SetInsertPoint(entry); - std::vector zero_index_list{llvm::ConstantInt::get(t_int32_, 0)}; - std::vector zero_array_index_list{llvm::ConstantInt::get(t_int32_, 0), - llvm::ConstantInt::get(t_int32_, 0)}; - auto args_array = builder_->CreateBitCast( -#if TVM_LLVM_VERSION >= 50 - &function->arg_begin()[0], + + auto getArg = [function](int i) -> llvm::Argument* { +#if TVM_LLVM_VERSION >= 100 + return function->getArg(i); +#elif TVM_LLVM_VERSION >= 50 + return &function->arg_begin()[i]; #else - &(*(function->arg_begin())), + return &*std::next(function->arg_begin(), i); #endif - llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1)); - llvm::Value* sid = builder_->CreateBitCast( - builder_->CreateLoad(t_void_->getPointerTo(GetGlobalAddressSpace()), - builder_->CreateInBoundsGEP(args_array, zero_index_list)), - t_int64_); + }; + + llvm::Type* t_int64_p = t_int64_->getPointerTo(GetGlobalAddressSpace()); + llvm::Value* sid = builder_->CreateLoad(t_int64_, builder_->CreateBitCast(getArg(0), t_int64_p)); + + auto ret_tcode = builder_->CreateBitCast(getArg(4), t_int_p); + auto ret_value = + builder_->CreateBitCast(getArg(3), t_void_p_->getPointerTo(GetGlobalAddressSpace())); llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function); - auto ret_types_array = builder_->CreateBitCast( -#if TVM_LLVM_VERSION >= 50 - &function->arg_begin()[4], -#else - &(*(std::next(function->arg_begin(), 4))), -#endif - llvm::ArrayType::get(t_int_, 1)->getPointerTo()); - auto retval_array = builder_->CreateBitCast( -#if TVM_LLVM_VERSION >= 50 - &function->arg_begin()[3], -#else - &(*std::next(function->arg_begin(), 3)), -#endif - llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1)->getPointerTo()); llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1); builder_->SetInsertPoint(default_block); - builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr), - builder_->CreateInBoundsGEP(ret_types_array, zero_array_index_list)); + builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr), ret_tcode); builder_->CreateRet(ConstInt32(kTvmErrorNoError)); // Add data to the global section. @@ -258,16 +236,20 @@ void CodeGenLLVM::LinkParameters(const Map params) { std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first; llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); + auto dtype = tvm::runtime::DataType(kv.second->param->dtype); + size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment); +#if TVM_LLVM_VERSION >= 100 + param_symbol->setAlignment(llvm::Align(align)); +#else + param_symbol->setAlignment(align); +#endif llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + symbol_name, function); switch_inst->addCase( llvm::cast(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block); builder_->SetInsertPoint(case_block); - builder_->CreateStore( - builder_->CreatePointerCast(param_symbol, t_void_->getPointerTo(GetGlobalAddressSpace())), - builder_->CreateInBoundsGEP(retval_array, zero_array_index_list)); - builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), - builder_->CreateInBoundsGEP(ret_types_array, zero_array_index_list)); + builder_->CreateStore(builder_->CreatePointerCast(param_symbol, t_void_p_), ret_value); + builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), ret_tcode); builder_->CreateRet(ConstInt32(0)); } } diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e56a6de6d914..d5fcfab6d889 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -27,7 +27,6 @@ #include #include -#include #include #include #include diff --git a/src/target/llvm/codegen_params.h b/src/target/llvm/codegen_params.h index 771bc201f7aa..f5fd21ff326d 100644 --- a/src/target/llvm/codegen_params.h +++ b/src/target/llvm/codegen_params.h @@ -24,7 +24,6 @@ #ifndef TVM_TARGET_LLVM_CODEGEN_PARAMS_H_ #define TVM_TARGET_LLVM_CODEGEN_PARAMS_H_ -#include #include #include "llvm_common.h" diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index 1791a5574c11..b967c7ad44e0 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -37,7 +37,6 @@ #include #include #include -#include #if TVM_LLVM_VERSION >= 100 #include #include @@ -78,6 +77,7 @@ #include #include #include +#include #include #include diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 3eab00c643e5..6b05d4bdf2d5 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -25,7 +25,6 @@ #ifndef TVM_TARGET_LLVM_LLVM_MODULE_H_ #define TVM_TARGET_LLVM_LLVM_MODULE_H_ -#include #include #include diff --git a/src/target/metadata_module.h b/src/target/metadata_module.h index add05ba52692..9311ee78ca6a 100644 --- a/src/target/metadata_module.h +++ b/src/target/metadata_module.h @@ -25,7 +25,6 @@ #ifndef TVM_TARGET_METADATA_MODULE_H_ #define TVM_TARGET_METADATA_MODULE_H_ -#include #include #include #include diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 1a0f08920fb6..4a2917daa5ed 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -67,6 +67,10 @@ std::string FindCUDAIncludePath() { if (stat(cuda_include_path.c_str(), &st) == 0) { return cuda_include_path; } + + if (stat("/usr/include/cuda.h", &st) == 0) { + return "/usr/include"; + } #endif LOG(FATAL) << "Cannot find cuda include path." << "CUDA_PATH is not set or CUDA is not installed in the default installation path." diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index b3ed7cf32f7f..17e38e9af6e6 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -84,7 +84,7 @@ TVM_REGISTER_GLOBAL("target.build.aocl") return BuildAOCL(mod, target, false); }); -TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu") +TVM_REGISTER_GLOBAL("target.build.aocl_sw_emu") .set_body_typed([](IRModule mod, Target target) -> runtime::Module { return BuildAOCL(mod, target, true); }); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index d4d0e54c6db4..99c9452975d4 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -212,13 +212,18 @@ std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr i PrintType(t.element_of(), os); os << "*)"; } - os << vid << " + ("; - PrintExpr(index, os); - os << ")"; if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { - os << " / " << (32 / t.bits()); + os << vid << ") + ("; + PrintExpr(index, os); + os << ")"; + os << " / " << t.lanes(); + os << ")[0]"; + } else { + os << vid << " + ("; + PrintExpr(index, os); + os << ")"; + os << "))[0]"; } - os << "))[0]"; } return os.str(); } diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 76e6a9bc7197..ae451f39f89b 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -25,7 +25,6 @@ #define TVM_TARGET_SOURCE_CODEGEN_C_H_ #include -#include #include #include #include diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 03fef4709b5e..2d93989730c7 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -22,7 +22,6 @@ */ #include "codegen_c_host.h" -#include #include #include #include diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 4cc999bf9136..6e76c3538e71 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -809,18 +809,48 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } - if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4 && op->lanes == 8) { - // make_int4x8 + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { + bool fail = false; const int64_t* p = as_const_int(op->value); ICHECK(p); int64_t v = *p & 0xF; - v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; - if (op->dtype.is_uint()) { - os << "(uint)" << v; + + if (op->lanes == 4) { + v = (v << 12) | (v << 8) | (v << 4) | v; + if (op->dtype.is_uint()) { + os << "(uint16_t)" << v; + } else { + os << "(int16_t)" << v; + } } else { - os << "(int)" << v; + v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; + if (op->lanes == 8) { + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } else if (op->lanes == 16 || op->lanes == 32) { + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < op->lanes / 8; ++i) { + if (i != 0) os << ", "; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } + os << ')'; + } else { + fail = true; + } + } + + if (!fail) { + return; } - return; } std::string v = PrintExpr(op->value); diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 4d2bf53f2252..71e3529e0d80 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -325,27 +325,30 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NO runtime::Module BuildMetal(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; - CodeGenMetal cg; - cg.Init(output_ssa); + std::stringstream code; + std::stringstream source; + std::string fmt = "metal"; for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; + code << "// Function: " << kv.first->name_hint << std::endl; + CodeGenMetal cg; + cg.Init(output_ssa); auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); + std::string fsource = cg.Finish(); + if (const auto* f = Registry::Get("tvm_callback_metal_compile")) { + source << fsource; + fsource = (*f)(fsource).operator std::string(); + fmt = "metallib"; + } + code << fsource; } - std::string code = cg.Finish(); - std::string fmt = "metal"; - std::string source = ""; - if (const auto* f = Registry::Get("tvm_callback_metal_compile")) { - source = code; - code = (*f)(code).operator std::string(); - fmt = "metallib"; - } - return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source); + return MetalModuleCreate(code.str(), fmt, ExtractFuncInfo(mod), source.str()); } TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal); diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 661df9305036..992df61980f8 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -192,17 +192,59 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { << "}\n"; } + void GenerateEntrypointForUnpackedAPI() { + code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix << "("; + int total_args = (metadata_->num_inputs + metadata_->num_outputs); + for (int i = 0; i < total_args; ++i) { + code_ << "arg" << i; + if (i + 1 != total_args) { + code_ << ","; + } + } + code_ << ");\n"; + code_ << "static int32_t " << ::tvm::runtime::symbol::tvm_module_main; + code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " + "out_type_code, void* resource_handle) {\n"; + code_ << "return " << ::tvm::runtime::symbol::tvm_run_func_prefix << "("; + for (int i = 0; i < metadata_->num_inputs; ++i) { + code_ << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,"; + } + for (int i = 0; i < metadata_->num_outputs; ++i) { + int j = metadata_->num_inputs + i; + code_ << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data"; + if (i + 1 != metadata_->num_outputs) { + code_ << ","; + } + } + code_ << ");\n"; + code_ << "}\n"; + } + + void GenerateEntrypointForPackedAPI() { + code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix; + code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " + "out_type_code, void* resource_handle);\n"; + code_ << "static int32_t " << ::tvm::runtime::symbol::tvm_module_main; + code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " + "out_type_code, void* resource_handle) {\n"; + code_ << "return " << ::tvm::runtime::symbol::tvm_run_func_prefix; + code_ << "(args, type_code, num_args, out_value, out_type_code, resource_handle);\n"; + code_ << "}\n"; + } + void GenerateAOTDescriptor() { code_ << "#include \"tvm/runtime/crt/internal/aot_executor/aot_executor.h\"\n"; code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n"; code_ << "#ifdef __cplusplus\n"; code_ << "extern \"C\"\n"; code_ << "#endif\n"; - code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix; - code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " - "out_type_code, void* resource_handle);\n"; + if (target_->GetAttr("unpacked-api").value_or(Bool(false))) { + GenerateEntrypointForUnpackedAPI(); + } else { + GenerateEntrypointForPackedAPI(); + } code_ << "const tvm_model_t network = {\n" - << " .run_func = &" << ::tvm::runtime::symbol::tvm_run_func_prefix << ",\n" + << " .run_func = &" << ::tvm::runtime::symbol::tvm_module_main << ",\n" << " .num_input_tensors = " << metadata_->num_inputs << ",\n" << " .num_output_tensors = " << metadata_->num_outputs << ", \n" << "};\n"; diff --git a/src/target/source/source_module.h b/src/target/source/source_module.h index 6226ba2f22b3..8ed08048cf2f 100644 --- a/src/target/source/source_module.h +++ b/src/target/source/source_module.h @@ -25,7 +25,6 @@ #ifndef TVM_TARGET_SOURCE_SOURCE_MODULE_H_ #define TVM_TARGET_SOURCE_SOURCE_MODULE_H_ -#include #include #include diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index dc625b6a928d..d8f0f8e90238 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -23,7 +23,6 @@ */ #include "codegen_spirv.h" -#include #include #include #include diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index c2460b2d8ff6..128d60e7725a 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -449,24 +449,42 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { // Declare appropriate capabilities for int/float types if (dtype.is_int() || dtype.is_uint()) { if (dtype.bits() == 8) { - ICHECK(spirv_support_.supports_int8) << "Vulkan target does not support Int8 capability"; + ICHECK(spirv_support_.supports_int8) + << "Vulkan target does not support Int8 capability. " + << "If your device supports 8-bit int operations, " + << "please either add -supports_int8=1 to the target, " + << "or query all device parameters by adding -from_device=0."; capabilities_used_.insert(spv::CapabilityInt8); } else if (dtype.bits() == 16) { - ICHECK(spirv_support_.supports_int16) << "Vulkan target does not support Int16 capability"; + ICHECK(spirv_support_.supports_int16) + << "Vulkan target does not support Int16 capability. " + << "If your device supports 16-bit int operations, " + << "please either add -supports_int16=1 to the target, " + << "or query all device parameters by adding -from_device=0."; capabilities_used_.insert(spv::CapabilityInt16); } else if (dtype.bits() == 64) { - ICHECK(spirv_support_.supports_int64) << "Vulkan target does not support Int64 capability"; + ICHECK(spirv_support_.supports_int64) + << "Vulkan target does not support Int64 capability. " + << "If your device supports 64-bit int operations, " + << "please either add -supports_int64=1 to the target, " + << "or query all device parameters by adding -from_device=0."; capabilities_used_.insert(spv::CapabilityInt64); } } else if (dtype.is_float()) { if (dtype.bits() == 16) { ICHECK(spirv_support_.supports_float16) - << "Vulkan target does not support Float16 capability"; + << "Vulkan target does not support Float16 capability. " + << "If your device supports 16-bit float operations, " + << "please either add -supports_float16=1 to the target, " + << "or query all device parameters by adding -from_device=0."; capabilities_used_.insert(spv::CapabilityFloat16); } else if (dtype.bits() == 64) { ICHECK(spirv_support_.supports_float64) - << "Vulkan target does not support Float64 capability"; + << "Vulkan target does not support Float64 capability. " + << "If your device supports 64-bit float operations, " + << "please either add -supports_float64=1 to the target, " + << "or query all device parameters by adding -from_device=0."; capabilities_used_.insert(spv::CapabilityFloat64); } } @@ -478,17 +496,25 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { // supports Int8 but doesn't support 8-bit buffer access. if (dtype.bits() == 8) { ICHECK(spirv_support_.supports_storage_buffer_8bit_access) - << "Vulkan target does not support StorageBuffer8BitAccess"; + << "Vulkan target does not support StorageBuffer8BitAccess. " + << "If your device supports 8-bit buffer access, " + << "please either add -supports_8bit_buffer=1 to the target, " + << "or query all device parameters by adding -from_device=0."; capabilities_used_.insert(spv::CapabilityStorageBuffer8BitAccess); extensions_used_.insert("SPV_KHR_8bit_storage"); ICHECK(spirv_support_.supports_storage_buffer_storage_class) << "Illegal Vulkan target description. " << "Vulkan spec requires extension VK_KHR_storage_buffer_storage_class " - << "if VK_KHR_8bit_storage is supported"; + << "if VK_KHR_8bit_storage is supported. " + << "Please either add -supports_storage_buffer_storage_class=1 to the target, " + << "or query all device parameters by adding -from_device=0."; } else if (dtype.bits() == 16) { - ICHECK(spirv_support_.supports_storage_buffer_8bit_access) - << "Vulkan target does not support StorageBuffer16BitAccess"; + ICHECK(spirv_support_.supports_storage_buffer_16bit_access) + << "Vulkan target does not support StorageBuffer16BitAccess. " + << "If your device supports 16-bit buffer access, " + << "please either add -supports_16bit_buffer=1 to the target, " + << "or query all device parameters by adding -from_device=0."; extensions_used_.insert("SPV_KHR_16bit_storage"); if (spirv_support_.supports_storage_buffer_storage_class) { diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index ff9aee406574..e06bde08895d 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -35,17 +35,45 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { ICHECK_EQ(target->kind->device_type, kDLVulkan) << "SPIRVSupport can only be checked for vulkan device type"; - // Currently, this codifies the assumptions that were present and - // implicit in previous implementations. In the future, this will - // pull information from the specified `Target`. - - supports_storage_buffer_storage_class = (SPV_VERSION >= 0x10300); - supports_storage_buffer_8bit_access = true; - supports_storage_buffer_16bit_access = true; - supports_float16 = true; - supports_int8 = true; - supports_int16 = true; - supports_int64 = true; + if (target->GetAttr("supported_subgroup_operations")) { + supported_subgroup_operations = + target->GetAttr("supported_subgroup_operations").value(); + } + if (target->GetAttr("max_push_constants_size")) { + max_push_constants_size = target->GetAttr("max_push_constants_size").value(); + } + if (target->GetAttr("max_uniform_buffer_range")) { + max_uniform_buffer_range = target->GetAttr("max_uniform_buffer_range").value(); + } + if (target->GetAttr("max_storage_buffer_range")) { + max_storage_buffer_range = target->GetAttr("max_storage_buffer_range").value(); + } + if (target->GetAttr("max_per_stage_descriptor_storage_buffer")) { + max_per_stage_descriptor_storage_buffers = + target->GetAttr("max_per_stage_descriptor_storage_buffer").value(); + } + if (target->GetAttr("supports_storage_buffer_storage_class")) { + supports_storage_buffer_storage_class = + target->GetAttr("supports_storage_buffer_storage_class").value(); + } + if (target->GetAttr("supports_8bit_buffer")) { + supports_storage_buffer_8bit_access = target->GetAttr("supports_8bit_buffer").value(); + } + if (target->GetAttr("supports_16bit_buffer")) { + supports_storage_buffer_16bit_access = target->GetAttr("supports_16bit_buffer").value(); + } + if (target->GetAttr("supports_float16")) { + supports_float16 = target->GetAttr("supports_float16").value(); + } + if (target->GetAttr("supports_int8")) { + supports_int8 = target->GetAttr("supports_int8").value(); + } + if (target->GetAttr("supports_int16")) { + supports_int16 = target->GetAttr("supports_int16").value(); + } + if (target->GetAttr("supports_int64")) { + supports_int64 = target->GetAttr("supports_int64").value(); + } } } // namespace codegen diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 0dd96e07ed96..402e3291975f 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -23,7 +23,6 @@ #include "codegen_stackvm.h" #include -#include #include #include #include diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index e06b2c05d3bf..b9d9706773f7 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -209,6 +209,85 @@ Map UpdateROCmAttrs(Map attrs) { return attrs; } +/*! + * \brief Update the attributes in the Vulkan target. + * \param attrs The original attributes + * \return The updated attributes + */ +Map UpdateVulkanAttrs(Map attrs) { + if (attrs.count("from_device")) { + int device_id = Downcast(attrs.at("from_device")); + Device device{kDLVulkan, device_id}; + const PackedFunc* get_target_property = + runtime::Registry::Get("device_api.vulkan.get_target_property"); + ICHECK(get_target_property) + << "Requested to read Vulkan parameters from device, but no Vulkan runtime available"; + + // Current vulkan implementation is partially a proof-of-concept, + // with long-term goal to move the -from_device functionality to + // TargetInternal::FromConfig, and to be usable by all targets. + // The duplicate list of parameters is needed until then, since + // TargetKind::Get("vulkan")->key2vtype_ is private. + std::vector bool_opts = { + "supports_float16", "supports_float32", + "supports_float64", "supports_int8", + "supports_int16", "supports_int32", + "supports_int64", "supports_8bit_buffer", + "supports_16bit_buffer", "supports_storage_buffer_storage_class", + "supports_push_descriptor", "supports_dedicated_allocation"}; + std::vector int_opts = {"supported_subgroup_operations", + "max_num_threads", + "thread_warp_size", + "max_block_size_x", + "max_block_size_y", + "max_block_size_z", + "max_push_constants_size", + "max_uniform_buffer_range", + "max_storage_buffer_range", + "max_per_stage_descriptor_storage_buffer", + "max_shared_memory_per_block", + "driver_version", + "vulkan_api_version", + "max_spirv_version"}; + std::vector str_opts = {"device_name"}; + + for (auto& key : bool_opts) { + if (!attrs.count(key)) { + attrs.Set(key, Bool(static_cast((*get_target_property)(device, key)))); + } + } + for (auto& key : int_opts) { + if (!attrs.count(key)) { + attrs.Set(key, Integer(static_cast((*get_target_property)(device, key)))); + } + } + for (auto& key : str_opts) { + if (!attrs.count(key)) { + attrs.Set(key, (*get_target_property)(device, key)); + } + } + + attrs.erase("from_device"); + } + + // Set defaults here, rather than in the .add_attr_option() calls. + // The priority should be user-specified > device-query > default, + // but defaults defined in .add_attr_option() are already applied by + // this point. Longer-term, would be good to add a + // `DeviceAPI::GetTargetProperty` function and extend "from_device" + // to work for all runtimes. + std::unordered_map defaults = {{"supports_float32", Bool(true)}, + {"supports_int32", Bool(true)}, + {"max_num_threads", Integer(256)}, + {"thread_warp_size", Integer(1)}}; + for (const auto& kv : defaults) { + if (!attrs.count(kv.first)) { + attrs.Set(kv.first, kv.second); + } + } + return attrs; +} + /********** Register Target kinds and attributes **********/ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) @@ -219,6 +298,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("system-lib") .add_attr_option("runtime") .add_attr_option("link-params", Bool(false)) + .add_attr_option("unpacked-api") .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("c", kDLCPU) @@ -229,6 +309,7 @@ TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("march") .add_attr_option("executor") .add_attr_option("workspace-byte-alignment") + .add_attr_option("unpacked-api") .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) @@ -269,13 +350,45 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) TVM_REGISTER_TARGET_KIND("metal", kDLMetal) .add_attr_option("system-lib") .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("thread_warp_size", Integer(16)) .set_default_keys({"metal", "gpu"}); TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("system-lib") - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("thread_warp_size", Integer(1)) - .set_default_keys({"vulkan", "gpu"}); + .add_attr_option("from_device") + // Feature support + .add_attr_option("supports_float16") + .add_attr_option("supports_float32") + .add_attr_option("supports_float64") + .add_attr_option("supports_int8") + .add_attr_option("supports_int16") + .add_attr_option("supports_int32") + .add_attr_option("supports_int64") + .add_attr_option("supports_8bit_buffer") + .add_attr_option("supports_16bit_buffer") + .add_attr_option("supports_storage_buffer_storage_class") + .add_attr_option("supports_push_descriptor") + .add_attr_option("supports_dedicated_allocation") + .add_attr_option("supported_subgroup_operations") + // Physical device limits + .add_attr_option("max_num_threads") + .add_attr_option("thread_warp_size") + .add_attr_option("max_block_size_x") + .add_attr_option("max_block_size_y") + .add_attr_option("max_block_size_z") + .add_attr_option("max_push_constants_size") + .add_attr_option("max_uniform_buffer_range") + .add_attr_option("max_storage_buffer_range") + .add_attr_option("max_per_stage_descriptor_storage_buffer") + .add_attr_option("max_shared_memory_per_block") + // Other device properties + .add_attr_option("device_name") + .add_attr_option("driver_version") + .add_attr_option("vulkan_api_version") + .add_attr_option("max_spirv_version") + // Tags + .set_default_keys({"vulkan", "gpu"}) + .set_attrs_preprocessor(UpdateVulkanAttrs); TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) .add_attr_option("system-lib") diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 386bc539b924..190892b2283f 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -35,11 +35,12 @@ class ProducerToBufferTransformer : public StmtExprMutator { : tensor2buffers_(tensor2buffers) {} PrimExpr VisitExpr_(const ProducerLoadNode* op) final { - te::Tensor tensor = Downcast(op->producer); + auto visited_op = Downcast(StmtExprMutator::VisitExpr_(op)); + te::Tensor tensor = Downcast(visited_op->producer); auto it = tensor2buffers_.find(tensor); ICHECK(it != tensor2buffers_.end()) << "IndexError: Cannot find the tensor " << tensor; const Buffer& buffer = it->second; - return BufferLoad(buffer, op->indices); + return BufferLoad(buffer, visited_op->indices); } private: @@ -101,6 +102,12 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te:: f_push_block_vars(compute_op->axis); f_push_block_vars(compute_op->reduce_axis); + // If we have a rank 0 tensor then we manifest it as a rank 1 buffer with a single element. + if (compute_op->axis.size() == 0) { + iter_vars.push_back(IterVar(Range::FromMinExtent(0, 1), Var(), IterVarType::kDataPar)); + bindings.push_back(Var()); + } + // Step 2. Declare buffer and update op2buffers Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint()); info->tensor2buffers[tensor] = buffer; diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc deleted file mode 100644 index 951bd6c18706..000000000000 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ /dev/null @@ -1,1124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file schedule_postproc_rewrite_for_tensor_core.cc - * - * \brief Rewrite the Stmt generated by ScheduleOps - * to accomondate tensorcore. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "../../runtime/thread_storage_scope.h" - -namespace tvm { -namespace te { - -using namespace te; -using runtime::StorageRank; -using runtime::StorageScope; -using runtime::ThreadScope; - -struct Tile { - int m{-1}; - int n{-1}; - int k{-1}; -}; - -std::string simplify_name(std::string input) { - auto pos = input.find("."); - if (pos != std::string::npos) { - return input.substr(0, pos); - } else { - return input; - } -} - -PrimExpr unpack_type_cast(const PrimExpr& input, const DataType& target_type) { - auto cast = input.as(); - if (cast == nullptr) { - return input; - } else if (cast->dtype == target_type) { - return cast->value; - } - return PrimExpr(); -} - -// MMAMatcher matches C = Cast(A)*Cast(B)+C, -// where A & B are fp16/int8 local buffers, -// and C is fp32/int32 local buffer. -class MMAMatcher : public StmtVisitor { - public: - explicit MMAMatcher(Map extern_buffer) { - for (auto kv : extern_buffer) { - BufferInfo bi; - bi.name = kv.second->name; - bi.dtype = kv.second->dtype; - bi.external = true; - buf_map_[kv.first] = bi; - } - } - - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::pragma_tensor_core) { - tensor_core_on_ = true; - StmtVisitor::VisitStmt_(op); - } else if (op->attr_key == tir::attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; - this->VisitStmt(op->body); - } else { - StmtVisitor::VisitStmt_(op); - } - } - - void VisitStmt_(const ProducerStoreNode* op) final { - StmtVisitor::VisitStmt_(op); - auto it = buf_map_.find(Downcast(op->producer)); - if (it == buf_map_.end()) { - return; - } - const BufferInfo& bi = it->second; - if (bi.released) { - return; - } - if (tensor_core_on_ && mma_sync_match_(op, bi)) { - matched_ = true; - } - } - - void VisitStmt_(const ProducerRealizeNode* op) final { - auto key = Downcast(op->producer); - if (buf_map_.count(key)) { - if (!buf_map_.at(key).external) { - return; - } - this->VisitStmt(op->body); - } else { - BufferInfo bi; - bi.name = key->GetNameHint(); - bi.dtype = key->dtype; - buf_map_[key] = bi; - this->VisitStmt(op->body); - buf_map_[key].released = true; - } - } - - inline bool Matched() const { return matched_; } - - friend class ScheduleAnalyser; - friend class BufferAnalyser; - - private: - struct BufferInfo { - std::string name; - DataType dtype; - bool external{false}; - bool released{false}; - bool same_as(const BufferInfo& bi) { - if (this->dtype != bi.dtype) return false; - if (this->name != bi.name) return false; - if (this->external != bi.external) return false; - if (this->released != bi.released) return false; - return true; - } - }; - - // Check whether the storage scope is local - bool check_local_buffer_(const ProducerLoadNode* op, BufferInfo* bi) { - auto tensor = Downcast(op->producer); - auto it = storage_scope_.find(tensor.get()); - if (it == storage_scope_.end()) { - return false; - } - const std::string& strkey = it->second; - if (strkey != "local") { - return false; - } - auto it1 = buf_map_.find(tensor); - if (it1 == buf_map_.end()) { - return false; - } - *bi = it1->second; - if (bi->released) { - return false; - } - return true; - } - - // Do the pattern matching - bool mma_sync_match_(const ProducerStoreNode* op, BufferInfo store_buffer) { - auto* add = op->value.as(); - if (add == nullptr) { - return false; - } - - auto* load_c = add->a.as(); - BufferInfo buffer_c; - if (!check_local_buffer_(load_c, &buffer_c) || !buffer_c.same_as(store_buffer) || - !(buffer_c.dtype == DataType::Float(32) || buffer_c.dtype == DataType::Int(32))) { - return false; - } - - auto mul = unpack_type_cast(add->b, buffer_c.dtype).as(); - if (mul == nullptr) { - return false; - } - - auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype); - auto load_a = load_a_expr.as(); - BufferInfo buffer_a; - if (!check_local_buffer_(load_a, &buffer_a) || - !(buffer_a.dtype == DataType::Float(16) || buffer_a.dtype == DataType::Int(8) || - buffer_a.dtype == DataType::UInt(8) || buffer_a.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { - return false; - } - - auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype); - auto load_b = load_b_expr.as(); - BufferInfo buffer_b; - if (!check_local_buffer_(load_b, &buffer_b) || - !(buffer_b.dtype == DataType::Float(16) || buffer_b.dtype == DataType::Int(8) || - buffer_b.dtype == DataType::UInt(8) || buffer_b.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { - return false; - } - - frag_reg_.insert(buffer_c.name); - frag_reg_.insert(buffer_a.name); - frag_reg_.insert(buffer_b.name); - buf_name_.insert(std::make_pair(load_a, buffer_a.name)); - buf_name_.insert(std::make_pair(load_b, buffer_b.name)); - mma_sync_.insert(std::make_pair(op, Array{load_a_expr, load_b_expr, add->a})); - - return true; - } - - std::unordered_map buf_map_; - std::unordered_map storage_scope_; - std::unordered_map> mma_sync_; - std::unordered_map buf_name_; - std::unordered_set frag_reg_; - bool matched_{false}; - bool tensor_core_on_{false}; -}; - -// BodyVisitor visits the body stmt of original ComputeOp -// to get the access indices of input matrices, -// if it is recognized as matrix multiply. -class BodyVisitor : public StmtExprVisitor { - public: - BodyVisitor() {} - - void VisitExpr_(const ReduceNode* op) final { - auto* comm_add = op->combiner->result[0].as(); - if (comm_add == nullptr || op->combiner->result.size() > 1) { - return; - } - for (PrimExpr source : op->source) { - auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as(); - auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as(); - if (mul_0 == nullptr && mul_1 == nullptr) { - continue; - } - - tensorcore_candidate_ = true; - StmtExprVisitor::VisitExpr(source); - } - } - - void VisitExpr_(const ProducerLoadNode* op) final { - StmtExprVisitor::VisitExpr_(op); - args_.insert(std::make_pair(op->producer->GetNameHint(), op->indices)); - } - - friend class ScheduleAnalyser; - - private: - std::unordered_map> args_; - bool tensorcore_candidate_{false}; -}; - -// ScheduleAnalyser figures out matrix_a/matrix_b and row_major/col_major -class ScheduleAnalyser { - public: - explicit ScheduleAnalyser(const MMAMatcher& mma_matcher) - : mma_sync_(mma_matcher.mma_sync_), buf_name_(mma_matcher.buf_name_) {} - - bool MatrixIdentify(Schedule schedule) { - // TODO(minmin): handle the case where MatMul is not the output stage - for (Operation output : schedule->outputs) { - const ComputeOpNode* compute = output.as(); - if (compute == nullptr) { - // Not a ComputeOp - continue; - } - auto axis = compute->axis; - auto reduce_axis = compute->reduce_axis; - if (axis.size() < 2 || reduce_axis.size() != 1) { - continue; - } - const VarNode* axis_var[2]; - const VarNode* reduce_axis_var; - axis_var[0] = axis[axis.size() - 2]->var.as(); - axis_var[1] = axis[axis.size() - 1]->var.as(); - reduce_axis_var = reduce_axis[0]->var.as(); - - BodyVisitor body_visitor; - for (PrimExpr expr : compute->body) { - body_visitor(expr); - } - if (!body_visitor.tensorcore_candidate_) { - continue; - } - for (auto iter : body_visitor.args_) { - auto name = iter.first; - auto args = iter.second; - if (args.size() < 2) { - continue; - } - const VarNode* var0 = args[args.size() - 2].as(); - const VarNode* var1 = args[args.size() - 1].as(); - if (var0 == nullptr || var1 == nullptr) { - continue; - } - std::string matrix_abc, major; - if (var0 == reduce_axis_var && var1 == axis_var[1]) { - matrix_abc = "matrix_a"; - major = "col_major"; - } else if (var0 == reduce_axis_var && var1 == axis_var[0]) { - matrix_abc = "matrix_b"; - major = "row_major"; - } else if (var0 == axis_var[1] && var1 == reduce_axis_var) { - matrix_abc = "matrix_a"; - major = "row_major"; - } else if (var0 == axis_var[0] && var1 == reduce_axis_var) { - matrix_abc = "matrix_b"; - major = "col_major"; - } - matrix_abc_.insert(std::make_pair(name, matrix_abc)); - matrix_major_.insert(std::make_pair(name, major)); - } - matrix_abc_.insert(std::make_pair(compute->name, "accumulator")); - matrix_major_.insert(std::make_pair(compute->name, "col_major")); - } - - for (auto& mma_sync : mma_sync_) { - auto& operands = mma_sync.second; - auto* load_a = operands[0].as(); - auto* load_b = operands[1].as(); - auto input0 = simplify_name(buf_name_.find(load_a)->second); - auto input1 = simplify_name(buf_name_.find(load_b)->second); - auto it0 = matrix_abc_.find(input0); - auto it1 = matrix_abc_.find(input1); - - if (it0 == matrix_abc_.end() || it1 == matrix_abc_.end()) { - return false; - } - if (it0->second == "matrix_a" && it1->second == "matrix_b") { - return true; - } else if (it0->second == "matrix_b" && it1->second == "matrix_a") { - mma_sync.second = Array{operands[1], operands[0], operands[2]}; - } else { - return false; - } - } - return true; - } - - friend class BufferAnalyser; - friend class TensorCoreIRMutator; - - private: - std::unordered_map matrix_abc_; - std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; - std::unordered_map buf_name_; -}; - -// IndexVisitor visits access index of fragment -// to record variable for loop scaling -class IndexVisitor : public StmtExprVisitor { - public: - IndexVisitor() {} - - void VisitExpr_(const VarNode* op) final { - loop_scaling_.insert(std::make_pair(op, scaling_factor_)); - } - - friend class BufferAnalyser; - friend class TensorCoreIRMutator; - - private: - std::unordered_map loop_scaling_; - unsigned scaling_factor_{0}; -}; - -// BufferAnalyser gets buffer info, -// e.g. thread tile and warp tile, for TensorCore CodeGen -class BufferAnalyser : public StmtExprVisitor { - public: - explicit BufferAnalyser(Map extern_buffer, - const ScheduleAnalyser& schedule_analyser, const MMAMatcher& mma_matcher) - : matrix_abc_(schedule_analyser.matrix_abc_), - matrix_major_(schedule_analyser.matrix_major_), - frag_reg_(mma_matcher.frag_reg_) { - for (auto kv : extern_buffer) { - BufferInfo bi; - bi.name = kv.second->name; - bi.dtype = kv.second->dtype; - bi.strides = kv.second->strides; - bi.shape = kv.second->shape; - bi.external = true; - buf_map_[kv.first] = bi; - } - } - - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent) { - if (const IntImmNode* value = op->value.as()) { - thread_extent_.insert( - std::make_pair(op->node.as()->var->name_hint, value->value)); - } - StmtExprVisitor::VisitStmt_(op); - } else if (op->attr_key == tir::attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; - this->VisitStmt(op->body); - } else if (op->attr_key == tir::attr::buffer_dim_align) { - te::Tensor tensor = Downcast(op->node); - const CallNode* tuple = op->value.as(); - ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); - auto& vinfo = dim_align_[tensor]; - size_t dim = tuple->args[0].as()->value; - if (dim >= vinfo.size()) { - vinfo.resize(dim + 1); - } - vinfo[dim].align_factor = tuple->args[1].as()->value; - vinfo[dim].align_offset = tuple->args[2].as()->value; - this->VisitStmt(op->body); - } else { - StmtExprVisitor::VisitStmt_(op); - } - } - - void VisitStmt_(const ProducerStoreNode* op) final { - StmtExprVisitor::VisitStmt_(op); - auto key = Downcast(op->producer); - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key->GetNameHint(); - const BufferInfo& bi = it->second; - ICHECK(!bi.released) << "Read a buffer that is already out of scope"; - - if (matrix_abc_.count(key->GetNameHint())) { - if (bi.shape.size() < 2) { - invalid_ = true; - return; - } - for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { - const IntImmNode* shape = bi.shape[i].as(); - if (shape == nullptr || shape->value % 16 != 0) { - invalid_ = true; - return; - } - } - } - - Array strides; - if (bi.strides.size() > 0) { - strides = bi.strides; - } else { - for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = Mul(stride, bi.shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(DataType::Int(32), 1)); - } - strides_.insert(std::make_pair(key->GetNameHint(), strides)); - - if (frag_reg_.count(bi.name)) { - PrimExpr dst = ProducerLoad(op->producer, op->indices); - frag_load_.insert(std::make_pair(op, dst)); - - auto rel_index = bi.RelIndex(op->indices); - if (op->indices.size() < 2) { - invalid_ = true; - return; - } - std::vector tile_size; - for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) { - index_visitor.scaling_factor_ = 16; - if (const IntImmNode* shape = bi.shape[i].as()) { - tile_size.push_back(shape->value); - index_visitor.scaling_factor_ = shape->value; - } else { - invalid_ = true; - return; - } - auto index = rel_index[i]; - auto simplified_index = analyzer_.Simplify(index); - index_visitor(simplified_index); - } - - std::string input_name = simplify_name(bi.name); - auto it = matrix_abc_.find(input_name); - auto it2 = matrix_major_.find(input_name); - bool ret = true; - if (it != matrix_abc_.end() && it2 != matrix_major_.end()) { - if (it->second == "matrix_a" && it2->second == "col_major") { - ret &= assign_or_check_(&thread_tile_.m, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.k, tile_size[1]); - } - if (it->second == "matrix_a" && it2->second == "row_major") { - ret &= assign_or_check_(&thread_tile_.k, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.m, tile_size[1]); - } - if (it->second == "matrix_b" && it2->second == "col_major") { - ret &= assign_or_check_(&thread_tile_.k, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.n, tile_size[1]); - } - if (it->second == "matrix_b" && it2->second == "row_major") { - ret &= assign_or_check_(&thread_tile_.n, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.k, tile_size[1]); - } - if (it->second == "accumulator") { - ret &= assign_or_check_(&thread_tile_.m, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.n, tile_size[1]); - } - if (!ret) { - invalid_ = true; - return; - } - } - } - - const ProducerLoadNode* value = op->value.as(); - // TODO(tvm-team): string matching is dangerous, consider other means. - if (value != nullptr && frag_reg_.count(value->producer->GetNameHint())) { - PrimExpr dst = ProducerLoad(op->producer, op->indices); - frag_store_.insert(std::make_pair(op, dst)); - } - } - - void VisitExpr_(const ProducerLoadNode* op) final { - StmtExprVisitor::VisitExpr_(op); - - auto tensor = Downcast(op->producer); - auto it = buf_map_.find(tensor); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << tensor->GetNameHint(); - const BufferInfo& bi = it->second; - ICHECK(!bi.released) << "Read a buffer that is already out of scope"; - - if (matrix_abc_.count(tensor->op->name)) { - if (bi.shape.size() < 2) { - invalid_ = true; - return; - } - for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { - const IntImmNode* shape = bi.shape[i].as(); - if (shape == nullptr || shape->value % 16 != 0) { - invalid_ = true; - return; - } - } - } - - Array strides; - if (bi.strides.size() > 0) { - strides = bi.strides; - } else { - for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = Mul(stride, bi.shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(DataType::Int(32), 1)); - } - strides_.insert(std::make_pair(tensor->GetNameHint(), strides)); - - if (!frag_reg_.count(bi.name)) { - return; - } - - auto rel_index = bi.RelIndex(op->indices); - if (op->indices.size() < 2) { - invalid_ = true; - return; - } - for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) { - index_visitor.scaling_factor_ = 16; - if (const IntImmNode* shape = bi.shape[i].as()) { - index_visitor.scaling_factor_ = shape->value; - } - auto index = rel_index[i]; - auto simplified_index = analyzer_.Simplify(index); - index_visitor(simplified_index); - } - } - - void VisitStmt_(const ProducerRealizeNode* op) final { - auto key = Downcast(op->producer); - if (buf_map_.count(key)) { - ICHECK(buf_map_.at(key).external); - this->VisitStmt(op->body); - } else { - // create a buffer entry - BufferInfo bi; - - bi.bounds = op->bounds; - Array shape; - for (auto r : bi.bounds) { - shape.push_back(r->extent); - } - - Array strides; - if (dim_align_.count(key) != 0 && shape.size() != 0) { - std::vector rstrides; - const std::vector& avec = dim_align_[key]; - int first_dim = 0; - PrimExpr stride = make_const(shape[first_dim].dtype(), 1); - for (size_t i = shape.size(); i != 0; --i) { - size_t dim = i - 1; - if (dim < avec.size() && avec[dim].align_factor != 0) { - PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); - PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); - stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); - stride = analyzer_.Simplify(stride); - } - rstrides.push_back(stride); - stride = stride * shape[dim]; - } - strides = Array(rstrides.rbegin(), rstrides.rend()); - } - - bi.name = key->GetNameHint(); - bi.dtype = key->dtype; - bi.strides = strides; - bi.shape = shape; - - buf_map_[key] = bi; - this->VisitStmt(op->body); - buf_map_[key].released = true; - } - } - - // Derive warp tile from thread tile, - // and check whether it is qualified for TensorCore. - bool QualifiedForTensorCore() { - if (invalid_) { - return false; - } - auto itx = thread_extent_.find("threadIdx.x"); - if (itx == thread_extent_.end()) { - return false; - } - int warp_threads_x = itx->second; - warp_tile_.m = warp_threads_x * thread_tile_.m; - warp_threads_y_ = 32 / warp_threads_x; - auto ity = thread_extent_.find("threadIdx.y"); - if (ity == thread_extent_.end()) { - return false; - } - if (ity->second < warp_threads_y_ || ity->second % warp_threads_y_ != 0) { - return false; - } - warp_tile_.n = warp_threads_y_ * thread_tile_.n; - warp_tile_.k = thread_tile_.k; - return supported_warp_tile_(); - } - - friend class TensorCoreIRMutator; - - private: - struct DimAlignInfo { - int align_factor{0}; - int align_offset{0}; - }; - - struct BufferInfo { - std::string name; - DataType dtype; - Array strides; - Array shape; - Region bounds; - bool external{false}; - bool released{false}; - inline Array RelIndex(Array args) const { - if (bounds.size() != 0) { - Array index; - ICHECK_EQ(bounds.size(), args.size()); - for (size_t i = 0; i < bounds.size(); ++i) { - index.push_back(args[i] - bounds[i]->min); - } - return index; - } else { - return args; - } - } - }; - - bool assign_or_check_(int* dst, int src) { - if (*dst <= 0) { - *dst = src; - return true; - } - if (*dst == src) { - return true; - } - return false; - } - - bool supported_warp_tile_() { - if (warp_tile_.m == 16 && warp_tile_.n == 16 && warp_tile_.k == 16) { - return true; - } - if (warp_tile_.m == 8 && warp_tile_.n == 32 && warp_tile_.k == 16) { - return true; - } - if (warp_tile_.m == 32 && warp_tile_.n == 8 && warp_tile_.k == 16) { - return true; - } - if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 32) { - return true; - } - if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 128) { - return true; - } - - return false; - } - - std::unordered_map buf_map_; - std::unordered_map> dim_align_; - std::unordered_map storage_scope_; - std::unordered_map matrix_abc_; - std::unordered_map matrix_major_; - std::unordered_set frag_reg_; - std::unordered_map> strides_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; - std::unordered_map thread_extent_; - IndexVisitor index_visitor; - Tile warp_tile_; - Tile thread_tile_; - arith::Analyzer analyzer_; - int warp_threads_y_{-1}; - bool invalid_{false}; -}; - -// ThreadIdxMutator does the thread index unification inside a warp -class ThreadIdxMutator : public StmtExprMutator { - public: - explicit ThreadIdxMutator(PrimExpr warp_y) : warp_y_(warp_y) {} - - PrimExpr VisitExpr_(const VarNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (op != nullptr) { - if (op->name_hint == "threadIdx.x") { - PrimExpr zero = IntImm(DataType::Int(32), 0); - return zero; - } - if (op->name_hint == "threadIdx.y") { - PrimExpr div = Div(expr, warp_y_); - PrimExpr mul = Mul(div, warp_y_); - return mul; - } - } - return expr; - } - - private: - PrimExpr warp_y_; -}; - -// TensorCoreIRMutator mutates the AST for TensorCore CodeGen -// based on tensor core intrinsics -class TensorCoreIRMutator : public StmtExprMutator { - public: - explicit TensorCoreIRMutator(const ScheduleAnalyser& schedule_analyser, - const BufferAnalyser& buffer_analyser) - : matrix_abc_(schedule_analyser.matrix_abc_), - matrix_major_(schedule_analyser.matrix_major_), - mma_sync_(schedule_analyser.mma_sync_), - strides_(buffer_analyser.strides_), - frag_reg_(buffer_analyser.frag_reg_), - loop_scaling_(buffer_analyser.index_visitor.loop_scaling_), - frag_load_(buffer_analyser.frag_load_), - frag_store_(buffer_analyser.frag_store_), - warp_tile_(buffer_analyser.warp_tile_), - warp_threads_y_(buffer_analyser.warp_threads_y_) {} - - Stmt VisitStmt_(const ProducerRealizeNode* op) final { - auto key = Downcast(op->producer); - bounds_[key] = op->bounds; - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - if (op != nullptr) { - if (!frag_reg_.count(key->GetNameHint())) { - return stmt; - } - - auto new_extents = get_tile_size_(simplify_name(key->GetNameHint())); - - Region new_bounds; - for (size_t i = 0; i < op->bounds.size() - 2; ++i) { - new_bounds.push_back(op->bounds[i]); - } - ICHECK_GE(op->bounds.size(), 2) << "Less than 2 dimensions for matrix " << key->GetNameHint(); - new_bounds.push_back( - Range::FromMinExtent(op->bounds[op->bounds.size() - 2]->min, new_extents[0])); - new_bounds.push_back( - Range::FromMinExtent(op->bounds[op->bounds.size() - 1]->min, new_extents[1])); - - return ProducerRealize(op->producer, new_bounds, op->condition, op->body); - } - return stmt; - } - - Stmt VisitStmt_(const AttrStmtNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - if (op->attr_key == tir::attr::realize_scope) { - auto node = op->node.as(); - if (node != nullptr) { - if (!frag_reg_.count(node->name)) { - return stmt; - } - - auto it = matrix_abc_.find(simplify_name(node->name)); - ICHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; - auto matrix_abc = tvm::tir::StringImm("wmma." + it->second); - Stmt body = this->VisitStmt(op->body); - return AttrStmt(op->node, op->attr_key, matrix_abc, body); - } - } - return stmt; - } - - Stmt VisitStmt_(const ProducerStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - auto it = mma_sync_.find(op); - if (it != mma_sync_.end()) { - const auto& operands = it->second; - PrimExpr a = operands[0]; - auto ca = a.as(); - PrimExpr b = operands[1]; - auto cb = b.as(); - PrimExpr c = operands[2]; - auto cc = c.as(); - - ObjectPtr buffer_node_a = make_object(); - ObjectPtr buffer_node_b = make_object(); - ObjectPtr buffer_node_c = make_object(); - - auto mma_sync_call = [&buffer_node_a, &buffer_node_b, &ca, &cb](const Buffer& buffer) { - Buffer buffer_a(buffer_node_a); - Buffer buffer_b(buffer_node_b); - if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { - return Evaluate( - Call(DataType::Handle(), builtin::tvm_bmma_sync(), - {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset})); - } else { - return Evaluate( - Call(DataType::Handle(), builtin::tvm_mma_sync(), - {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset})); - } - }; - - auto call_add_c = [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer& buffer) { - return add_buffer_bind_scope_(cc, buffer_node_c, mma_sync_call); - }; - - auto call_add_b = [this, &cb, &buffer_node_b, &call_add_c](const Buffer& buffer) { - return add_buffer_bind_scope_(cb, buffer_node_b, call_add_c); - }; - - return add_buffer_bind_scope_(ca, buffer_node_a, call_add_b); - } - - auto it2 = frag_load_.find(op); - if (it2 != frag_load_.end()) { - PrimExpr dst = it2->second; - if (op->value.as() != nullptr || op->value.as() != nullptr) { - auto pload = dst.as(); - - auto fill_fragment_call = [this, &op](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_fill_fragment(), - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, op->value})); - }; - - ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, fill_fragment_call); - } - - const ProducerLoadNode* value = op->value.as(); - ICHECK(value != nullptr) << "Can only load fragment from a buffer"; - - auto it = strides_.find(value->producer->GetNameHint()); - ICHECK(it != strides_.end()) << "Cannot find stride for " << value->producer->GetNameHint(); - auto strides = it->second; - ICHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size() - 2]; - - // thread index unification inside a warp - PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); - ThreadIdxMutator thread_idx_mutator(warp_y); - PrimExpr mutated_value = thread_idx_mutator(op->value); - // TODO(tvm-team) The extern function name seems to be a hack. - PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value}); - - auto pload = dst.as(); - PrimExpr matrix_major; - auto iter2 = matrix_major_.find(simplify_name(pload->producer->GetNameHint())); - ICHECK(iter2 != matrix_major_.end()) - << "Can not determine matrix major for " << pload->producer->GetNameHint(); - if (iter2->second == "col_major") { - matrix_major = StringImm("col_major"); - } else if (iter2->second == "row_major") { - matrix_major = StringImm("row_major"); - } else { - LOG(FATAL) << "invalid matrix major for " << pload->producer->GetNameHint(); - } - - auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_load_matrix_sync(), - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, src, stride, matrix_major})); - }; - - ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, load_matrix_call); - } - - auto it3 = frag_store_.find(op); - if (it3 != frag_store_.end()) { - auto it = strides_.find(op->producer->GetNameHint()); - ICHECK(it != strides_.end()) << "Cannot find stride for " << op->producer->GetNameHint(); - auto strides = it->second; - ICHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size() - 2]; - - PrimExpr dst = it3->second; - // thread index unification inside a warp - PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); - ThreadIdxMutator thread_idx_mutator(warp_y); - dst = thread_idx_mutator(dst); - dst = Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst}); - - auto pload = op->value.as(); - - auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_store_matrix_sync(), - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, dst, stride, StringImm("col_major")})); - }; - - ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, store_matrix_call); - } - - return stmt; - } - - Stmt VisitStmt_(const ForNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - if (op != nullptr) { - auto it = loop_scaling_.find(op->loop_var.get()); - if (it != loop_scaling_.end()) { - int scale_factor = it->second; - int scaled_extent_value = 1; - if (const IntImmNode* ori_extent = op->extent.as()) { - int ori_extent_value = ori_extent->value; - scaled_extent_value = ori_extent_value / scale_factor; - } - PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->thread_binding, - op->annotations); - } - } - return stmt; - } - - private: - Array get_tile_size_(const std::string& name) { - auto it = matrix_abc_.find(name); - auto it2 = matrix_major_.find(name); - ICHECK(it != matrix_abc_.end() && it2 != matrix_major_.end()) - << "Cannot find matrix info for " << name; - PrimExpr size0 = make_const(DataType::Int(32), 16); - PrimExpr size1 = make_const(DataType::Int(32), 16); - if (it->second == "matrix_a" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - if (it->second == "matrix_a" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.m); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_b" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.n); - } - if (it->second == "matrix_b" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_c") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - Array tile_size = {size0, size1}; - return tile_size; - } - - Stmt add_buffer_bind_scope_(const ProducerLoadNode* pload, - const ObjectPtr& buffer_node, - const std::function& call_back) { - auto tensor = Downcast(pload->producer); - auto it = bounds_.find(tensor); - ICHECK(it != bounds_.end()); - Array min_bound; - for (auto i : it->second) { - min_bound.push_back(i->min); - } - - ICHECK_GE(it->second.size(), 2); - Array shape; - for (size_t i = 0; i < it->second.size() - 2; ++i) { - shape.push_back(it->second[i]->extent); - } - auto tile_size = get_tile_size_(simplify_name(tensor->op->name)); - shape.push_back(tile_size[0]); - shape.push_back(tile_size[1]); - - Array strides; - for (size_t i = 1; i < shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = shape.size() - 1; j >= i; --j) { - stride = Mul(stride, shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(DataType::Int(32), 1)); - - PrimExpr elem_offset = IntImm(DataType::Int(32), 0); - ICHECK_EQ(pload->indices.size(), min_bound.size()); - for (size_t i = 0; i < min_bound.size(); i++) { - elem_offset = Add(elem_offset, Mul(strides[i], Sub(pload->indices[i], min_bound[i]))); - } - - auto it2 = matrix_abc_.find(simplify_name(tensor->op->name)); - ICHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << tensor->op->name; - buffer_node->data = Var(tensor->op->name, DataType::Handle()); - buffer_node->name = tensor->op->name; - buffer_node->scope = "wmma." + it2->second; - buffer_node->dtype = tensor->dtype; - buffer_node->strides = strides; - buffer_node->shape = shape; - buffer_node->data_alignment = 1; - buffer_node->elem_offset = analyzer_.Simplify(elem_offset); - buffer_node->offset_factor = 1; - Buffer buffer(buffer_node); - - Array args; - for (size_t i = 0; i < pload->indices.size(); ++i) { - args.push_back(pload->indices[i]); - args.push_back(shape[i]); - } - auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args); - Array node = {buffer, tensor}; - return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer)); - } - - std::unordered_map matrix_abc_; - std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; - std::unordered_map> strides_; - std::unordered_set frag_reg_; - std::unordered_map loop_scaling_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; - std::unordered_map bounds_; - arith::Analyzer analyzer_; - Tile warp_tile_; - int warp_threads_y_{-1}; -}; - -Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, - Map extern_buffer) { - // Check if current lower target is CUDA - auto target = tvm::Target::Current(true); - if (target.defined() && target->kind->name != "cuda") { - return stmt; - } - - // Check if current runtime support GPU CUDA - Device dev{kDLCUDA, 0}; - auto api = tvm::runtime::DeviceAPI::Get(dev, true); - if (api == nullptr) { - return stmt; - } - - MMAMatcher mma_matcher(extern_buffer); - mma_matcher(stmt); - if (!mma_matcher.Matched()) { - return stmt; - } - - ScheduleAnalyser schedule_analyser(mma_matcher); - if (!schedule_analyser.MatrixIdentify(schedule)) { - return stmt; - } - - BufferAnalyser buffer_analyser(extern_buffer, schedule_analyser, mma_matcher); - buffer_analyser(stmt); - if (!buffer_analyser.QualifiedForTensorCore()) { - return stmt; - } - - return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt)); -} - -TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore") - .set_body_typed([](Stmt stmt, Schedule schedule, Map extern_buffer) { - return SchedulePostProcRewriteForTensorCore(stmt, schedule, extern_buffer); - }); - -} // namespace te -} // namespace tvm diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 32cc51039be0..5c59961fe011 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -36,7 +36,6 @@ * - Add annotation of extern buffers using the buffer_map field * in the PrimFunc type. */ -#include #include #include #include diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 95c40f9a3c8e..4c59a1767372 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -87,9 +87,7 @@ PrimFuncPass::PrimFuncPass( // Perform Module -> Module optimizations at the PrimFunc level. IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { - const PassInfo& pass_info = Info(); ICHECK(mod.defined()); - pass_ctx.Trace(mod, pass_info, true); std::vector deleted_list; IRModuleNode* mod_ptr = mod.CopyOnWrite(); auto* func_dict = mod_ptr->functions.CopyOnWrite(); @@ -112,7 +110,6 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) for (const auto& gv : deleted_list) { func_dict->erase(gv); } - pass_ctx.Trace(mod, pass_info, false); return mod; } diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 8d52a621b900..dd7fee37e2d1 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -26,13 +26,13 @@ namespace tir { /******** Verification ********/ /*! - * \brief Verify the sref tree state is consistent with the IR + * \brief Verifies the sref tree state is consistent with the IR * \param self The schedule state containing the sref to be verified * \throw An exception will be thrown if the sref tree is not valid */ void VerifySRefTree(const ScheduleState& self); /*! - * \brief Verify the cached flags in the schedule state, including: + * \brief Verifies the cached flags in the schedule state, including: * - affine_binding * - region_cover * - stage_pipeline @@ -41,10 +41,53 @@ void VerifySRefTree(const ScheduleState& self); */ void VerifyCachedFlags(const ScheduleState& self); -/******** Binding ********/ +/******** Scope ********/ +/*! + * \brief Gets the sref to the scope root block, exclusive + * \param sref The block or loop sref to be retrieved + * \return The sref to the scope root block. NullOpt if `sref` is the root block of the IR + */ +Optional GetScopeRoot(const StmtSRef& sref); + +/*! + * \brief Checks if scope the specified sref is in is a stage-pipeline and return it + * \param prim The name of the schedule primitive + * \param self The schedule state + * \param sref The sref whose scope is to be checked + * \throw ScheduleError if the sref has been the root of the AST (so it has no scope root), or its + * scope root is not a stage pipeline + * \return The block sref to the scope root + */ +StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const StmtSRef& sref); + +/*! + * \brief Checks whether the block is a complete block under the scope + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root The sref to the root block of the scope that `block_sref` is in + * \return A boolean indicating if the block is a complete block + * \note Definition of a complete block: + * 1) All block vars are data parallel + * 2) Dominant: the block is the only writer of its output, + * dominating the reader of its output buffers + * 3) No overlap between the buffers the block reads and writes + */ +bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root); + +/*! + * \brief Checks if the block is a complete block + * \param self The schedule state + * \param block_sref The sref to the block whose completeness is to be checked + * \param scope_root_sref The scope root of the block + * \throw ScheduleError If the block is not a complete block + */ +void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref); +/******** Binding ********/ /*! - * \brief Verify if the block binding in a specific BlockRealize is an affine binding. + * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. * The binding can be represented as an injective affine map from the loop iterators. * \param realize The BlockRealize to be analyzed * \param loop_var_ranges The ranges of the loop variables @@ -55,7 +98,7 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va arith::Analyzer* analyzer); /*! - * \brief Extract the ranges of loop variables in a path of the sref tree + * \brief Extracts the ranges of loop variables in a path of the sref tree * \param low_inclusive The lowest node in the path * \param high_exclusive The highest node in the path, defaults to the scope root if not specified * \param extra_relax_scope If the scope is not global, the method will look beyond the limit and @@ -78,7 +121,7 @@ Map GetBindings(const BlockRealize& realize); /******** Block-loop relation ********/ /*! - * \brief Retrieve blocks in a specific function with its name + * \brief Retrieves blocks in a specific function with its name * \param self The schedule state * \param name The name of the blocks to be retrieved * \param func_name The name of the function @@ -86,14 +129,14 @@ Map GetBindings(const BlockRealize& realize); */ Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name); /*! - * \brief Get the parent loops of the block in its scope, from outer to inner + * \brief Gets the parent loops of the block in its scope, from outer to inner * \param self The schedule state * \param block_sref The query block * \return A list of loops above the given block in its scope, from outer to inner */ Array GetLoops(const StmtSRef& block_sref); /*! - * \brief Get the leaf blocks of a scope where a specific block/loop is in + * \brief Gets the leaf blocks of a scope where a specific block/loop is in * \param self The schedule state * \param parent_sref The StmtSRef that points to the parent block/loop * \return A list of leaf blocks diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index e4b767bc40ad..d58dece3c644 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -21,6 +21,146 @@ namespace tvm { namespace tir { +/******** Scope ********/ + +Optional GetScopeRoot(const StmtSRef& sref) { + for (const StmtSRefNode* p = sref->parent; p != nullptr; p = p->parent) { + if (p->stmt->IsInstance()) { + return GetRef(p); + } + } + return NullOpt; +} + +StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const StmtSRef& sref) { + class RootBlockError : public ScheduleError { + public: + explicit RootBlockError(IRModule mod) : mod_(mod) {} + IRModule mod() const final { return mod_; } + String FastErrorString() const final { + return "ScheduleError: The primitive does not operate on the root block"; + } + String DetailRenderTemplate() const final { + return "The primitive does not operate on the root block"; + } + Array LocationsOfInterest() const final { return {}; } + IRModule mod_; + }; + + class NotStagePipelineError : public ScheduleError { + public: + explicit NotStagePipelineError(IRModule mod, Block block) : mod_(mod), block_(block) {} + IRModule mod() const final { return mod_; } + String FastErrorString() const final { + return "ScheduleError: The scope root is not a stage pipeline"; + } + String DetailRenderTemplate() const final { + return R"(The scope {0} is not a stage pipeline. +Definition of a scope that is a stage pipeline: +- The region cover property holds for every of its child blocks +- No write-after-read dependency or opaque dependency, +- only read-after-write and write-after-write are allowed +- All the statements in the scope are schedulable statements, i.e. Block and For +)"; + } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; + }; + + StmtSRef scope_root_sref{nullptr}; + if (Optional opt_scope_root_sref = GetScopeRoot(sref)) { + scope_root_sref = opt_scope_root_sref.value(); + } else { + throw RootBlockError(self->mod); + } + bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline; + if (stage_pipeline == false) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); + throw NotStagePipelineError(self->mod, GetRef(block)); + } + return scope_root_sref; +} + +/*! + * \brief Check the dominant property of a block: + * the block is the only writer of its output, dominating the reader of its output buffers + * \param self The schedule state + * \param block_sref The block whose dominant property is to be checked + * \return A boolean indicating if the block is a dominant block + */ +bool IsDominantBlock(const BlockScope& self, const StmtSRef& block_sref) { + // Check whether the input block is the only writer of its outputs + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& buffer_writers = + self->buffer_writers; + for (const BufferRegion& write_region : block->writes) { + ICHECK(buffer_writers.count(write_region->buffer)) + << "InternalError: buffer \"" << write_region->buffer->name + << "\" does not exist in the current scope, when querying block:\n" + << GetRef(block); + if (buffer_writers.at(write_region->buffer).size() != 1) { + return false; + } + } + return true; +} + +bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root) { + BlockScope scope = self->GetBlockScope(scope_root); + // Cond 1. All block vars are data parallel + const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type != kDataPar) { + return false; + } + } + // Cond 2. Dominant: the block is the only writer of its output, + // dominating the reader of its output buffers + if (!IsDominantBlock(scope, block_sref)) { + return false; + } + // Cond 3. No overlap between the buffers the block reads and writes + std::unordered_set written_buffers; + written_buffers.reserve(block->writes.size()); + for (const BufferRegion& write : block->writes) { + written_buffers.insert(write->buffer.get()); + } + for (const BufferRegion& read : block->reads) { + if (written_buffers.count(read->buffer.get())) { + return false; + } + } + return true; +} + +void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + class IncompleteBlockError : public ScheduleError { + public: + explicit IncompleteBlockError(IRModule mod, Block block) : mod_(mod), block_(block) {} + String FastErrorString() const final { return "ScheduleError: Incomplete block"; } + String DetailRenderTemplate() const final { + return R"(The block {0} is not a complete block. +Definition of a complete block: +1) All block vars are data parallel +2) Dominant: the block is the only writer of its output, dominating the reader of its output buffers +3) No overlap between the buffers the block reads and writes)"; + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; + }; + + bool result = IsCompleteBlock(self, block_sref, scope_root_sref); + if (result == false) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); + throw IncompleteBlockError(self->mod, GetRef(block)); + } +} + /******** Binding ********/ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index ef12f10fa924..0563d39427b1 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -21,9 +21,11 @@ namespace tvm { namespace tir { -Schedule Schedule::Concrete(IRModule mod, int debug_mode) { +Schedule Schedule::Concrete(IRModule mod, int debug_mode, + ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); + n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); return Schedule(std::move(n)); @@ -136,6 +138,7 @@ class ScheduleCopier { scope->src2deps = Copy(old_info.scope->src2deps); scope->dst2deps = Copy(old_info.scope->dst2deps); scope->buffer_writers = Copy(old_info.scope->buffer_writers); + scope->stage_pipeline = old_info.scope->stage_pipeline; new_info.scope = BlockScope(std::move(scope)); result[Copy(old_sref)] = std::move(new_info); } @@ -173,21 +176,80 @@ class ScheduleCopier { void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const { ScheduleCopier::Copy(this, new_state, new_symbol_table); + new_state->get()->DebugVerify(); } Schedule ConcreteScheduleNode::Copy() const { ObjectPtr n = make_object(); - Copy(&n->state_, &n->symbol_table_); + n->error_render_level_ = this->error_render_level_; + this->Copy(&n->state_, &n->symbol_table_); n->analyzer_ = std::make_unique(); return Schedule(std::move(n)); } +/*! \brief Macro that guards the beginning of each invocation of TensorIR schedule primitive */ +#define TVM_TIR_SCHEDULE_BEGIN() try { +/*! + * \brief Macro that pairs with `TVM_TIR_SCHEDULE_BEGIN`, handling potential errors and error + * message rendering + * \param level An ScheduleErrorRenderLevel enum, level of error rendering + * \sa ScheduleErrorRenderLevel + */ +#define TVM_TIR_SCHEDULE_END(primitive, level) \ + } \ + catch (const ScheduleError& error) { \ + if ((level) == ScheduleErrorRenderLevel::kDetail) { \ + throw tvm::runtime::Error(error.RenderReport(primitive)); \ + } else if ((level) == ScheduleErrorRenderLevel::kFast) { \ + throw tvm::runtime::Error(error.FastErrorString()); \ + } else if ((level) == ScheduleErrorRenderLevel::kNone) { \ + throw tvm::runtime::Error("ScheduleError: (not rendered)"); \ + } \ + } + /******** Block/Loop relation ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { + class NotSingleResult : public ScheduleError { + public: + explicit NotSingleResult(String name, IRModule mod, const Array& blocks) + : name_(name), mod_(mod), blocks_{} { + blocks_.reserve(blocks.size()); + for (const StmtSRef& block_sref : blocks) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + blocks_.push_back(GetRef(block)); + } + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {blocks_.begin(), blocks_.end()}; } + + String DetailRenderTemplate() const final { + if (blocks_.empty()) { + return "Cannot find a block with the name: " + name_; + } else { + return "Found " + std::to_string(blocks_.size()) + " blocks with the name: " + name_; + } + } + + String FastErrorString() const final { + if (blocks_.empty()) { + return "ScheduleError: Cannot find a block with the specified name"; + } else { + return "ScheduleError: Found multiple blocks with the specified name"; + } + } + + String name_; + IRModule mod_; + Array blocks_; + }; Array blocks = tir::GetBlocks(this->state_, name, func_name); - CHECK_EQ(blocks.size(), 1) << "ValueError: There are " << blocks.size() - << " blocks with the name: " << name; + if (blocks.size() != 1) { + TVM_TIR_SCHEDULE_BEGIN(); + throw NotSingleResult(name, this->state_->mod, blocks); + TVM_TIR_SCHEDULE_END("get-block", this->error_render_level_); + } return CreateRV(blocks[0]); } @@ -195,6 +257,28 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); } +/******** Schedule: loops manipulation ********/ +/******** Schedule: compute location ********/ + +void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::ComputeInline(state_, this->GetSRef(block_rv)); + TVM_TIR_SCHEDULE_END("compute-inline", this->error_render_level_); + this->state_->DebugVerify(); +} + +void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::ReverseComputeInline(state_, this->GetSRef(block_rv)); + TVM_TIR_SCHEDULE_END("reverse-compute-inline", this->error_render_level_); + this->state_->DebugVerify(); +} + +/******** Schedule: loop binding/annotation ********/ +/******** Schedule: cache read/write ********/ +/******** Schedule: reduction ********/ +/******** Schedule: blockize & tensorize ********/ + /******** FFI ********/ TVM_REGISTER_NODE_TYPE(ConcreteScheduleNode); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 39eab1159db9..8945fb9ee0dc 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -37,6 +37,8 @@ class ConcreteScheduleNode : public ScheduleNode { protected: /*! \brief The internal state of scheduling */ ScheduleState state_; + /*! \brief The level of error rendering */ + ScheduleErrorRenderLevel error_render_level_; /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ TSymbolTable symbol_table_; /*! \brief A persistent stateless arithmetic analyzer. */ @@ -44,6 +46,7 @@ class ConcreteScheduleNode : public ScheduleNode { public: void VisitAttrs(tvm::AttrVisitor* v) { + // `error_render_level_` is not visited // `state_` is not visited // `symbol_table_` is not visited // `analyzer_` is not visitied @@ -74,6 +77,14 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Block/Loop relation ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; + /******** Schedule: loops manipulation ********/ + /******** Schedule: compute location ********/ + void ComputeInline(const BlockRV& block) override; + void ReverseComputeInline(const BlockRV& block) override; + /******** Schedule: loop binding/annotation ********/ + /******** Schedule: cache read/write ********/ + /******** Schedule: reduction ********/ + /******** Schedule: blockize & tensorize ********/ /******** Utility functions ********/ protected: diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc new file mode 100644 index 000000000000..d8dcf57b91e4 --- /dev/null +++ b/src/tir/schedule/error.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace tir { + +String ScheduleError::RenderReport(const String& primitive) const { + IRModule mod = this->mod(); + std::ostringstream os; + os << "ScheduleError: An error occurred in the schedule primitive '" << primitive + << "'.\n\nThe IR is:\n" + << AsTVMScript(mod); + Array locs = LocationsOfInterest(); + int n_locs = locs.size(); + std::vector roi_names; + roi_names.reserve(n_locs); + if (n_locs > 0) { + os << "Regions of interest:\n"; + for (const ObjectRef& obj : locs) { + String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size()); + os << name << "\n" << obj; + roi_names.emplace_back(std::move(name)); + } + os << "\n"; + } + std::string msg = DetailRenderTemplate(); + for (int i = 0; i < n_locs; ++i) { + std::string src = "{" + std::to_string(i) + "}"; + for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { + msg.replace(pos, src.length(), roi_names[i]); + } + } + os << "Error message: " << msg; + return os.str(); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/error.h b/src/tir/schedule/error.h new file mode 100644 index 000000000000..46447cfbde49 --- /dev/null +++ b/src/tir/schedule/error.h @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_ERROR_H_ +#define TVM_TIR_SCHEDULE_ERROR_H_ + +#include + +namespace tvm { +namespace tir { + +/*! \brief Error that happens during TensorIR scheduling */ +class ScheduleError : public tvm::runtime::Error { + public: + /*! \brief Base constructor */ + ScheduleError() : tvm::runtime::Error("") {} + /*! \brief The error occurred in this IRModule */ + virtual IRModule mod() const = 0; + /*! \brief The locations of interest that we want to point out */ + virtual Array LocationsOfInterest() const = 0; + /*! + * \brief Returns an error string template for rendering, corresponds to the "detail" mode. + * \sa ScheduleErrorRenderLevel + * \note The template is a string, e.g. + * "Some error occurred on block {0} and loop {1} blah blah" + * And renderer will replace {0} and {1} according to the list provided LocationsOfInterest. Right + * now it only printed out all the locations in plain text, but in the future, we may want to mark + * the IR with underscores and attach names to each location of interest, like what synr does. + */ + virtual String DetailRenderTemplate() const = 0; + /*! + * \brief Returns an error string without needing to render, corresponds to the "fast" mode + * \sa ScheduleErrorRenderLevel + */ + virtual String FastErrorString() const = 0; + /*! \brief Render the ScheduleError with the template provided by `DetailRenderTemplate` */ + String RenderReport(const String& primitive) const; +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_ERROR_H_ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h new file mode 100644 index 000000000000..ab8299e38169 --- /dev/null +++ b/src/tir/schedule/primitive.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_PRIMITIVE_H_ +#define TVM_TIR_SCHEDULE_PRIMITIVE_H_ + +#include + +namespace tvm { +namespace tir { + +/******** Schedule: loops manipulation ********/ + +/******** Schedule: compute location ********/ +/*! + * \brief Inline a block into its consumer(s). It requires: + * 1) The block is a complete non-root block, which only produces one buffer + * 2) The block must not be the only leaf in the scope. + * 3) The body of the block must be a BufferStore statement in the form of, + * A[i, j, k, ...] = ... + * where the indices of the LHS are all distinct atomic variables, + * and no variables other than those indexing variables are allowed in the statement. + * \param self The state of the schedule + * \param block_sref The sref to the block to be inlined to its consumer(s) + */ +TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref); +/*! + * \brief Inline a block into its only producer. It requires: + * 1) The block is a complete non-root block, which only produces and consumers one buffer + * 2) The block must not be the only leaf in the scope. + * 3) The only producer of the block is a read-after-write producer and a complete non-root block + * 4) The body of the block must be a BufferStore statement in the form of, + * B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...) + * where the indices of each `BufferLoad` on the RHS are all distinct atomic variables, + * and no variables other than those indexing variables are allowed in the statement. + * \param self The state of the schedule + * \param block_sref The sref to the block to be inlined to its producer + */ +TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref); + +/******** Schedule: loop binding/annotation ********/ + +/******** Schedule: cache read/write ********/ + +/******** Schedule: reduction ********/ + +/******** Schedule: blockize & tensorize ********/ + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_PRIMITIVE_H_ diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc new file mode 100644 index 000000000000..6bd6388fafff --- /dev/null +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -0,0 +1,677 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of + 'A[i, j, k, ...] = f(i, j, k, ...)', +where the indices on the left are distinct atomic variables, +and there should not no variables other than the index variables)"; + +static const char kErrBodyReverseInline[] = R"(The body of the inlined block should be in form of + `B[...] = g(i, j, k, A[i, j, k, ...] ...)`, +where A is the only buffer the block consumes, whose indices are distinct atomic variables, +and there should not no variables other than the index variables)"; + +class NotSingleReadWriteBuffer : public ScheduleError { + public: + explicit NotSingleReadWriteBuffer(IRModule mod, bool is_read, Block block) + : mod_(mod), is_read_(is_read), block_(std::move(block)) {} + + String FastErrorString() const final { + return is_read_ ? "ScheduleError: The block is allowed to read only a single buffer region" + : "ScheduleError: The block is allowed to write only a single buffer region"; + } + + String DetailRenderTemplate() const final { + if (is_read_) { + int k = block_->reads.size(); + return "The block is only allowed to read a single buffer region, but it reads " + + std::to_string(k) + " region(s): {0}"; + } else { + int k = block_->writes.size(); + return "The block is only allowed to write a single buffer region, but it writes " + + std::to_string(k) + " region(s): {0}"; + } + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + bool is_read_; + Block block_; + + static Buffer GetSingleRead(const ScheduleState& self, const Block& block) { + if (block->reads.size() != 1) { + throw NotSingleReadWriteBuffer(self->mod, true, block); + } + return block->reads[0]->buffer; + } + + static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) { + if (block->writes.size() != 1) { + throw NotSingleReadWriteBuffer(self->mod, false, block); + } + return block->writes[0]->buffer; + } +}; + +class BodyAnalysisError : public ScheduleError { + public: + explicit BodyAnalysisError(bool is_reverse, IRModule mod, Block block) + : is_reverse_(is_reverse), mod_(mod), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The block cannot be inlined because its body pattern does not meet the " + "condition for inlining"; + } + + String DetailRenderTemplate() const final { + return is_reverse_ ? kErrBodyReverseInline : kErrBodyInline; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + bool is_reverse_; + IRModule mod_; + Block block_; +}; + +class OnlyLeafError : public ScheduleError { + public: + explicit OnlyLeafError(IRModule mod, Block leaf_block, StmtSRef scope_root_sref) + : mod_(mod), leaf_block_(std::move(leaf_block)), scope_root_(nullptr) { + const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref); + this->scope_root_ = GetRef(scope_root); + } + + String FastErrorString() const final { + return "ScheduleError: Cannot remove the only leaf in the scope"; + } + + String DetailRenderTemplate() const final { + return "Block {0} is the only leaf in the scope {1}, which cannot be removed; Otherwise the " + "scope will be empty."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {leaf_block_, scope_root_}; } + + IRModule mod_; + Block leaf_block_; + Block scope_root_; +}; + +class NonSingleProducerError : public ScheduleError { + public: + explicit NonSingleProducerError(IRModule mod, Block block) + : mod_(mod), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The consumer block to be inlined is required to have only a single " + "producer block, and the producer block should be a complete block who has only a " + "single consumer"; + } + + String DetailRenderTemplate() const final { + return "The consumer block {0} to be inlined is required to have only a single " + "producer block, and the producer block should be a complete block who has only a " + "single consumer"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; + + static void Check(const ScheduleState& self, const StmtSRef& consumer_block_sref, + const StmtSRef& scope_root_sref) { + BlockScope scope = self->GetBlockScope(scope_root_sref); + Array producers = scope->GetDepsByDst(consumer_block_sref); + if (producers.size() == 1 && producers[0]->kind == DepKind::kRAW) { + const StmtSRef& producer_block_sref = producers[0]->src; + if (IsCompleteBlock(self, producer_block_sref, scope_root_sref)) { + Array consumers = scope->GetDepsBySrc(producer_block_sref); + if (consumers.size() == 1) { + return; + } + } + } + const BlockNode* block = TVM_SREF_TO_BLOCK(block, consumer_block_sref); + throw NonSingleProducerError(self->mod, GetRef(block)); + } +}; + +class OpaqueAccessError : public ScheduleError { + public: + explicit OpaqueAccessError(IRModule mod, StmtSRef scope_root_sref) + : mod_(mod), scope_root_(nullptr) { + const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref); + this->scope_root_ = GetRef(scope_root); + } + + String FastErrorString() const final { + return "ScheduleError: The buffer to be inlined has opaque access (e.g. `B.data`), or its " + "subregion is matched into other blocks"; + } + + String DetailRenderTemplate() const final { + return "The buffer to be inlined has opaque access (e.g. `B.data`), or its " + "subregion is matched into other blocks: {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {scope_root_}; } + + IRModule mod_; + Block scope_root_; +}; + +/*! + * \brief Construct a new AST, with a specific sref tree leaf removed. + * The leaf's ancestors who have only a single child will be removed too. + * \param leaf_block_sref The block/loop sref to the sref tree leaf to be removed + * \param src_stmt The root of the subtree where the replacement begins + * \param tgt_stmt The root of the subtree after the replacement + * \return A boolean indicating if the leaf can be removed successfully + * \note Removal is not conducted beyond scope-level. + * + * An example of the removal plan, say we are removing the leaf block "B" from the AST. + * + * \code + * with block([], "scope_root"): + * ... + * with block([128, 128], "B") as [vi, vj]: + * B[vi, vj] = A[vi, vj] + 1.0 + * with block([128, 128], "C") as [vi, vj]: + * C[vi, vj] = B[vi, vj] * 2.0 + * \endcode + * + * Ths method does not mutate the AST, instead it returns the a `(src_stmt, tgt_stmt)` pair as a + * plan to substitute certain pieces of the IR. + * + * In our example, it returns block "scope_root" as `src_stmt`, and the result `tgt_stmt` is: + * + * \code + * with block([], "scope_root"): + * ... + * with block([128, 128], "C") as [vi, vj]: + * C[vi, vj] = B[vi, vj] * 2.0 + * \endcode + */ +bool LeafBlockRemovalPlan(const StmtSRef& leaf_block_sref, Stmt* src_stmt, Stmt* tgt_stmt) { + // Go upwards until find an ancestor with more than one child + const StmtNode* last_stmt = leaf_block_sref->stmt; + StmtSRefNode* sref = leaf_block_sref->parent; + for (;; last_stmt = sref->stmt, sref = sref->parent) { + if (const auto* loop = sref->StmtAs()) { + if (const auto* seq = loop->body.as()) { + if (seq->size() > 1) { + break; + } + } + } else { + // Removal is not done beyond scope-level. + // When encountering a block, i.e. the scope root, we simply stop + break; + } + } + if (const auto* block = sref->StmtAs()) { + if (const auto* seq = block->body.as()) { + ObjectPtr n = make_object(*block); + n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + *src_stmt = GetRef(block); + *tgt_stmt = Stmt(std::move(n)); + return true; + } + } + if (const auto* loop = sref->StmtAs()) { + if (const auto* seq = loop->body.as()) { + ObjectPtr n = make_object(*loop); + n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + *src_stmt = GetRef(loop); + *tgt_stmt = Stmt(std::move(n)); + return true; + } + } + return false; +} + +/*! + * \brief The base class of the inliner, which handles: + * 1) Substitute a subtree with the specific block being inlined + * 2) Update the block signature to reflect the changes of read/write/allocated buffers + * 3) Maintain a list of index variables and their substition of the buffer being inlined + */ +class BaseInliner : public StmtExprMutator { + protected: + explicit BaseInliner(const Buffer& inlined_buffer, const Block& inlined_block, + const StmtSRef& scope_root_sref) + : inlined_buffer_(inlined_buffer), + inlined_store_(inlined_block->body.as()), + scope_root_sref_(scope_root_sref) { + AddBuffersInBlockSignature(inlined_block.get()); + } + + PrimExpr VisitExpr_(const VarNode* var) final { + CheckOpaqueAccess(var); + return StmtExprMutator::VisitExpr_(var); + } + + PrimExpr VisitExpr_(const LoadNode* load) final { + CheckOpaqueAccess(load->buffer_var.get()); + return StmtExprMutator::VisitExpr_(load); + } + + Stmt VisitStmt_(const StoreNode* store) final { + CheckOpaqueAccess(store->buffer_var.get()); + return StmtExprMutator::VisitStmt_(store); + } + + Stmt VisitStmt_(const ForNode* loop) final { + if (src_stmt.get() == loop) { + loop = tgt_stmt.as(); + ICHECK(loop != nullptr); + } + return StmtExprMutator::VisitStmt_(loop); + } + + Stmt VisitStmt_(const BlockNode* block) final { + CheckMatchBufferRegion(block); + AddBuffersInBlockSignature(block); + Block src_block = GetRef(block); + if (src_block.same_as(src_stmt)) { + block = tgt_stmt.as(); + ICHECK(block != nullptr); + } + Block tgt_block = Downcast(StmtExprMutator::VisitStmt_(block)); + bool is_scope_root = src_block.get() == scope_root_sref_->stmt; + tgt_block = UpdateBuffersInBlockSignature(std::move(tgt_block), is_scope_root); + block_reuse.Set(src_block, tgt_block); + return std::move(tgt_block); + } + + /*! + * \brief Check if the indices are atomic distinct variables and the access is n-dimensional. + * If so, set `self->idx_vars_` properly. + * \param indices The indices to be extracted + * \param expected_ndim The expected ndim of the access + * \return A boolean flag indicating if the check is successful + */ + bool UpdateAndCheckIndexVars(const Array& indices, int expected_ndim) { + int n = indices.size(); + if (n != expected_ndim) { + // Failure: dimension mismatch + return false; + } + std::vector result; + result.reserve(n); + for (const PrimExpr& i : indices) { + if (const auto* var = i.as()) { + result.push_back(var); + } else { + // Failure: indexing expression is not a variable + return false; + } + } + using DistinctSet = std::unordered_set; + int n_distinct = DistinctSet(result.begin(), result.end()).size(); + if (n != n_distinct) { + // Failure: indexing variables are not distinct + return false; + } + if (idx_vars_.empty()) { + idx_vars_ = std::move(result); + } else if (!support::ArrayWithSameContent(idx_vars_, result)) { + // Failure: indexing variables are not consitent in different BufferLoads + return false; + } + return true; + } + + /*! + * \brief Set the mapping of index substitution `self->idx_sub_` + * \param indices The expressions that the corresponding index variables are replaced to + */ + void SetIndexSubstitution(const Array& indices) { + ICHECK_EQ(indices.size(), idx_vars_.size()); + int n = idx_vars_.size(); + idx_sub_.reserve(n); + for (int i = 0; i < n; ++i) { + idx_sub_[idx_vars_[i]] = indices[i]; + } + } + + private: + /*! + * \brief Add the buffers in the block signature to the `buffer_var_map_`, + * which is used for auto-completion of a block's read/write region + * \param block The block whose signature to be added + */ + void AddBuffersInBlockSignature(const BlockNode* block) { + for (const BufferRegion& buffer_region : block->reads) { + const Buffer& buffer = buffer_region->buffer; + buffer_var_map_.Set(buffer->data, buffer); + } + for (const BufferRegion& buffer_region : block->writes) { + const Buffer& buffer = buffer_region->buffer; + buffer_var_map_.Set(buffer->data, buffer); + } + for (const Buffer& buffer : block->alloc_buffers) { + buffer_var_map_.Set(buffer->data, buffer); + } + } + + /*! + * \brief Update the following block signature: + * 1) tir.alloc_buffer, if the block is scope root + * 2) tir.reads, if the block is not scope root + * 3) tir.writes, if the block is not scope root + * \param block The block to be updated + * \param is_scope_root A flag indicating if a block is the scope root of the block to be inlined + * \return The updated block + */ + Block UpdateBuffersInBlockSignature(Block block, bool is_scope_root) { + // Step 1. Update `BlockNode::alloc_buffers` + Array alloc_buffers; + if (is_scope_root) { + alloc_buffers.reserve(block->alloc_buffers.size()); + for (const Buffer& alloc_buffer : block->alloc_buffers) { + if (!alloc_buffer.same_as(inlined_buffer_)) { + alloc_buffers.push_back(alloc_buffer); + } + } + } else { + alloc_buffers = std::move(block->alloc_buffers); + } + // Step 2. Update `BlockNode::reads` and `BlockNode::writes` + Array reads = std::move(block->reads); + Array writes = std::move(block->writes); + if (!is_scope_root) { + Array> inspected = GetBlockAccessRegion(block, buffer_var_map_); + reads = std::move(inspected[0]); + writes = std::move(inspected[1]); + } + // Step 3. Assemble the result + BlockNode* n = block.CopyOnWrite(); + n->reads = std::move(reads); + n->writes = std::move(writes); + n->alloc_buffers = std::move(alloc_buffers); + return block; + } + + /*! + * \brief Opaque access to the buffer to be inlined is disallowed. + * This method checks if a buffer var belongs to the buffer + * \param buffer_var The buffer var to be checked + */ + void CheckOpaqueAccess(const VarNode* buffer_var) { + if (inlined_buffer_->data.get() == buffer_var) { + this->has_opaque_access = true; + } + } + + /*! + * \brief The buffer to be inlined is not allowed to be region matched. + * This method checks if a block has the disallowed behavior of buffer region match. + * \param block The block to be checked + */ + void CheckMatchBufferRegion(const BlockNode* block) { + for (const MatchBufferRegion& match_buffer_region : block->match_buffers) { + const Buffer& matched = match_buffer_region->source->buffer; + if (matched.same_as(inlined_buffer_)) { + this->has_opaque_access = true; + } + } + } + + protected: + /*! \brief The buffer to be inlined */ + Buffer inlined_buffer_{nullptr}; + /*! \brief The body of the block to be inlined */ + const BufferStoreNode* inlined_store_{nullptr}; + /*! \brief The scope root */ + StmtSRef scope_root_sref_{nullptr}; + /*! \brief Maps a buffer's data field to itself */ + Map buffer_var_map_; + /*! \brief The indices used for indexing the buffer to be inlined */ + std::vector idx_vars_; + /*! \brief The mapping to substitute index variables to PrimExprs */ + std::unordered_map idx_sub_; + + public: + /*! + * \brief The Stmt to be replaced when removing the leaf block + * \note The pair (src_stmt, tgt_stmt) are produced by LeafBlockRemovalPlan to indicate a + * transformation on top of the input AST. We take this approach to avoid changing the AST twice + */ + Stmt src_stmt{nullptr}; + /*! \brief The Stmt to be replaced to when removing the leaf block */ + Stmt tgt_stmt{nullptr}; + /*! \brief The reuse mapping of block srefs */ + Map block_reuse; + /*! \brief Indicates if there is any opaque access of the inlined buffer */ + bool has_opaque_access{false}; +}; + +/*! + * \brief Helper to inline the producer block into its consumer(s) + * The derived class implements the following functionalities: + * 1) Substitute `BufferLoad` on the buffer to be inlined + * to its value calculation in the producer block + * 2) Analyze the producer block to determine the remapping of index variables + */ +class ComputeInliner : public BaseInliner { + public: + explicit ComputeInliner(const Buffer& inlined_buffer, const Block& producer_block, + const StmtSRef& scope_root_sref) + : BaseInliner(inlined_buffer, producer_block, scope_root_sref) {} + + bool BodyPatternAllowInline(const Block& producer_block) { + if (inlined_store_ == nullptr) { + return false; + } + int n_vars = UndefinedVars(GetRef(inlined_store_), {}).size(); + if (!UpdateAndCheckIndexVars(inlined_store_->indices, n_vars)) { + return false; + } + return true; + } + + private: + using BaseInliner::VisitExpr_; + using BaseInliner::VisitStmt_; + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + if (!load->buffer.same_as(inlined_buffer_)) { + return std::move(load); + } + return ReplaceInlinedBuffer(std::move(load)); + } + + PrimExpr ReplaceInlinedBuffer(BufferLoad load) { + SetIndexSubstitution(load->indices); + return Substitute(inlined_store_->value, idx_sub_); + } +}; + +/*! + * \brief Helper to inline the consumer block into its producer + * The derived class implements the following functionalities: + * 1) Analyze the consumer block to determine the remapping of index variables + * 2) Substitute `BufferStore` of the buffer to be inlined, + * replacing it with direct writing to the buffer that consumer writes + */ +class ReverseComputeInliner : public BaseInliner { + class Substituter : public StmtExprMutator { + public: + explicit Substituter(ReverseComputeInliner* self) : self_(self) {} + + private: + PrimExpr VisitExpr_(const VarNode* var) final { + auto it = self_->idx_sub_.find(var); + ICHECK(it != self_->idx_sub_.end()); + return (*it).second; + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + return load->buffer.same_as(self_->inlined_buffer_) ? self_->producer_rhs_ : load; + } + + ReverseComputeInliner* self_; + }; + + public: + explicit ReverseComputeInliner(const Buffer& inlined_buffer, const Block& consumer_block, + const StmtSRef& scope_root_sref) + : BaseInliner(inlined_buffer, consumer_block, scope_root_sref) {} + + bool BodyPatternAllowInline(const Block& consumer_block) { + if (inlined_store_ == nullptr) { + // Failure: block body is not BufferStore + return false; + } + std::vector loads = ExtractBufferLoad(inlined_buffer_, inlined_store_); + if (loads.size() == 0) { + // Failure: no BufferLoad from the `inlined_buffer_` + return false; + } + int n_vars = UndefinedVars(GetRef(inlined_store_), {}).size(); + for (const BufferLoadNode* load : loads) { + if (!UpdateAndCheckIndexVars(load->indices, n_vars)) { + // Failure: incorrect of inconsistent index vars + return false; + } + } + return true; + } + + private: + using BaseInliner::VisitExpr_; + using BaseInliner::VisitStmt_; + + Stmt VisitStmt_(const BufferStoreNode* _store) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); + if (!store->buffer.same_as(inlined_buffer_)) { + return std::move(store); + } + return ReplaceInlinedBuffer(std::move(store)); + } + + Stmt ReplaceInlinedBuffer(BufferStore producer) { + SetIndexSubstitution(producer->indices); + producer_rhs_ = producer->value; + return Substituter(this)(GetRef(inlined_store_)); + } + + /*! + * \brief Extracts expressions that loads a specific buffer + * \param buffer The buffer to be loaded from + * \param from The BufferStore statement to be extracted from + * \return A list of `BufferLoad` expressions + */ + static std::vector ExtractBufferLoad(const Buffer& buffer, + const BufferStoreNode* from) { + struct Extractor : public ExprVisitor { + void VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.get() == buffer) { + result.push_back(load); + } + ExprVisitor::VisitExpr_(load); + } + const BufferNode* buffer; + std::vector result; + } extractor; + extractor.buffer = buffer.get(); + for (const PrimExpr& expr : from->indices) { + extractor(expr); + } + extractor(from->value); + return std::move(extractor.result); + } + + /*! \brief The RHS value of the producer's BufferStore statement */ + PrimExpr producer_rhs_{nullptr}; +}; + +void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { + const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref); + Block producer_block = GetRef(_producer_block); + Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); + // Step 1. Get the scope block + StmtSRef scope_root_sref = GetScopeRootAndCheckStagePipeline(self, producer_block_sref); + // Step 2. Check completeness + CheckCompleteBlock(self, producer_block_sref, scope_root_sref); + // Step 3. Analyze the block body + ComputeInliner inliner(inlined_buffer, producer_block, scope_root_sref); + if (!inliner.BodyPatternAllowInline(producer_block)) { + throw BodyAnalysisError(false, self->mod, producer_block); + } + // Step 4. Create a plan that removes the leaf block to be inlined + if (!LeafBlockRemovalPlan(producer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt)) { + throw OnlyLeafError(self->mod, producer_block, scope_root_sref); + } + // Step 5. Create an AST where the leaf `producer_block_sref` points to is removed, + // and update other blocks who read from the removed block + Stmt tgt_stmt = inliner(GetRef(scope_root_sref->stmt)); + if (inliner.has_opaque_access) { + throw OpaqueAccessError(self->mod, scope_root_sref); + } + // Step 6. Do the real mutation on the AST and the sref tree in the schedule state + self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); +} + +void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) { + const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref); + Block consumer_block = GetRef(_consumer_block); + Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block); + // Step 1. Get the scope block + StmtSRef scope_root_sref = GetScopeRootAndCheckStagePipeline(self, consumer_block_sref); + // Step 2. Check completeness + CheckCompleteBlock(self, consumer_block_sref, scope_root_sref); + // Step 3. Check if the consumer has a single complete producer + NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref); + // Step 4. Analyze the block body + ReverseComputeInliner inliner(inlined_buffer, consumer_block, scope_root_sref); + if (!inliner.BodyPatternAllowInline(consumer_block)) { + throw BodyAnalysisError(true, self->mod, consumer_block); + } + // Step 5. Create a plan that removes the leaf block to be inlined + if (!LeafBlockRemovalPlan(consumer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt)) { + throw OnlyLeafError(self->mod, consumer_block, scope_root_sref); + } + // Step 6. Create an AST where the leaf `consumer_block_sref` points to is removed, + // and update other blocks who read from the removed block + Stmt tgt_stmt = inliner(GetRef(scope_root_sref->stmt)); + if (inliner.has_opaque_access) { + throw OpaqueAccessError(self->mod, scope_root_sref); + } + // Step 7. Do the real mutation on the AST and the sref tree in the schedule state + self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index b407b07e5312..115f7936f64e 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // /**************** (FFI) Constructor ****************/ TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") - .set_body_typed([](ObjectRef obj, int debug_mode) -> Schedule { + .set_body_typed([](ObjectRef obj, int debug_mode, int error_render_level) -> Schedule { IRModule mod{nullptr}; if (const auto* func = obj.as()) { mod = IRModule({{GlobalVar("main"), GetRef(func)}}); @@ -66,7 +66,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " << obj->GetTypeKey(); } - return Schedule::Concrete(mod, debug_mode); + return Schedule::Concrete(mod, debug_mode, + static_cast(error_render_level)); }); /******** (FFI) Lookup random variables ********/ @@ -121,6 +122,16 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") .set_body_method(&ScheduleNode::GetLoops); +/******** (FFI) loops manipulation ********/ +/******** (FFI) compute location ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") + .set_body_method(&ScheduleNode::ComputeInline); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") + .set_body_method(&ScheduleNode::ReverseComputeInline); +/******** (FFI) loop binding/annotation ********/ +/******** (FFI) cache read/write ********/ +/******** (FFI) reduction ********/ +/******** (FFI) blockize & tensorize ********/ } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index b72fd8e05706..19ed995ac8cc 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -34,7 +34,10 @@ #include "../../printer/text_printer.h" #include "../../runtime/thread_storage_scope.h" +#include "../../support/array.h" #include "./analysis.h" +#include "./error.h" +#include "./primitive.h" namespace tvm { namespace tir { @@ -113,6 +116,33 @@ inline bool CanRelaxStorageUndereThread(const runtime::StorageScope& storage_sco return static_cast(storage_scope.rank) <= static_cast(thread_scope.rank); } +/******** SeqStmt ********/ + +/*! + * \brief Remove a specific Stmt from a SeqStmt. If a SeqStmt contains a BlockRealize, + * whose block is the Stmt to be removed, then remove that BlockRealize too. + * \param seq The SeqStmt to be removed from + * \param to_remove The Stmt to be removed + * \return The removal result + */ +inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { + ICHECK_GT(seq->size(), 1); + Array new_stmts; + new_stmts.reserve(seq->size()); + for (const Stmt& stmt : seq->seq) { + if (to_remove.same_as(stmt)) { + continue; + } + if (const auto* realize = stmt.as()) { + if (to_remove.same_as(realize->block)) { + continue; + } + } + new_stmts.push_back(stmt); + } + return SeqStmt::Flatten(new_stmts); +} + /******** Integer set ********/ /*! @@ -131,22 +161,6 @@ inline Map AsIntSet(const Map& var_dom) { return {result.begin(), result.end()}; } -/*! - * \brief Converts an N-dimensional integer set to N-dimensional region - * \param nd_int_set The integer set - * \return The region as the result of conversion - */ -inline Array AsRegion(const Array& nd_int_set, arith::Analyzer* analyzer) { - Array result; - result.reserve(nd_int_set.size()); - for (const arith::IntSet& int_set : nd_int_set) { - PrimExpr min = analyzer->Simplify(int_set.min()); - PrimExpr extent = analyzer->Simplify(int_set.max() - int_set.min() + 1); - result.push_back(Range::FromMinExtent(std::move(min), std::move(extent))); - } - return result; -} - } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 0cc0086897d8..ee52a6fc0988 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -20,7 +20,6 @@ /*! * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ -#include #include #include #include diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc new file mode 100644 index 000000000000..6e8793fbd367 --- /dev/null +++ b/src/tir/transforms/make_unpacked_api.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file make_unpacked_api.cc Lower PrimFunc to a standard C function API. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "arg_binder.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +PrimFunc MakeUnpackedAPI(PrimFunc&& func) { + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol) << "MakeUnpackedAPI: Expect PrimFunc to have the global_symbol attribute"; + + auto target = func->GetAttr(tvm::attr::kTarget); + ICHECK(target.defined()) << "MakeUnpackedAPI: Require the target attribute"; + + auto* func_ptr = func.CopyOnWrite(); + + // Setup device context + int target_device_type = target.value()->kind->device_type; + Integer device_type(target_device_type); + Integer device_id(0); + PrimExpr node = StringImm("default"); + const Stmt nop = Evaluate(0); + std::vector device_init; + + // Create arg to buffer binder + std::unordered_map vmap; + ArgBinder binder(&vmap); + + // Collect variables and buffers to map between + Array args; + std::vector> var_def; + std::vector> buffer_def; + + for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { + Var param = func_ptr->params[i]; + Var v_arg = Var("arg" + std::to_string(i), param->dtype); + + auto it = func_ptr->buffer_map.find(param); + if (it != func_ptr->buffer_map.end()) { + buffer_def.emplace_back(v_arg, (*it).second); + } else { + var_def.emplace_back(v_arg, param); + } + + args.push_back(v_arg); + } + + // Bind variables then bind buffers to them to ensure correct ordering + for (const auto& kv : var_def) { + binder.Bind(kv.second, kv.first, kv.first->name_hint, true); + } + for (const auto& kv : buffer_def) { + binder.Bind(kv.second->data, kv.first, kv.first->name_hint, true); + } + + if (buffer_def.size()) { + device_init.push_back(AttrStmt(node, attr::device_id, device_id, nop)); + device_init.push_back(AttrStmt(node, attr::device_type, device_type, nop)); + } + + func_ptr->body = MergeNest({device_init, binder.init_nest(), binder.asserts()}, func_ptr->body); + func_ptr->params = args; + func_ptr->ret_type = PrimType(DataType::Int(32)); + + // return the function. + return std::move(func); +} + +namespace transform { + +Pass MakeUnpackedAPI() { + auto pass_func = [](IRModule m, PassContext ctx) { + IRModuleNode* mptr = m.CopyOnWrite(); + std::vector> updates; + + for (const auto& kv : mptr->functions) { + if (auto* n = kv.second.as()) { + PrimFunc func = GetRef(n); + if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDefault) { + auto updated_func = MakeUnpackedAPI(std::move(func)); + updates.push_back({kv.first, updated_func}); + } + } + } + + for (const auto& pair : updates) { + mptr->AddUnchecked(pair.first, pair.second); + } + return m; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakeUnpackedAPI", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.MakeUnpackedAPI").set_body_typed(MakeUnpackedAPI); +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 921c7ad79509..f01d98707586 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -22,7 +22,6 @@ * \brief Split device function from host. */ #include -#include #include #include #include diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 0bce3bbc7f53..db54d5a99a91 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -174,11 +174,31 @@ TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = strided_slice(args[0], args[1], args[2], args[3], args[4]); + Tensor x = args[0]; + Array begin = args[1]; + Array end = args[2]; + Array strides = args[3]; + if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides)) { + Array begin_static = args[1]; + Array end_static = args[2]; + Array strides_static = args[3]; + Array axes = args[4]; + std::string slice_mode = args[5]; + if (axes.size() > 0) { + *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, slice_mode); + } else { + *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); + } + } else { + *rv = dynamic_strided_slice(x, begin, end, strides); + } }); TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = dynamic_strided_slice(args[0], args[1], args[2], args[3]); + te::Tensor begin = args[1]; + te::Tensor end = args[2]; + te::Tensor strides = args[3]; + *rv = dynamic_strided_slice(args[0], begin, end, strides); }); TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/cpp/auto_scheduler_test.cc b/tests/cpp/auto_scheduler_test.cc index 5e4533733d2e..16dfd56a69ea 100644 --- a/tests/cpp/auto_scheduler_test.cc +++ b/tests/cpp/auto_scheduler_test.cc @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 8cc5c4bc0a3a..204a824f9248 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -52,7 +52,7 @@ TEST(BuildModule, Basic) { auto target = Target("llvm"); - auto lowered = lower(s, args, "func", binds); + auto lowered = LowerSchedule(s, args, "func", binds); auto module = build(lowered, target, Target()); auto mali_target = Target("opencl -model=Mali-T860MP4@800Mhz -device=mali"); @@ -116,8 +116,8 @@ TEST(BuildModule, Heterogeneous) { auto args2 = Array({copy, C, elemwise_sub}); std::unordered_map binds; - auto lowered_s1 = lower(s1, args1, "elemwise_add", binds); - auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds); + auto lowered_s1 = LowerSchedule(s1, args1, "elemwise_add", binds); + auto lowered_s2 = LowerSchedule(s2, args2, "elemwise_sub", binds); Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; auto module = build(inputs, Target()); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 63819308a666..7d1fa790146e 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -19,7 +19,10 @@ #include #include -#include +#include +#include +#include +#include #include #include diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index cf22577a791a..f993f9605c91 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -19,7 +19,6 @@ #include #include -#include #include #include #include diff --git a/tests/cpp/relay_dismantler_test.cc b/tests/cpp/relay_dismantler_test.cc index d5c089b26194..8c74d4151818 100644 --- a/tests/cpp/relay_dismantler_test.cc +++ b/tests/cpp/relay_dismantler_test.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - #include #include #include @@ -38,6 +37,8 @@ #include #include +#include + using namespace tvm; using namespace tvm::relay; @@ -69,6 +70,80 @@ TEST(Relay, OutOfStack_cast) { ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*"); } +TEST(Relay, OutOfStack_packed_func) { + constexpr int len = 1e6; + auto foo = [] { + auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32))); + auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0})); + auto add_func = tvm::runtime::Registry::Get("relay.op._make.add"); + auto y = (*add_func)(x, one); + for (int i = 0; i < len; ++i) { + y = (*add_func)(y, one); + } + + // check if still reachable + int k = 0; + Expr e = y; + while (e.defined() && e.as() != nullptr) { + e = e.as()->args[0]; + ++k; + } + ASSERT_EQ(len + 1, k); + }; + ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*"); +} + +TEST(Relay, CallNodeSharedArgs) { + auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32))); + auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0})); + auto relu_op = relay::Op::Get("nn.relu"); + Call y = relay::Call(relu_op, {x}, Attrs(), {}); + y = relay::Call(relu_op, {y}, Attrs(), {}); + ASSERT_EQ(1, y.get()->args[0].as()->args.size()); + y = relay::Call(y.get()->op, y.get()->args, y.get()->attrs, y.get()->type_args); + ASSERT_EQ(1, y.get()->args[0].as()->args.size()); +} + +TEST(Relay, TupleSharedFields) { + auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32))); + auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0})); + auto relu_op = relay::Op::Get("nn.relu"); + Expr y = relay::Call(relu_op, {x}, Attrs(), {}); + y = relay::Call(relu_op, {y}, Attrs(), {}); + { + Expr y1 = relay::Tuple(y.as()->args); + Expr y2 = relay::Tuple(y.as()->args); + + y1 = relay::Call(relu_op, {y1}); + y2 = relay::Call(relu_op, {y2}); + y = y1; + } + ASSERT_EQ(1, y.as()->args[0].as()->fields[0].as()->args.size()); +} + +TEST(Relay, TupleiGetItemSharedTuple) { + auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32))); + auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0})); + auto relu_op = relay::Op::Get("nn.relu"); + Expr y = relay::Call(relu_op, {x}, Attrs(), {}); + y = relay::Tuple({y}); + { + Expr y1 = relay::TupleGetItem(y, 0); + Expr y2 = relay::TupleGetItem(y, 0); + + y1 = relay::Call(relu_op, {y1}); + y2 = relay::Call(relu_op, {y2}); + y = y1; + } + ASSERT_EQ(1, y.as() + ->args[0] + .as() + ->tuple.as() + ->fields[0] + .as() + ->args.size()); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/cpp/relay_transform_sequential_test.cc b/tests/cpp/relay_transform_sequential_test.cc index 289574aef1e2..6d38e1017042 100644 --- a/tests/cpp/relay_transform_sequential_test.cc +++ b/tests/cpp/relay_transform_sequential_test.cc @@ -121,6 +121,14 @@ TEST(Relay, Sequential) { ICHECK(tvm::StructuralEqual()(f, expected)); } +TEST(PassContextListConfigs, Basic) { + Map> configs = relay::transform::PassContext::ListConfigs(); + ICHECK_EQ(configs.empty(), false); + + auto config = configs["relay.backend.use_auto_scheduler"]; + ICHECK_EQ(config["type"], "IntImm"); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/cpp/texture_copy_test.cc b/tests/cpp/texture_copy_test.cc new file mode 100644 index 000000000000..688bcab758ca --- /dev/null +++ b/tests/cpp/texture_copy_test.cc @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include +#include + +TEST(TextureCopy, HostDeviceRT) { + using namespace tvm; + bool enabled = tvm::runtime::RuntimeEnabled("opencl"); + if (!enabled) { + LOG(INFO) << "Skip texture copy test because opencl runtime is disabled.\n"; + return; + } + + std::vector shape{16, 16, 4}; + auto cpu_arr0 = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr1 = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + String mem_scope = "global.texture"; + auto opencl_txarr0 = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLOpenCL, 0}, mem_scope); + + size_t size = 1; + for (size_t i = 0; i < shape.size(); ++i) { + size *= static_cast(shape[i]); + } + + std::random_device dev; + std::mt19937 mt(dev()); + std::uniform_real_distribution<> random(-10.0, 10.0); + + // Random initialize host ndarray + for (size_t i = 0; i < size; i++) { + static_cast(cpu_arr0->data)[i] = random(mt); + } + + // Do a roundtrip from host storage to opencl texture storage and back + cpu_arr0.CopyTo(opencl_txarr0); + opencl_txarr0.CopyTo(cpu_arr1); + for (size_t i = 0; i < size; ++i) { + ICHECK_LT( + std::fabs(static_cast(cpu_arr1->data)[i] - static_cast(cpu_arr0->data)[i]), + 1e-5); + } +} + +TEST(TextureCopy, OverwritePoolSubview) { + using namespace tvm; + bool enabled = tvm::runtime::RuntimeEnabled("opencl"); + if (!enabled) { + LOG(INFO) << "Skip texture copy test because opencl runtime is disabled.\n"; + return; + } + + std::vector shape{16, 16, 4}; + std::vector shape_pool{32, 32, 4}; + auto cpu_arr0 = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr1 = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_pool0 = runtime::NDArray::Empty(shape_pool, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_pool1 = runtime::NDArray::Empty(shape_pool, {kDLFloat, 32, 1}, {kDLCPU, 0}); + + String mem_scope = "global.texture"; + auto opencl_txpool = + runtime::NDArray::Empty(shape_pool, {kDLFloat, 32, 1}, {kDLOpenCL, 0}, mem_scope); + auto opencl_txarr0 = opencl_txpool.CreateView(shape, {kDLFloat, 32, 1}); + + std::random_device dev; + std::mt19937 mt(dev()); + std::uniform_real_distribution<> random(-10.0, 10.0); + + size_t size = 1; + size_t size_pool = 1; + for (size_t i = 0; i < shape_pool.size(); ++i) { + size *= static_cast(shape[i]); + size_pool *= static_cast(shape_pool[i]); + } + + // Random initialize host pool storage + for (size_t i = 0; i < size_pool; i++) { + static_cast(cpu_pool0->data)[i] = random(mt); + } + + // Random initialize host array + for (int64_t h = 0; h < shape[0]; h++) { + for (int64_t w = 0; w < shape[1]; w++) { + for (int64_t rgba = 0; rgba < shape[2]; rgba++) { + static_cast(cpu_arr0->data)[shape[1] * shape[2] * h + shape[2] * w + rgba] = 1.1f; + } + } + } + + // Copy to texture pool for initialization + cpu_pool0.CopyTo(opencl_txpool); + // Copy host data to subview into texture storage + cpu_arr0.CopyTo(opencl_txarr0); + // Copy modified pool back + opencl_txpool.CopyTo(cpu_pool1); + + // Check that modifications to pool follow two dimensional + // strides according to the written texture shape. + for (int64_t h = 0; h < shape_pool[0]; h++) { + for (int64_t w = 0; w < shape_pool[1]; w++) { + for (int64_t rgba = 0; rgba < shape_pool[2]; rgba++) { + size_t i = shape_pool[1] * shape_pool[2] * h + shape_pool[2] * w + rgba; + if (h < shape[0] && w < shape[1] && rgba < shape[2]) { + size_t j = shape[1] * shape[2] * h + shape[2] * w + rgba; + ICHECK_LT(std::fabs(static_cast(cpu_pool1->data)[i] - + static_cast(cpu_arr0->data)[j]), + 1e-5); + } else { + ICHECK_LT(std::fabs(static_cast(cpu_pool1->data)[i] - + static_cast(cpu_pool0->data)[i]), + 1e-5); + } + } + } + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index 967df8d2b7b4..73d06ccf7c4c 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -129,19 +129,30 @@ "tests/micro/zephyr/testdata/digit-2.jpg", "tests/micro/zephyr/testdata/digit-9.jpg", "tests/micro/zephyr/testdata/mnist-8.onnx", + "tests/micro/zephyr/testdata/ic_sample_fp32_8.npy", # microTVM Zephyr runtime - "apps/microtvm/zephyr/demo_runtime/prj.conf", - "apps/microtvm/zephyr/demo_runtime/boards/qemu_x86.conf", - "apps/microtvm/zephyr/demo_runtime/boards/qemu_riscv32.conf", - "apps/microtvm/zephyr/demo_runtime/boards/qemu_riscv64.conf", - "apps/microtvm/zephyr/demo_runtime/boards/nrf5340dk_nrf5340_cpuapp.conf", - "apps/microtvm/zephyr/demo_runtime/boards/nucleo_f746zg.conf", - "apps/microtvm/zephyr/demo_runtime/boards/stm32f746g_disco.conf", - "apps/microtvm/zephyr/demo_runtime/boards/mps2_an521.conf", - "apps/microtvm/zephyr/demo_runtime/qemu-hack/qemu-system-i386", - "apps/microtvm/zephyr/demo_runtime/qemu-hack/qemu-system-arm", - "apps/microtvm/zephyr/demo_runtime/qemu-hack/qemu-system-riscv32", - "apps/microtvm/zephyr/demo_runtime/qemu-hack/qemu-system-riscv64", + "apps/microtvm/zephyr/qemu-hack/qemu-system-i386", + "apps/microtvm/zephyr/qemu-hack/qemu-system-arm", + "apps/microtvm/zephyr/qemu-hack/qemu-system-riscv32", + "apps/microtvm/zephyr/qemu-hack/qemu-system-riscv64", + "apps/microtvm/zephyr/host_driven/prj.conf", + "apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf", + "apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf", + "apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf", + "apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf", + "apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf", + "apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf", + "apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf", + "apps/microtvm/zephyr/host_driven/qemu-hack", + "apps/microtvm/zephyr/aot_demo/prj.conf", + "apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf", + "apps/microtvm/zephyr/aot_demo/boards/qemu_riscv32.conf", + "apps/microtvm/zephyr/aot_demo/boards/qemu_riscv64.conf", + "apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf", + "apps/microtvm/zephyr/aot_demo/boards/nucleo_f746zg.conf", + "apps/microtvm/zephyr/aot_demo/boards/stm32f746g_disco.conf", + "apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf", + "apps/microtvm/zephyr/aot_demo/qemu-hack", # microTVM Virtual Machines "apps/microtvm/reference-vm/zephyr/Vagrantfile", "apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template", diff --git a/tests/lint/git-black.sh b/tests/lint/git-black.sh index 993a2b2a9ff8..4052c052463e 100755 --- a/tests/lint/git-black.sh +++ b/tests/lint/git-black.sh @@ -37,6 +37,10 @@ if [[ "$#" -lt 1 ]]; then exit 1 fi +# required to make black's dep click to work +export LC_ALL=C.UTF-8 +export LANG=C.UTF-8 + if [ ! -x "$(command -v black)" ]; then echo "Cannot find black" exit 1 @@ -61,5 +65,5 @@ if [[ ${INPLACE_FORMAT} -eq 1 ]]; then "${CMD[@]}" else echo "Running black in checking mode" - black --diff --check ${FILES[@]} + python3 -m black --diff --check ${FILES[@]} fi diff --git a/tests/micro/test_runtime_micro_on_arm.py b/tests/micro/test_runtime_micro_on_arm.py deleted file mode 100644 index 0212c3ea2692..000000000000 --- a/tests/micro/test_runtime_micro_on_arm.py +++ /dev/null @@ -1,370 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os - -import numpy as np -import tvm -from tvm import te -from tvm.contrib import graph_executor, utils -from tvm import relay -import tvm.micro as micro -from tvm.micro import create_micro_mod - -# Use real micro device - an STM32F746 discovery board -# SETUP: -# Be sure to have openocd installed and running -# Ex : openocd -f board/stm32f7discovery.cfg -# Be sure to have the ST CMSIS library downloaded, installed and -# Ex : export CMSIS_ST_PATH="/home/yourid/st/STM32Cube_FW_F7_V1.16.0/Drivers/CMSIS" -DEV_CONFIG_A = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666) -DEV_CONFIG_B = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666) -TARGET = "micro_dev" - - -def relay_micro_build(func, dev_config, params=None): - """Create a graph executor module with a micro device context from a Relay function. - - Parameters - ---------- - func : relay.Function - function to compile - - dev_config : Dict[str, Any] - MicroTVM config dict for the target device - - params : dict - input parameters that do not change during inference - - Return - ------ - mod : tvm.runtime.Module - graph executor module for the target device - """ - with tvm.transform.PassContext( - disabled_pass={"FuseOps"}, config={"tir.disable_vectorize": True} - ): - graph, c_mod, params = relay.build(func, target=TARGET, params=params) - micro_mod = micro.create_micro_mod(c_mod, dev_config) - ctx = tvm.micro_dev(0) - mod = graph_executor.create(graph, micro_mod, ctx) - mod.set_input(**params) - return mod - - -GDB_INIT_TEMPLATE = """ -layout asm -target remote localhost:{gdb_port} -set $pc = UTVMInit -break UTVMDone -""" - - -def reset_gdbinit(): - if "server_port" not in DEV_CONFIG_A: - return - try: - gdb_init_dir = os.environ["MICRO_GDB_INIT_DIR"] - except KeyError: - return - with open(f"{gdb_init_dir}/.gdbinit", "w") as f: - gdb_port = DEV_CONFIG_A["server_port"] - 3333 - f.write(GDB_INIT_TEMPLATE.format(gdb_port=gdb_port)) - - -def test_alloc(): - """Test tensor allocation on the device.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - with micro.Session(DEV_CONFIG_A): - ctx = tvm.micro_dev(0) - np_tensor = np.random.uniform(size=shape).astype(dtype) - micro_tensor = tvm.nd.array(np_tensor, ctx) - tvm.testing.assert_allclose(np_tensor, micro_tensor.numpy()) - - -def test_add(): - """Test a module which performs addition.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - - reset_gdbinit() - - # Construct TVM expression. - tvm_shape = tvm.runtime.convert(shape) - A = te.placeholder(tvm_shape, name="A", dtype=dtype) - B = te.placeholder(tvm_shape, name="B", dtype=dtype) - C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") - s = te.create_schedule(C.op) - - func_name = "fadd" - c_mod = tvm.build(s, [A, B, C], target="c", name=func_name) - - with micro.Session(DEV_CONFIG_A) as sess: - micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) - micro_func = micro_mod[func_name] - ctx = tvm.micro_dev(0) - - a_np = np.random.uniform(size=shape).astype(dtype) - a = tvm.nd.array(a_np, ctx) - b_np = np.random.uniform(size=shape).astype(dtype) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx) - micro_func(a, b, c) - - # ensure inputs weren't corrupted - tvm.testing.assert_allclose(a.numpy(), a_np) - tvm.testing.assert_allclose(b.numpy(), b_np) - # ensure output is correct - tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) - - -def test_workspace_add(): - """Test a module which uses a workspace to compute an intermediate value.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - - reset_gdbinit() - - # Construct TVM expression. - tvm_shape = tvm.runtime.convert(shape) - A = te.placeholder(tvm_shape, name="A", dtype=dtype) - B = te.placeholder(tvm_shape, name="B", dtype=dtype) - B = te.compute(A.shape, lambda *i: A(*i) + 1, name="B") - C = te.compute(A.shape, lambda *i: B(*i) + 1, name="C") - s = te.create_schedule(C.op) - - func_name = "fadd_two_workspace" - c_mod = tvm.build(s, [A, C], target="c", name=func_name) - - with micro.Session(DEV_CONFIG_A) as sess: - micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) - micro_func = micro_mod[func_name] - ctx = tvm.micro_dev(0) - a_np = np.random.uniform(size=shape).astype(dtype) - a = tvm.nd.array(a_np, ctx) - c = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx) - micro_func(a, c) - - # ensure input wasn't corrupted - tvm.testing.assert_allclose(a.numpy(), a_np) - # ensure output is correct - tvm.testing.assert_allclose(c.numpy(), a.numpy() + 2.0) - - -def test_graph_executor(): - """Test a program which uses the graph executor.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - - # Construct Relay program. - x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) - xx = relay.multiply(x, x) - z = relay.add(xx, relay.const(1.0)) - func = relay.Function([x], z) - - with micro.Session(DEV_CONFIG_A): - mod = relay_micro_build(func, DEV_CONFIG_A) - - x_in = np.random.uniform(size=shape[0]).astype(dtype) - mod.run(x=x_in) - result = mod.get_output(0).numpy() - - tvm.testing.assert_allclose(mod.get_input(0).numpy(), x_in) - tvm.testing.assert_allclose(result, x_in * x_in + 1.0) - - -def test_conv2d(): - if not tvm.runtime.enabled("micro_dev"): - return - - from tvm.relay import create_executor - from tvm.relay import transform - - dshape = (1, 4, 16, 16) - dtype = "int8" - func_name = "fused_nn_conv2d" - - reset_gdbinit() - - # Construct Relay program. - x = relay.var("x", shape=dshape, dtype=dtype) - conv_expr = relay.nn.conv2d(x, relay.var("w"), kernel_size=(3, 3), padding=(1, 1), channels=4) - func = relay.Function(relay.analysis.free_vars(conv_expr), conv_expr) - mod = tvm.IRModule.from_expr(func) - mod = transform.InferType()(mod) - - x_shape = list(map(lambda x: x.value, mod["main"].params[0].checked_type.shape)) - w_shape = list(map(lambda x: x.value, mod["main"].params[1].checked_type.shape)) - out_shape = list(map(lambda x: x.value, mod["main"].ret_type.shape)) - - with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - graph, c_mod, params = relay.build(mod, target="c") - - with micro.Session(DEV_CONFIG_A): - micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) - candidate_func_name = func_name - for i in range(100): - try: - micro_func = micro_mod[candidate_func_name] - break - except tvm.TVMError as e: - candidate_func_name = f"{func_name}_{i}" - else: - assert False - ctx = tvm.micro_dev(0) - - x_data = tvm.nd.array(np.random.uniform(size=x_shape).astype(dtype), ctx) - w_data = tvm.nd.array(np.random.uniform(size=w_shape).astype(dtype), ctx) - result = tvm.nd.array(np.zeros(shape=out_shape, dtype=dtype), ctx) - micro_func(x_data, w_data, result) - - out_data = np.zeros(out_shape, dtype=dtype) - params = {"x": x_data.numpy(), "w": w_data.numpy()} - intrp = create_executor("debug") - expected_result = intrp.evaluate(mod["main"])(x_data, w_data) - - tvm.testing.assert_allclose(result.numpy(), expected_result.numpy()) - - -def test_interleave_sessions(): - """Test closing and reopening sessions.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - - # Construct Relay add program. - x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) - ret = relay.add(x, relay.const(1.0)) - add_const_func = relay.Function([x], ret) - - sess_a = micro.Session(DEV_CONFIG_A) - sess_b = micro.Session(DEV_CONFIG_B) - with sess_a: - np_tensor_a = np.random.uniform(size=shape).astype(dtype) - micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) - with sess_b: - np_tensor_b = np.random.uniform(size=shape).astype(dtype) - micro_tensor_b = tvm.nd.array(np_tensor_b, tvm.micro_dev(0)) - with sess_a: - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) - add_const_mod.run(x=micro_tensor_a) - add_result = add_const_mod.get_output(0).numpy() - tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0) - with sess_b: - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_B) - add_const_mod.run(x=micro_tensor_b) - add_result = add_const_mod.get_output(0).numpy() - tvm.testing.assert_allclose(add_result, np_tensor_b + 1.0) - - -def test_nested_sessions(): - """Test entering and exiting nested session contexts.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - - # Construct Relay add program. - x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) - ret = relay.add(x, relay.const(1.0)) - add_const_func = relay.Function([x], ret) - - sess_a = micro.Session(DEV_CONFIG_A) - sess_b = micro.Session(DEV_CONFIG_B) - with sess_a: - np_tensor_a = np.random.uniform(size=shape).astype(dtype) - micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) - with sess_b: - np_tensor_b = np.random.uniform(size=shape).astype(dtype) - micro_tensor_b = tvm.nd.array(np_tensor_b, tvm.micro_dev(0)) - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) - add_const_mod.run(x=micro_tensor_a) - add_result = add_const_mod.get_output(0).numpy() - tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0) - - -def test_inactive_session_use(): - """Test the use of objects allocated in a session that is no longer active.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - - # Construct Relay add program. - x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) - ret = relay.add(x, relay.const(1.0)) - add_const_func = relay.Function([x], ret) - - sess_a = micro.Session(DEV_CONFIG_A) - sess_b = micro.Session(DEV_CONFIG_B) - with sess_a: - np_tensor_a = np.random.uniform(size=shape).astype(dtype) - micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) - - with sess_b: - # These objects belong to `sess_a`. - add_const_mod.run(x=micro_tensor_a) - add_result = add_const_mod.get_output(0).numpy() - tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0) - - -# TODO add workspace alloc/free stress test - -if __name__ == "__main__": - test_alloc() - print() - print("finished alloc test") - input("[press enter to continue]") - test_add() - print() - print("finished add test") - input("[press enter to continue]") - test_workspace_add() - print() - print("finished workspace add test") - input("[press enter to continue]") - test_graph_executor() - print() - print("finished graph executor test") - input("[press enter to continue]") - test_conv2d() - print() - print("finished conv2d test") - input("[press enter to continue]") - # disable for now as these are currently broken - # test_interleave_sessions() - # print() - # print('finished interleaved sessions test') - # input('[press enter to continue]') - # test_nested_sessions() - # print() - # print('finished nested sessions test') - # input('[press enter to continue]') - test_inactive_session_use() - print() - print("finished use inactive session test") - input("[press enter to continue]") diff --git a/tests/micro/zephyr/conftest.py b/tests/micro/zephyr/conftest.py index 17aeec4afe18..6ca5a530be9d 100644 --- a/tests/micro/zephyr/conftest.py +++ b/tests/micro/zephyr/conftest.py @@ -44,6 +44,17 @@ def pytest_addoption(parser): parser.addoption( "--west-cmd", default="west", help="Path to `west` command for flashing device." ) + parser.addoption( + "--skip-build", + action="store_true", + help="If set true, reuses build from the previous test run. Otherwise, build from the scratch.", + ) + parser.addoption( + "--tvm-debug", + action="store_true", + default=False, + help="If set true, enable a debug session while the test is running. Before running the test, in a separate shell, you should run: ", + ) def pytest_generate_tests(metafunc): @@ -54,3 +65,13 @@ def pytest_generate_tests(metafunc): @pytest.fixture def west_cmd(request): return request.config.getoption("--west-cmd") + + +@pytest.fixture +def skip_build(request): + return request.config.getoption("--skip-build") + + +@pytest.fixture +def tvm_debug(request): + return request.config.getoption("--tvm-debug") diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index e217ec39ed1d..96bcdfe5d86d 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -42,34 +42,27 @@ import conftest -# If set, build the uTVM binary from scratch on each test. -# Otherwise, reuses the build from the previous test run. -BUILD = True - -# If set, enable a debug session while the test is running. -# Before running the test, in a separate shell, you should run: -# python -m tvm.exec.microtvm_debug_shell -DEBUG = False - _LOG = logging.getLogger(__name__) PLATFORMS = conftest.PLATFORMS -def _make_sess_from_op(model, zephyr_board, west_cmd, op_name, sched, arg_bufs): +def _make_sess_from_op(model, zephyr_board, west_cmd, op_name, sched, arg_bufs, build_config): target = tvm.target.target.micro(model) target = tvm.target.Target(target=target, host=target) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): mod = tvm.build(sched, arg_bufs, target=target, name=op_name) - return _make_session(model, target, zephyr_board, west_cmd, mod) + return _make_session(model, target, zephyr_board, west_cmd, mod, build_config) -def _make_session(model, target, zephyr_board, west_cmd, mod): - test_name = f"{os.path.splitext(os.path.abspath(__file__))[0]}_{zephyr_board}" - prev_build = f"{test_name}-last-build.micro-binary" - workspace_root = ( - f'{test_name}_workspace/{datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")}' +def _make_session(model, target, zephyr_board, west_cmd, mod, build_config): + parent_dir = os.path.dirname(__file__) + filename = os.path.splitext(os.path.basename(__file__))[0] + prev_build = f"{os.path.join(parent_dir, 'archive')}_{filename}_{zephyr_board}_last_build.micro" + workspace_root = os.path.join( + f"{os.path.join(parent_dir, 'workspace')}_{filename}_{zephyr_board}", + datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"), ) workspace_parent = os.path.dirname(workspace_root) if not os.path.exists(workspace_parent): @@ -78,7 +71,7 @@ def _make_session(model, target, zephyr_board, west_cmd, mod): test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) tvm_source_dir = os.path.join(test_dir, "..", "..", "..") - runtime_path = os.path.join(tvm_source_dir, "apps", "microtvm", "zephyr", "demo_runtime") + runtime_path = os.path.join(tvm_source_dir, "apps", "microtvm", "zephyr", "host_driven") compiler = zephyr.ZephyrCompiler( project_dir=runtime_path, board=zephyr_board, @@ -92,14 +85,14 @@ def _make_session(model, target, zephyr_board, west_cmd, mod): opts["lib_opts"]["ccflags"] = ["-std=gnu++14"] flasher_kw = {} - if DEBUG: + if build_config["debug"]: flasher_kw["debug_rpc_session"] = tvm.rpc.connect("127.0.0.1", 9090) session_kw = { "flasher": compiler.flasher(**flasher_kw), } - if BUILD: + if not build_config["skip_build"]: session_kw["binary"] = tvm.micro.build_static_runtime( # the x86 compiler *expects* you to give the exact same dictionary for both # lib_opts and bin_opts. so the library compiler is mutating lib_opts and @@ -122,19 +115,20 @@ def _make_session(model, target, zephyr_board, west_cmd, mod): return tvm.micro.Session(**session_kw) -def _make_add_sess(model, zephyr_board, west_cmd): +def _make_add_sess(model, zephyr_board, west_cmd, build_config): A = tvm.te.placeholder((2,), dtype="int8") B = tvm.te.placeholder((1,), dtype="int8") C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C") sched = tvm.te.create_schedule(C.op) - return _make_sess_from_op(model, zephyr_board, west_cmd, "add", sched, [A, B, C]) + return _make_sess_from_op(model, zephyr_board, west_cmd, "add", sched, [A, B, C], build_config) # The same test code can be executed on both the QEMU simulation and on real hardware. -def test_compile_runtime(platform, west_cmd): +def test_compile_runtime(platform, west_cmd, skip_build, tvm_debug): """Test compiling the on-device runtime.""" model, zephyr_board = PLATFORMS[platform] + build_config = {"skip_build": skip_build, "debug": tvm_debug} # NOTE: run test in a nested function so cPython will delete arrays before closing the session. def test_basic_add(sess): @@ -149,14 +143,15 @@ def test_basic_add(sess): system_lib.get_function("add")(A_data, B_data, C_data) assert (C_data.numpy() == np.array([6, 7])).all() - with _make_add_sess(model, zephyr_board, west_cmd) as sess: + with _make_add_sess(model, zephyr_board, west_cmd, build_config) as sess: test_basic_add(sess) -def test_platform_timer(platform, west_cmd): +def test_platform_timer(platform, west_cmd, skip_build, tvm_debug): """Test compiling the on-device runtime.""" model, zephyr_board = PLATFORMS[platform] + build_config = {"skip_build": skip_build, "debug": tvm_debug} # NOTE: run test in a nested function so cPython will delete arrays before closing the session. def test_basic_add(sess): @@ -176,13 +171,14 @@ def test_basic_add(sess): assert result.mean > 0 assert len(result.results) == 3 - with _make_add_sess(model, zephyr_board, west_cmd) as sess: + with _make_add_sess(model, zephyr_board, west_cmd, build_config) as sess: test_basic_add(sess) -def test_relay(platform, west_cmd): +def test_relay(platform, west_cmd, skip_build, tvm_debug): """Testing a simple relay graph""" model, zephyr_board = PLATFORMS[platform] + build_config = {"skip_build": skip_build, "debug": tvm_debug} shape = (10,) dtype = "int8" @@ -196,7 +192,7 @@ def test_relay(platform, west_cmd): with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): graph, mod, params = tvm.relay.build(func, target=target) - with _make_session(model, target, zephyr_board, west_cmd, mod) as session: + with _make_session(model, target, zephyr_board, west_cmd, mod, build_config) as session: graph_mod = tvm.micro.create_local_graph_executor( graph, session.get_system_lib(), session.device ) @@ -208,9 +204,10 @@ def test_relay(platform, west_cmd): tvm.testing.assert_allclose(result, x_in * x_in + 1) -def test_onnx(platform, west_cmd): +def test_onnx(platform, west_cmd, skip_build, tvm_debug): """Testing a simple ONNX model.""" model, zephyr_board = PLATFORMS[platform] + build_config = {"skip_build": skip_build, "debug": tvm_debug} # Load test images. this_dir = os.path.dirname(__file__) @@ -229,7 +226,7 @@ def test_onnx(platform, west_cmd): relay_mod = relay.transform.DynamicToStatic()(relay_mod) # We add the -link-params=1 option to ensure the model parameters are compiled in. - # There is currently a bug preventing the demo_runtime environment from receiving + # There is currently a bug preventing the host_driven environment from receiving # the model weights when set using graph_mod.set_input(). # See: https://github.com/apache/tvm/issues/7567 target = tvm.target.target.micro(model, options=["-link-params=1"]) @@ -237,7 +234,7 @@ def test_onnx(platform, west_cmd): lowered = relay.build(relay_mod, target, params=params) graph = lowered.get_graph_json() - with _make_session(model, target, zephyr_board, west_cmd, lowered.lib) as session: + with _make_session(model, target, zephyr_board, west_cmd, lowered.lib, build_config) as session: graph_mod = tvm.micro.create_local_graph_executor( graph, session.get_system_lib(), session.device ) @@ -309,14 +306,16 @@ def visit_call(self, call): return super().visit_call(call) -def check_result(relay_mod, model, zephyr_board, west_cmd, map_inputs, out_shape, result): +def check_result( + relay_mod, model, zephyr_board, west_cmd, map_inputs, out_shape, result, build_config +): """Helper function to verify results""" TOL = 1e-5 target = tvm.target.target.micro(model) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): graph, mod, params = tvm.relay.build(relay_mod, target=target) - with _make_session(model, target, zephyr_board, west_cmd, mod) as session: + with _make_session(model, target, zephyr_board, west_cmd, mod, build_config) as session: rt_mod = tvm.micro.create_local_graph_executor( graph, session.get_system_lib(), session.device ) @@ -335,9 +334,10 @@ def check_result(relay_mod, model, zephyr_board, west_cmd, map_inputs, out_shape tvm.testing.assert_allclose(out.numpy(), results[idx], rtol=TOL, atol=TOL) -def test_byoc_utvm(platform, west_cmd): +def test_byoc_utvm(platform, west_cmd, skip_build, tvm_debug): """This is a simple test case to check BYOC capabilities of uTVM""" model, zephyr_board = PLATFORMS[platform] + build_config = {"skip_build": skip_build, "debug": tvm_debug} x = relay.var("x", shape=(10, 10)) w0 = relay.var("w0", shape=(10, 10)) w1 = relay.var("w1", shape=(10, 10)) @@ -391,8 +391,42 @@ def test_byoc_utvm(platform, west_cmd): model=model, zephyr_board=zephyr_board, west_cmd=west_cmd, + build_config=build_config, ) +def _make_add_sess_with_shape(model, zephyr_board, west_cmd, shape, build_config): + A = tvm.te.placeholder(shape, dtype="int8") + C = tvm.te.compute(A.shape, lambda i: A[i] + A[i], name="C") + sched = tvm.te.create_schedule(C.op) + return _make_sess_from_op(model, zephyr_board, west_cmd, "add", sched, [A, C], build_config) + + +@pytest.mark.parametrize( + "shape,", + [ + pytest.param((1 * 1024,), id="(1*1024)"), + pytest.param((4 * 1024,), id="(4*1024)"), + pytest.param((16 * 1024,), id="(16*1024)"), + ], +) +def test_rpc_large_array(platform, west_cmd, skip_build, tvm_debug, shape): + """Test large RPC array transfer.""" + model, zephyr_board = PLATFORMS[platform] + build_config = {"skip_build": skip_build, "debug": tvm_debug} + + # NOTE: run test in a nested function so cPython will delete arrays before closing the session. + def test_tensors(sess): + a_np = np.random.randint(low=-128, high=127, size=shape, dtype="int8") + + A_data = tvm.nd.array(a_np, device=sess.device) + assert (A_data.asnumpy() == a_np).all() + C_data = tvm.nd.array(np.zeros(shape, dtype="int8"), device=sess.device) + assert (C_data.asnumpy() == np.zeros(shape)).all() + + with _make_add_sess_with_shape(model, zephyr_board, west_cmd, shape, build_config) as sess: + test_tensors(sess) + + if __name__ == "__main__": - sys.exit(pytest.main([os.path.dirname(__file__)] + sys.argv[1:])) + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot.py new file mode 100644 index 000000000000..dc277c245078 --- /dev/null +++ b/tests/micro/zephyr/test_zephyr_aot.py @@ -0,0 +1,215 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +from hashlib import new +import logging +import os +import sys +import logging +import pathlib + +import pytest +import numpy as np + +import tvm +import tvm.rpc +import tvm.micro +import tvm.relay as relay + +from tvm.micro.contrib import zephyr +from tvm.contrib import utils +from tvm.contrib.download import download_testdata + +import conftest + +_LOG = logging.getLogger(__name__) + +PLATFORMS = conftest.PLATFORMS + + +def _build_session_kw(model, target, zephyr_board, west_cmd, mod, runtime_path, build_config): + parent_dir = os.path.dirname(__file__) + filename = os.path.splitext(os.path.basename(__file__))[0] + prev_build = f"{os.path.join(parent_dir, 'archive')}_{filename}_{zephyr_board}_last_build.micro" + workspace_root = os.path.join( + f"{os.path.join(parent_dir, 'workspace')}_{filename}_{zephyr_board}", + datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"), + ) + workspace_parent = os.path.dirname(workspace_root) + if not os.path.exists(workspace_parent): + os.makedirs(workspace_parent) + workspace = tvm.micro.Workspace(debug=True, root=workspace_root) + + compiler = zephyr.ZephyrCompiler( + project_dir=runtime_path, + board=zephyr_board, + zephyr_toolchain_variant="zephyr", + west_cmd=west_cmd, + env_vars={"ZEPHYR_RUNTIME": "ZEPHYR-AOT"}, + ) + + opts = tvm.micro.default_options(os.path.join(runtime_path, "crt")) + opts["bin_opts"]["include_dirs"].append(os.path.join(runtime_path, "include")) + opts["lib_opts"]["include_dirs"].append(os.path.join(runtime_path, "include")) + + flasher_kw = {} + if build_config["debug"]: + flasher_kw["debug_rpc_session"] = tvm.rpc.connect("127.0.0.1", 9090) + + session_kw = { + "flasher": compiler.flasher(**flasher_kw), + } + + if not build_config["skip_build"]: + session_kw["binary"] = tvm.micro.build_static_runtime( + workspace, + compiler, + mod, + opts, + executor="aot", + extra_libs=[tvm.micro.get_standalone_crt_lib("memory")], + ) + if os.path.exists(prev_build): + os.unlink(prev_build) + session_kw["binary"].archive(prev_build, metadata_only=True) + else: + unarchive_dir = utils.tempdir() + session_kw["binary"] = tvm.micro.MicroBinary.unarchive( + prev_build, unarchive_dir.relpath("binary") + ) + + return session_kw + + +def _create_header_file(tensor_name, npy_data, output_path): + """ + This method generates a header file containing the data contained in the numpy array provided. + It is used to capture the tensor data (for both inputs and expected outputs). + """ + file_path = pathlib.Path(f"{output_path}/" + tensor_name).resolve() + # create header file + raw_path = file_path.with_suffix(".h").resolve() + with open(raw_path, "w") as header_file: + header_file.write("#include \n") + header_file.write("#include \n") + header_file.write("#include \n") + header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n") + + if npy_data.dtype == "int8": + header_file.write(f"int8_t {tensor_name}[] =") + elif npy_data.dtype == "int32": + header_file.write(f"int32_t {tensor_name}[] = ") + elif npy_data.dtype == "uint8": + header_file.write(f"uint8_t {tensor_name}[] = ") + elif npy_data.dtype == "float32": + header_file.write(f"float {tensor_name}[] = ") + + header_file.write("{") + for i in np.ndindex(npy_data.shape): + header_file.write(f"{npy_data[i]}, ") + header_file.write("};\n\n") + + +def _read_line(fd): + data = "" + new_line = False + while True: + if new_line: + break + new_data = fd.read(1, timeout_sec=10) + logging.debug(f"read data: {new_data}") + for item in new_data: + new_c = chr(item) + data = data + new_c + if new_c == "\n": + new_line = True + break + return data + + +def _get_message(fd, expr: str): + while True: + data = _read_line(fd) + logging.debug(f"new line: {data}") + if expr in data: + return data + + +def test_tflite(platform, west_cmd, skip_build, tvm_debug): + """Testing a TFLite model.""" + model, zephyr_board = PLATFORMS[platform] + input_shape = (1, 32, 32, 3) + output_shape = (1, 10) + build_config = {"skip_build": skip_build, "debug": tvm_debug} + + this_dir = os.path.dirname(__file__) + tvm_source_dir = os.path.join(this_dir, "..", "..", "..") + runtime_path = os.path.join(tvm_source_dir, "apps", "microtvm", "zephyr", "aot_demo") + model_url = "https://github.com/eembc/ulpmark-ml/raw/fc1499c7cc83681a02820d5ddf5d97fe75d4f663/base_models/ic01/ic01_fp32.tflite" + model_path = download_testdata(model_url, "ic01_fp32.tflite", module="model") + + # Import TFLite model + tflite_model_buf = open(model_path, "rb").read() + try: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + + # Load TFLite model and convert to Relay + relay_mod, params = relay.frontend.from_tflite( + tflite_model, shape_dict={"input_1": input_shape}, dtype_dict={"input_1 ": "float32"} + ) + + target = tvm.target.target.micro(model, options=["-link-params=1", "--executor=aot"]) + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + lowered = relay.build(relay_mod, target, params=params) + + # Load sample and generate input/output header files + sample_url = "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/testdata_image_classification_fp32_8.npy" + sample_path = download_testdata( + sample_url, "testdata_image_classification_fp32_8.npy", module="data" + ) + sample = np.load(sample_path) + model_files_path = os.path.join(runtime_path, "include") + _create_header_file((f"input_data"), sample, model_files_path) + _create_header_file( + "output_data", np.zeros(shape=output_shape, dtype="float32"), model_files_path + ) + + session_kw = _build_session_kw( + model, target, zephyr_board, west_cmd, lowered.lib, runtime_path, build_config + ) + transport = session_kw["flasher"].flash(session_kw["binary"]) + transport.open() + transport.write(b"start\n", timeout_sec=5) + + result_line = _get_message(transport, "#result") + result_line = result_line.strip("\n") + result_line = result_line.split(":") + result = int(result_line[1]) + time = int(result_line[2]) + logging.info(f"Result: {result}\ttime: {time} ms") + assert result == 8 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/contrib/test_arm_compute_lib/test_pooling.py b/tests/python/contrib/test_arm_compute_lib/test_pooling.py index 137484330db8..9deaa758639e 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_pooling.py +++ b/tests/python/contrib/test_arm_compute_lib/test_pooling.py @@ -169,34 +169,37 @@ def test_pooling(): fp32_dtype = ("float32", -127, 128, 0.001, 0.001) uint8_dtype = ("uint8", 0, 255, 1, 0) - + # fmt: off trials = [ - ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (27, 27, 512)], - ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], - ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)], - ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)], - ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)], - ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), True, True, (15, 15, 16)], - ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)], - ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], - ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)], + ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (0, 0), False, False, (27, 27, 512), (0, 1),], + ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16), (0, 1),], + ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16), (0, 1),], + ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16), (0, 1),], + ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (3, 2), (1, 1), True, True, (15, 15, 16), (1, 0),], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16), (0, 1),], + ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (3, 2), (0, 1), True, False, (15, 15, 16), (1, 0),], # 20.05: "exclude_padding equal false is not supported for AVG Pooling with padding on quantized types" # ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)], - ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)], - ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), True, False, (16, 16, 16)], - ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (16, 16, 16)], - ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, True, (15, 15, 16)], + ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), True, False, (16, 16, 16), (0, 1),], + ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (0, 0), False, False, (16, 16, 16), (0, 1),], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, True, (15, 15, 16), (0, 1),], ] - + # fmt: on for ( typef, (dtype, low, high, atol, rtol), size, stride, + dilation, pad, ceil_mode, count_include_pad, input_shape, + (tvm_ops, acl_partitions), ) in trials: shape = (1, *input_shape) outputs = [] @@ -205,7 +208,16 @@ def test_pooling(): } func = _get_pooling_model( - shape, dtype, typef, size, stride, pad, ceil_mode, count_include_pad, iter(inputs) + shape, + dtype, + typef, + size, + stride, + dilation, + pad, + ceil_mode, + count_include_pad, + iter(inputs), ) config = { @@ -215,15 +227,25 @@ def test_pooling(): "pooling type": typef, "dtype": dtype, "padding": pad, + "dilation": dilation, "ceil_mode": ceil_mode, "count_include_pad": count_include_pad, "inputs": inputs, } verify_saturation = True if dtype == "uint8" else False - for acl in [False, True]: outputs.append( - build_and_run(func, inputs, 1, None, device, enable_acl=acl, config=config)[0] + build_and_run( + func, + inputs, + 1, + None, + device, + enable_acl=acl, + tvm_ops=tvm_ops, + acl_partitions=acl_partitions, + config=config, + )[0] ) verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=verify_saturation) @@ -283,25 +305,25 @@ def test_codegen_pooling(): fp32_dtype = ("float32", -127, 128) uint8_dtype = ("uint8", 0, 255) - + # fmt: off trials = [ - ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16)], - ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16)], - ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16)], - ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16)], - ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16)], - ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (3, 2), (1, 1), True, True, (15, 15, 16)], - ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, False, (16, 16, 16)], - ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, False, (16, 16, 16)], - ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16)], - ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (3, 2), (0, 1), True, False, (15, 15, 16)], - ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, True, (16, 16, 16)], - ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16)], - ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), True, False, (15, 15, 16)], - ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (0, 0), False, False, (16, 16, 16)], - ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, True, (15, 15, 16)], + ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16), (0, 1),], + ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16), (0, 1),], + ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16), (0, 1),], + ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (3, 2), (1, 1), True, True, (15, 15, 16), (1, 0),], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16), (0, 1),], + ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (3, 2), (0, 1), True, False, (15, 15, 16), (1, 0),], + ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, True, (16, 16, 16), (0, 1),], + ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), True, False, (15, 15, 16), (0, 1),], + ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (0, 0), False, False, (16, 16, 16), (0, 1),], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, True, (15, 15, 16), (0, 1),], ] - + # fmt: on for ( typef, (dtype, low, high), @@ -312,6 +334,7 @@ def test_codegen_pooling(): ceil_mode, count_include_pad, input_shape, + (tvm_ops, acl_partitions), ) in trials: shape = (1, *input_shape) inputs = {"a"} @@ -319,7 +342,7 @@ def test_codegen_pooling(): func = _get_pooling_model(*args, iter(inputs)) exp_codegen = _get_expected_pooling_codegen(*args) - verify_codegen(func, exp_codegen, 1) + verify_codegen(func, exp_codegen, acl_partitions, tvm_ops) def test_codegen_global_pooling(): diff --git a/tests/python/contrib/test_cblas.py b/tests/python/contrib/test_cblas.py index b4fc2b283369..2b99879d8227 100644 --- a/tests/python/contrib/test_cblas.py +++ b/tests/python/contrib/test_cblas.py @@ -158,13 +158,14 @@ def test_quantized_matmul_add(): def verify_batch_matmul( - batch, m, l, n, lib, transa=False, transb=False, iterative=False, dtype="float32" + batch_a, batch_b, m, l, n, lib, transa=False, transb=False, iterative=False, dtype="float32" ): - ashape = (batch, l, n) if transa else (batch, n, l) - bshape = (batch, m, l) if transb else (batch, l, m) + batch = max(batch_a, batch_b) + ashape = (batch_a, l, n) if transa else (batch_a, n, l) + bshape = (batch_b, m, l) if transb else (batch_b, l, m) A = te.placeholder(ashape, name="A", dtype=dtype) B = te.placeholder(bshape, name="B", dtype=dtype) - C = cblas.batch_matmul(A, B, transa, transb) + C = lib.batch_matmul(A, B, transa, transb) D = te.compute(C.shape, lambda k, i, j: C[k, i, j], name="D") s = te.create_schedule(D.op) @@ -207,24 +208,32 @@ def verify(target="llvm"): def test_batch_matmul(): - verify_batch_matmul(16, 235, 128, 1024, cblas) - verify_batch_matmul(16, 235, 128, 1024, cblas, True, False) - verify_batch_matmul(16, 235, 128, 1024, cblas, False, True) - verify_batch_matmul(16, 235, 128, 1024, cblas, True, True) - verify_batch_matmul(16, 235, 128, 1024, mkl) - verify_batch_matmul(16, 235, 128, 1024, mkl, True, False) - verify_batch_matmul(16, 235, 128, 1024, mkl, False, True) - verify_batch_matmul(16, 235, 128, 1024, mkl, True, True) - verify_batch_matmul(1, 1, 16, 3, cblas) - verify_batch_matmul(1, 1, 16, 3, cblas, True, False) - verify_batch_matmul(1, 1, 16, 3, cblas, False, False) - verify_batch_matmul(1, 1, 16, 3, cblas, True, True) - verify_batch_matmul(1, 1, 16, 3, cblas, iterative=True) - verify_batch_matmul(1, 1, 16, 3, mkl) - verify_batch_matmul(1, 1, 16, 3, mkl, True, False) - verify_batch_matmul(1, 1, 16, 3, mkl, False, False) - verify_batch_matmul(1, 1, 16, 3, mkl, True, True) - verify_batch_matmul(1, 1, 16, 3, mkl, iterative=True) + verify_batch_matmul(16, 16, 235, 128, 1024, cblas) + verify_batch_matmul(16, 16, 235, 128, 1024, cblas, True, False) + verify_batch_matmul(16, 16, 235, 128, 1024, cblas, False, True) + verify_batch_matmul(16, 16, 235, 128, 1024, cblas, True, True) + verify_batch_matmul(16, 16, 235, 128, 1024, mkl) + verify_batch_matmul(16, 16, 235, 128, 1024, mkl, True, False) + verify_batch_matmul(16, 16, 235, 128, 1024, mkl, False, True) + verify_batch_matmul(16, 16, 235, 128, 1024, mkl, True, True) + verify_batch_matmul(16, 1, 235, 128, 1024, cblas) + verify_batch_matmul(1, 16, 235, 128, 1024, cblas) + verify_batch_matmul(16, 1, 235, 128, 1024, cblas, iterative=True) + verify_batch_matmul(1, 16, 235, 128, 1024, cblas, iterative=True) + verify_batch_matmul(16, 1, 235, 128, 1024, mkl) + verify_batch_matmul(1, 16, 235, 128, 1024, mkl) + verify_batch_matmul(16, 1, 235, 128, 1024, mkl, iterative=True) + verify_batch_matmul(1, 16, 235, 128, 1024, mkl, iterative=True) + verify_batch_matmul(1, 1, 1, 16, 3, cblas) + verify_batch_matmul(1, 1, 1, 16, 3, cblas, True, False) + verify_batch_matmul(1, 1, 1, 16, 3, cblas, False, False) + verify_batch_matmul(1, 1, 1, 16, 3, cblas, True, True) + verify_batch_matmul(1, 1, 1, 16, 3, cblas, iterative=True) + verify_batch_matmul(1, 1, 1, 16, 3, mkl) + verify_batch_matmul(1, 1, 1, 16, 3, mkl, True, False) + verify_batch_matmul(1, 1, 1, 16, 3, mkl, False, False) + verify_batch_matmul(1, 1, 1, 16, 3, mkl, True, True) + verify_batch_matmul(1, 1, 1, 16, 3, mkl, iterative=True) if __name__ == "__main__": diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py index a0f51ca7c9fc..648100a569d7 100644 --- a/tests/python/contrib/test_cublas.py +++ b/tests/python/contrib/test_cublas.py @@ -112,33 +112,23 @@ def verify(target="cuda"): verify() -def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5): - j = 16 - n = 1024 - l = 128 - m = 236 - A = te.placeholder((j, n, l), name="A", dtype=in_dtype) - B = te.placeholder((j, l, m), name="B", dtype=in_dtype) +def verify_batch_matmul(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5): + A = te.placeholder(Ashape, name="A", dtype=in_dtype) + B = te.placeholder(Bshape, name="B", dtype=in_dtype) C = cublas.batch_matmul(A, B, dtype=out_dtype) s = te.create_schedule(C.op) - def verify(target="cuda"): - if not tvm.get_global_func("tvm.contrib.cublas.matmul", True): - print("skip because extern function is not available") - return - dev = tvm.cuda(0) - f = tvm.build(s, [A, B, C], target) - a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), dev) - f(a, b, c) - tvm.testing.assert_allclose( - c.numpy(), - np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype), - rtol=rtol, - ) - - verify() + dev = tvm.cuda(0) + f = tvm.build(s, [A, B, C], "cuda") + a = tvm.nd.array(np.random.uniform(size=Ashape).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=Bshape).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros(Cshape, dtype=C.dtype), dev) + f(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), + np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype), + rtol=rtol, + ) @tvm.testing.requires_cuda @@ -156,9 +146,20 @@ def test_matmul_add_igemm(): @tvm.testing.requires_cuda def test_batch_matmul(): - verify_batch_matmul("float", "float") - verify_batch_matmul("float16", "float") - verify_batch_matmul("float16", "float16", rtol=1e-2) + if not tvm.get_global_func("tvm.contrib.cublas.matmul", True): + print("skip because extern function is not available") + return + + verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float", "float") + verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float", "float") + verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float") + verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float") + verify_batch_matmul( + (16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 + ) + verify_batch_matmul( + (16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 + ) if __name__ == "__main__": diff --git a/tests/python/contrib/test_ethosn/infrastructure.py b/tests/python/contrib/test_ethosn/infrastructure.py index ba03acc1c112..92e8f11a2312 100644 --- a/tests/python/contrib/test_ethosn/infrastructure.py +++ b/tests/python/contrib/test_ethosn/infrastructure.py @@ -326,5 +326,5 @@ def get_ethosn_api_version(): def get_ethosn_variant(): ethosn_variant_config = os.getenv("ETHOSN_VARIANT_CONFIG") if ethosn_variant_config is not None: - return 3 - return 0 + return "Ethos-N78_1TOPS_2PLE_RATIO" + return "Ethos-N77" diff --git a/tests/python/contrib/test_ethosn/test_conv2d.py b/tests/python/contrib/test_ethosn/test_conv2d.py index ca551603d13f..845cec593105 100644 --- a/tests/python/contrib/test_ethosn/test_conv2d.py +++ b/tests/python/contrib/test_ethosn/test_conv2d.py @@ -188,10 +188,6 @@ def test_conv2d_failure(): _scale_error_msg = ( "Overall scale (of the input * weights / output) should be in the range [0, 1)" ) - if tei.get_ethosn_api_version() == 2008: - _scale_error_msg = ( - "Overall scale (of the input * weights / output) should be in the range [0, 1}" - ) trials = [ ( diff --git a/tests/python/contrib/test_ethosn/test_networks.py b/tests/python/contrib/test_ethosn/test_networks.py index ce89c90d9379..f9a3549576c3 100644 --- a/tests/python/contrib/test_ethosn/test_networks.py +++ b/tests/python/contrib/test_ethosn/test_networks.py @@ -123,15 +123,11 @@ def test_mobilenet_v1(): # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. _compile_hash = {"bfb5a50607edb50009c58ae9d4287e4d"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"896c28b4f06341ea638ead3a593e1aed"} - if tei.get_ethosn_api_version() == 2008: - _compile_hash = {"47e216d8ab2bf491708ccf5620bc0d02"} - if tei.get_ethosn_variant() == 3: - _compile_hash = {"2436f523e263f66a063cef902f2f43d7"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"9298b6c51e2a82f70e91dd11dd6af412"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"407eb47346c8afea2d15e8f0d1c079f2"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" @@ -153,15 +149,11 @@ def test_inception_v3(): # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. _compile_hash = {"96116d7e6c7385de0688074a3f889983"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"551cde850c6ef960d19be4f317fb8e68"} - if tei.get_ethosn_api_version() == 2008: - _compile_hash = {"8c9d75659cd7bc9ff6dd6d490d28f9b2"} - if tei.get_ethosn_variant() == 3: - _compile_hash = {"cdd4d7f6453d722ea73224ff9d6a115a"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"d44eece5027ff56e5e7fcf014367378d"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"1ba555b4bc60c428018a0f2de9d90532"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" @@ -182,17 +174,11 @@ def test_inception_v4(): # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. _compile_hash = {"b34aec2a48c591818761ed6b42c133e5"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"30f078bd42757e8686eafa1f28d0d352"} - if tei.get_ethosn_api_version() == 2008: - if not tei.get_ethosn_variant() == 0: - pytest.skip( - "Ethos-N78 20.08 does not support inception_v4 in the default configuration." - ) - _compile_hash = {"798292bfa596ca7c32086396b494b46c"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"53f126cf654d4cf61ebb23c767f6740b"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"851665c060cf4719248919d17325ae02"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" @@ -213,15 +199,11 @@ def test_ssd_mobilenet_v1(): # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. _compile_hash = {"c312edfc9a946ed4dc7c049d472dae6e", "3183f0fa5eba8f6b9557d14eaf47842d"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"deee52e136327436411fc725624ae2ea", "6526509d3cbee014e38c79e22bb29d7f"} - if tei.get_ethosn_api_version() == 2008: - _compile_hash = {"5999f26e140dee0d7866491997ef78c5", "24e3a690a7e95780052792d5626c85be"} - if tei.get_ethosn_variant() == 3: - _compile_hash = {"da871b3f03a93df69d704ed44584d6cd", "9f52411d301f3cba3f6e4c0f1c558e87"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"6e8c4586bdd26527c642a4f016f52284", "057c5efb094c79fbe4483b561147f1d2"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"dc687e60a4b6750fe740853f22aeb2dc", "1949d86100004eca41099c8e6fa919ab"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index f9912c9674e5..b54da208b33d 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -1246,12 +1246,12 @@ def test_tensorrt_dynamic_batch(): def test_tensorrt_dynamic_batch_conv(): if skip_codegen_test(): return - batches_to_test = [1, 1, 0, 2, 3, 0, 1, 3, 2] + batches_to_test = [1, 5, 1, 0, 2, 3, 0, 1, 3, 2] x_shape = (relay.Any(), 32, 8, 8) x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32") k_shape = (16, 32, 3, 3) params = {"kernel": np.random.uniform(-1, 1, k_shape).astype("float32")} - result_arr = [{} for _ in range(len(batches_to_test))] + result_arr = [{"cuda": {}, "llvm": {}} for _ in range(len(batches_to_test))] for use_trt in [True, False]: x = relay.var("x", shape=x_shape, dtype="float32") kernel = relay.var("kernel", shape=k_shape, dtype="float32") @@ -1263,15 +1263,21 @@ def test_tensorrt_dynamic_batch_conv(): mod, _ = tensorrt.partition_for_tensorrt(mod, params) if not skip_runtime_test(): - with relay.build_config(opt_level=3): - relay_exec = relay.create_executor("vm", mod=mod, device=tvm.cpu(0), target="llvm") + for target in ["llvm", "cuda"]: + with relay.build_config(opt_level=3): + relay_exec = relay.create_executor( + "vm", mod=mod, device=tvm.cpu(0), target="llvm" + ) - for i, batch_size in enumerate(batches_to_test): - result_arr[i][use_trt] = relay_exec.evaluate()(x_data[:batch_size, ...], **params) + for i, batch_size in enumerate(batches_to_test): + result_arr[i][target][use_trt] = relay_exec.evaluate()( + x_data[:batch_size, ...], **params + ) if not skip_runtime_test(): for i in range(len(batches_to_test)): - assert_result_dict_holds(result_arr[i]) + for target in ["llvm", "cuda"]: + assert_result_dict_holds(result_arr[i][target]) def test_maskrcnn_resnet50() -> None: diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index f7cbf92bca30..9c0d8fa8911e 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -41,7 +41,7 @@ def download_and_untar(model_url, model_sub_path, temp_dir): return os.path.join(temp_dir, model_sub_path) -def get_sample_compiled_module(target_dir, package_filename): +def get_sample_compiled_module(target_dir, package_filename, output_format="so"): """Support function that returns a TFLite compiled module""" base_url = "https://storage.googleapis.com/download.tensorflow.org/models" model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" @@ -53,7 +53,10 @@ def get_sample_compiled_module(target_dir, package_filename): tvmc_model = tvmc.frontends.load_model(model_file) return tvmc.compiler.compile_model( - tvmc_model, target="llvm", package_path=os.path.join(target_dir, package_filename) + tvmc_model, + target="llvm", + package_path=os.path.join(target_dir, package_filename), + output_format=output_format, ) @@ -182,6 +185,24 @@ def tflite_compiled_model(tmpdir_factory): return get_sample_compiled_module(target_dir, "mock.tar") +@pytest.fixture(scope="session") +def tflite_compiled_model_mlf(tmpdir_factory): + + # Not all CI environments will have TFLite installed + # so we need to safely skip this fixture that will + # crash the tests that rely on it. + # As this is a pytest.fixture, we cannot take advantage + # of pytest.importorskip. Using the block below instead. + try: + import tflite + except ImportError: + print("Cannot import tflite, which is required by tflite_compiled_module_as_tarfile.") + return "" + + target_dir = tmpdir_factory.mktemp("data") + return get_sample_compiled_module(target_dir, "mock.tar", "mlf") + + @pytest.fixture(scope="session") def imagenet_cat(tmpdir_factory): tmpdir_name = tmpdir_factory.mktemp("data") diff --git a/tests/python/driver/tvmc/test_mlf.py b/tests/python/driver/tvmc/test_mlf.py new file mode 100644 index 000000000000..48be5a810bc5 --- /dev/null +++ b/tests/python/driver/tvmc/test_mlf.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import os + +import tvm +from tvm.driver import tvmc +from tvm.driver.tvmc.main import _main +from tvm.driver.tvmc.model import TVMCPackage, TVMCException + + +def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory): + pytest.importorskip("tflite") + + output_dir = tmpdir_factory.mktemp("mlf") + input_model = tflite_mobilenet_v1_1_quant + output_file = os.path.join(output_dir, "mock.tar") + + # Compile the input model and generate a Model Library Format (MLF) archive. + tvmc_cmd = ( + f"tvmc compile {input_model} --target='llvm' --output {output_file} --output-format mlf" + ) + tvmc_args = tvmc_cmd.split(" ")[1:] + _main(tvmc_args) + assert os.path.exists(output_file), "Could not find the exported MLF archive." + + # Run the MLF archive. It must fail since it's only supported on micro targets. + tvmc_cmd = f"tvmc run {output_file}" + tvmc_args = tvmc_cmd.split(" ")[1:] + exit_code = _main(tvmc_args) + on_error = "Trying to run a MLF archive must fail because it's only supported on micro targets." + assert exit_code != 0, on_error + + +def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory): + pytest.importorskip("tflite") + + tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + mod, params = tvmc_model.mod, tvmc_model.params + + graph_module = tvm.relay.build(mod, target="llvm", params=params) + + output_dir = tmpdir_factory.mktemp("mlf") + output_file = os.path.join(output_dir, "mock.tar") + + # Try to export MLF with no cross compiler set. No exception must be thrown. + tvmc_model.export_package( + executor_factory=graph_module, + package_path=output_file, + cross=None, + output_format="mlf", + ) + assert os.path.exists(output_file), "Could not find the exported MLF archive." + + # Try to export a MLF whilst also specifying a cross compiler. Since + # that's not supported it must throw a TVMCException and report the + # reason accordingly. + with pytest.raises(TVMCException) as exp: + tvmc_model.export_package( + executor_factory=graph_module, + package_path=output_file, + cross="cc", + output_format="mlf", + ) + expected_reason = "Specifying the MLF output and a cross compiler is not supported." + on_error = "A TVMCException was caught but its reason is not the expected one." + assert str(exp.value) == expected_reason, on_error + + +def test_tvmc_import_package_mlf(tflite_compiled_model_mlf): + pytest.importorskip("tflite") + + # Compile and export a model to a MLF archive so it can be imported. + exported_tvmc_package = tflite_compiled_model_mlf + archive_path = exported_tvmc_package.package_path + + # Import the MLF archive. TVMCPackage constructor will call import_package method. + tvmc_package = TVMCPackage(archive_path) + + assert tvmc_package.lib_name is None, ".lib_name must not be set in the MLF archive." + assert tvmc_package.lib_path is None, ".lib_path must not be set in the MLF archive." + assert tvmc_package.graph is not None, ".graph must be set in the MLF archive." + assert tvmc_package.params is not None, ".params must be set in the MLF archive." + assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format." diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py index 078076b479ea..476fac5da1b9 100644 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ b/tests/python/driver/tvmc/test_tvmc_common.py @@ -192,6 +192,11 @@ def test_target_from_cli__error_duplicate(): _ = tvmc.common.target_from_cli("llvm, llvm") +def test_target_invalid_more_than_two_tvm_targets(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("cuda, opencl, llvm") + + def test_target_from_cli__error_target_not_found(): with pytest.raises(TVMCException): _ = tvmc.common.target_from_cli("invalidtarget") @@ -202,6 +207,18 @@ def test_target_from_cli__error_no_tvm_target(): _ = tvmc.common.target_from_cli("ethos-n77") +def test_target_two_tvm_targets(): + tvm_target, extra_targets = tvmc.common.target_from_cli( + "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu" + ) + + assert "opencl" in str(tvm_target) + assert "llvm" in str(tvm_target.host) + + # No extra targets + assert 0 == len(extra_targets) + + def test_tokenize_target_with_opts(): tokens = tvmc.common.tokenize_target("foo -opt1=value1 --flag, bar -opt2=value2") expected_tokens = ["foo", "-opt1=value1", "--flag", ",", "bar", "-opt2=value2"] diff --git a/tests/python/frontend/caffe/test_forward.py b/tests/python/frontend/caffe/test_forward.py index f4c0cd102340..d1396f0435d0 100644 --- a/tests/python/frontend/caffe/test_forward.py +++ b/tests/python/frontend/caffe/test_forward.py @@ -452,6 +452,33 @@ def test_forward_Deconvolution(): bias_filler=dict(type="xavier"), ), ) + _test_deconvolution( + data, + convolution_param=dict( + num_output=16, + bias_term=False, + pad=0, + kernel_size=2, + stride=2, + dilation=1, + group=16, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ), + ) + data = np.random.rand(1, 100, 32, 32).astype(np.float32) + _test_deconvolution( + data, + convolution_param=dict( + num_output=100, + bias_term=False, + pad=0, + kernel_size=2, + stride=2, + dilation=1, + group=100, + ), + ) ####################################################################### diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d1ecfc5559a4..6ac747c5ea94 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4129,6 +4129,7 @@ def verify_softplus(indata): verify_softplus(input_data) +@tvm.testing.uses_gpu def test_cumsum(): def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): cumsum_node = onnx.helper.make_node( @@ -4205,6 +4206,30 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): verify_cumsum(data, 1, 1, 1, type="int32") +@tvm.testing.uses_gpu +def test_eyelike(): + def verify_eyelike(indata): + node = helper.make_node( + "EyeLike", + inputs=["X"], + outputs=["Y"], + ) + + graph = helper.make_graph( + [node], + "eyelike_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(indata.shape))], + ) + + model = helper.make_model(graph, producer_name="eyelike_test") + + verify_with_ort_with_inputs(model, [indata], dtype="float32", opset=9) + + input_data = np.zeros((5, 5), dtype=np.float32) + verify_eyelike(input_data) + + """ The following parameterized tests loads the tests that ONNX ships as serialized ONNX files, inputs, and outputs. The goal of this test @@ -4241,9 +4266,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_cumsum_2d_negative_axis/", "test_det_2d/", "test_det_nd/", - "test_eyelike_populate_off_main_diagonal/", - "test_eyelike_with_dtype/", - "test_eyelike_without_dtype/", "test_matmulinteger/", "test_maxpool_2d_same_lower/", "test_maxpool_2d_same_upper/", @@ -4277,17 +4299,32 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_tfidfvectorizer_tf_onlybigrams_levelempty/", "test_tfidfvectorizer_tf_onlybigrams_skip5/", "test_tfidfvectorizer_tf_uniandbigrams_skip5/", - "test_unique_not_sorted_without_axis/", "test_unique_sorted_with_axis/", "test_unique_sorted_with_axis_3d/", "test_unique_sorted_with_negative_axis/", - "test_unique_sorted_without_axis/", "test_upsample_nearest/", ] +targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()] + +target_skips = { + "cuda": [ + "test_mod_mixed_sign_float16/", + "test_qlinearconv/", + "test_resize_upsample_sizes_nearest/", + ] +} + + +@pytest.mark.parametrize("target", targets) @pytest.mark.parametrize("test", onnx_test_folders) -def test_onnx_nodes(test): +def test_onnx_nodes(test, target): + if target in target_skips: + for failure in target_skips[target]: + if failure in test: + pytest.skip() + break for failure in unsupported_onnx_tests: if failure in test: pytest.skip() @@ -4313,12 +4350,14 @@ def test_onnx_nodes(test): outputs.append(numpy_helper.to_array(new_tensor)) else: raise ImportError(str(tensor) + " not labeled as an import or an output") - tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0)) - if len(outputs) == 1: - tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=rtol, atol=atol) - else: - for output, val in zip(outputs, tvm_val): - tvm.testing.assert_allclose(output, val, rtol=rtol, atol=atol) + + dev = tvm.device(target, 0) + tvm_val = get_tvm_output_with_vm(onnx_model, inputs, target, dev) + if len(outputs) == 1: + tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=rtol, atol=atol) + else: + for output, val in zip(outputs, tvm_val): + tvm.testing.assert_allclose(output, val, rtol=rtol, atol=atol) def test_wrong_input(): @@ -4682,4 +4721,5 @@ def repeat(N, D): test_wrong_input() test_aten() test_reverse_sequence() + test_eyelike() test_qlinearconv() diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 07f0d8e75c4d..be4d74ed205a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -643,6 +643,10 @@ def test_forward_prelu(): input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.PReLU(num_parameters=3).eval(), input_data=input_data) + # Test when input channel > 1 and num parameters = 1 + verify_model(torch.nn.PReLU(num_parameters=1).eval(), input_data=input_data) + # Test when input dims < 2 + verify_model(torch.nn.PReLU(num_parameters=1).eval(), input_data=torch.randn(2)) @tvm.testing.uses_gpu diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index f29450dbb604..331553388b48 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -151,7 +151,6 @@ def run_tvm_graph( return vmobj_to_list(result) elif mode == "vm": with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): - print(mod["main"]) mod = relay.transform.InferType()(mod) vm_exec = relay.vm.compile(mod, target="llvm", params=params) if serialize: @@ -3438,16 +3437,18 @@ def _test_forward_combined_nms( "nms/CombinedNonMaxSuppression:2", "nms/CombinedNonMaxSuppression:3", ], - mode="vm", ) def test_forward_combined_nms(): """CombinedNonMaxSuppression""" _test_forward_combined_nms((1, 64, 1, 4), (1, 64, 1), 0.7, 0.5, 64, 64) + _test_forward_combined_nms((1, 32, 1, 4), (1, 32, 1), 0.7, 0.5, 10, 64) + _test_forward_combined_nms((1, 32, 1, 4), (1, 32, 2), 0.7, 0.5, 32, 64) _test_forward_combined_nms((1, 64, 1, 4), (1, 64, 20), 0.7, 0.5, 64, 10) _test_forward_combined_nms((1, 64, 20, 4), (1, 64, 20), 0.7, 0.5, 64, 64, clip_boxes=True) _test_forward_combined_nms((2, 200, 1, 4), (2, 200, 1), 0.4, 0.6, 100, 100) + _test_forward_combined_nms((2, 200, 1, 4), (2, 200, 10), 0.4, 0.2, 150, 1000) ####################################################################### diff --git a/tests/python/frontend/tensorflow2/common.py b/tests/python/frontend/tensorflow2/common.py index e30ee7b0c993..9686909ff31f 100644 --- a/tests/python/frontend/tensorflow2/common.py +++ b/tests/python/frontend/tensorflow2/common.py @@ -23,8 +23,7 @@ from tvm.runtime.vm import VirtualMachine import tvm.contrib.graph_executor as runtime -from tvm.relay.frontend.tensorflow import from_tensorflow - +from tvm.relay.frontend.tensorflow2 import from_tensorflow import tvm.testing from tvm.relay.testing.tf import vmobj_to_list as vmobj_to_list @@ -34,20 +33,20 @@ def run_tf_code(func, input_): if type(func) is Function: - out = func(input_) - if isinstance(out, list): - a = [x.numpy() for x in out] + f_out = func(input_) + if isinstance(f_out, (list, tuple)): + np_out = [x.numpy() for x in f_out] else: - a = [out.numpy()] + np_out = [f_out.numpy()] else: - a = func(tf.constant(input_)) - if type(a) is dict: - a = [x.numpy() for x in a.values()] - elif type(a) is list: - a = [x.numpy() for x in a] + f_out = func(tf.constant(input_)) + if type(f_out) is dict: + np_out = [f_out[k].numpy() for k in sorted(f_out.keys())] + elif type(f_out) is list: + np_out = [x.numpy() for x in f_out] else: - a = a.numpy() - return a + np_out = f_out.numpy() + return np_out def compile_graph_executor(mod, params, target="llvm", target_host="llvm", opt_level=3): @@ -72,7 +71,7 @@ def run_graph_executor(lib, input_, ctx=tvm.cpu(0)): mod = runtime.GraphModule(lib["default"](ctx)) mod.set_input(0, input_) mod.run() - return [mod.get_output(0).asnumpy()] + return [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())] def compare_tf_tvm(gdef, input_, output_, runtime="vm", output_tensors=None): diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index 40d42a28025a..b3504ff38328 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -49,13 +49,15 @@ def _model_graph(TestClass): return gdef, input_, output -def run_all(TestClass): - def run_func_graph(TestClass, runtime="vm"): - compare_tf_tvm(*_function_graph(TestClass), runtime=runtime) +def run_func_graph(TestClass, runtime="vm", outputs=None): + compare_tf_tvm(*_function_graph(TestClass), runtime=runtime, output_tensors=outputs) + + +def run_model_graph(TestClass, outputs=None): + compare_tf_tvm(*_model_graph(TestClass), runtime="vm", output_tensors=outputs) - def run_model_graph(TestClass): - compare_tf_tvm(*_model_graph(TestClass), runtime="vm") +def run_all(TestClass): run_model_graph(TestClass) for runtime_ in ["vm", "graph"]: run_func_graph(TestClass, runtime=runtime_) @@ -63,7 +65,7 @@ def run_model_graph(TestClass): def test_add_one(): class AddOne(tf.Module): - """ simple function to test x=x+1; scalar as input""" + """simple function to test x=x+1; scalar as input""" def get_input(self): return np.array(1.0, dtype="float32") @@ -357,5 +359,93 @@ def func(self, x): run_all(ConcatV2) +def test_multi_output(): + class MultiOutput(tf.Module): + def get_input(self): + return np.ones((2, 2), dtype="float32") + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) + def func(self, x): + y = 2 * x + return x, y + + run_func_graph(MultiOutput, runtime="vm", outputs=["Identity:output:0", "Identity_1:output:0"]) + run_func_graph( + MultiOutput, runtime="graph", outputs=["Identity:output:0", "Identity_1:output:0"] + ) + run_model_graph(MultiOutput, outputs=["Identity:output:0"]) + + +def test_if(): + def create_if_class(_condition=True): + class If(tf.Module): + def get_input(self): + return np.ones((2, 2), dtype="float32") + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) + def func(self, x): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) + def double(x): + return 2 * x + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) + def triple(x): + return 3 * x + + output = tf.raw_ops.If( + cond=_condition, + input=[x], + Tout=[tf.float32], + output_shapes=[(2, 2)], + then_branch=double.get_concrete_function(), + else_branch=triple.get_concrete_function(), + ) + return output[0] + + return If + + for cond in [True, False]: + if_class = create_if_class(_condition=cond) + run_func_graph(if_class, runtime="vm") + run_model_graph(if_class) + + +def test_stateless_while(): + class StatelessWhile(tf.Module): + def get_input(self): + return np.array([6], dtype="float32") + + @tf.function(input_signature=[tf.TensorSpec(shape=(1,), dtype=tf.float32)]) + def func(self, x): + i = tf.constant(3.0) + cond = lambda i: tf.less(i, x) + body = lambda i: (tf.add(i, 2),) + r = tf.while_loop(cond, body, [i]) + return r[0] + + run_func_graph(StatelessWhile, runtime="vm") + run_model_graph(StatelessWhile) + + +def test_stateless_while_2var(): + class StatelessWhile2Var(tf.Module): + def get_input(self): + return np.array([20], dtype="float32") + + @tf.function(input_signature=[tf.TensorSpec(shape=(1,), dtype=tf.float32)]) + def func(self, x): + i = tf.constant(3.0) + j = tf.constant(5.0) + cond = lambda i, j: tf.less(i + j, x) + body = lambda i, j: (tf.add(i, 2), tf.add(j, 3)) + r = tf.while_loop(cond, body, [i, j]) + return r + + run_func_graph( + StatelessWhile2Var, runtime="vm", outputs=["Identity:output:0", "Identity_1:output:0"] + ) + run_model_graph(StatelessWhile2Var, outputs=["Identity:output:0"]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index c1917674873d..a54ffb80f051 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -165,12 +165,18 @@ def extract_main_workspace_sizebytes(extract_dir): def compile_and_run( - mod, input_list, output_list, use_calculated_workspaces, params=None, workspace_byte_alignment=8 + mod, + input_list, + output_list, + target_options, + use_calculated_workspaces, + params=None, + workspace_byte_alignment=8, ): """ This method verifies the generated source """ - target = f"c -runtime=c --link-params --executor=aot --workspace-byte-alignment={workspace_byte_alignment}" + target = f"c -runtime=c --link-params --executor=aot --workspace-byte-alignment={workspace_byte_alignment} {target_options}" cflags = f"-DTVM_RUNTIME_ALLOC_ALIGNMENT_BYTES={workspace_byte_alignment} " # The calculated workspaces will not account for stack allocator tags used for debugging diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 02b4de3a64f3..4f8de450d9f1 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -44,7 +44,8 @@ @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -def test_conv_with_params(use_calculated_workspaces): +@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) +def test_conv_with_params(use_calculated_workspaces, target_options): RELAY_MODEL = """ #[version = "0.0.5"] def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), int8]) { @@ -73,11 +74,12 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), output_list = generate_ref_data(mod, inputs, params) input_list = [input_data] - compile_and_run(mod, input_list, output_list, use_calculated_workspaces, params) + compile_and_run(mod, input_list, output_list, target_options, use_calculated_workspaces, params) @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -def test_add_with_params(use_calculated_workspaces): +@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) +def test_add_with_params(use_calculated_workspaces, target_options): x = relay.var("x", shape=(1, 10)) y = relay.var("y", shape=(1, 10)) z = relay.add(x, y) @@ -91,11 +93,14 @@ def test_add_with_params(use_calculated_workspaces): output_list = generate_ref_data(func, inputs, params) input_list = [y_in] - compile_and_run(func, input_list, output_list, use_calculated_workspaces, params) + compile_and_run( + func, input_list, output_list, target_options, use_calculated_workspaces, params + ) @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -def test_conv2d(use_calculated_workspaces): +@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) +def test_conv2d(use_calculated_workspaces, target_options): """Test a subgraph with a single conv2d operator.""" def conv2d_direct(): @@ -137,11 +142,12 @@ def group_conv2d(): for mod, inputs, out_shape in [conv2d_direct(), group_conv2d()]: output_list = generate_ref_data(mod, inputs) input_list = [inputs["data"], inputs["weight"]] - compile_and_run(mod, input_list, output_list, use_calculated_workspaces) + compile_and_run(mod, input_list, output_list, target_options, use_calculated_workspaces) @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -def test_concatenate(use_calculated_workspaces): +@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) +def test_concatenate(use_calculated_workspaces, target_options): dtype = "float32" x = relay.var("x", shape=(10, 5), dtype=dtype) y = relay.var("y", shape=(10, 5), dtype=dtype) @@ -157,11 +163,12 @@ def test_concatenate(use_calculated_workspaces): output_list = generate_ref_data(func, inputs) input_list = [inputs["x"], inputs["y"], inputs["z"]] - compile_and_run(func, input_list, output_list, use_calculated_workspaces) + compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -def test_nested_tuples(use_calculated_workspaces): +@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) +def test_nested_tuples(use_calculated_workspaces, target_options): x = relay.var("x", shape=(10,)) x1 = x + relay.const(1.0) x2 = x1 + relay.const(1.0) @@ -174,39 +181,43 @@ def test_nested_tuples(use_calculated_workspaces): inputs = {"x": x_data} output_list = generate_ref_data(func, inputs) input_list = [x_data] - compile_and_run(func, input_list, output_list, use_calculated_workspaces) + compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -def test_tuple_getitem(use_calculated_workspaces): +@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) +def test_tuple_getitem(use_calculated_workspaces, target_options): func = relay.Function([], relay.TupleGetItem(relay.Tuple([relay.const(1), relay.const(2)]), 0)) output_list = generate_ref_data(func, {}) input_list = [] - compile_and_run(func, input_list, output_list, use_calculated_workspaces) + compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -def test_id(use_calculated_workspaces): +@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) +def test_id(use_calculated_workspaces, target_options): x = relay.var("x", "float32") ident = relay.Function([x], x) one = np.array(1.0, "float32") inputs = {"x": one} output_list = generate_ref_data(ident, inputs) input_list = [one] - compile_and_run(ident, input_list, output_list, use_calculated_workspaces) + compile_and_run(ident, input_list, output_list, target_options, use_calculated_workspaces) @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -def test_add_const(use_calculated_workspaces): +@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) +def test_add_const(use_calculated_workspaces, target_options): two = relay.add(relay.const(1), relay.const(1)) func = relay.Function([], two) output_list = generate_ref_data(func, {}) input_list = [] - compile_and_run(func, input_list, output_list, use_calculated_workspaces) + compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -def test_mul_param(use_calculated_workspaces): +@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) +def test_mul_param(use_calculated_workspaces, target_options): x = relay.var("x", shape=(10, 10)) y = relay.var("y", shape=(1, 10)) func = relay.Function([x, y], relay.multiply(x, y)) @@ -215,11 +226,12 @@ def test_mul_param(use_calculated_workspaces): inputs = {"x": x_data, "y": y_data} output_list = generate_ref_data(func, inputs) input_list = [inputs["x"], inputs["y"]] - compile_and_run(func, input_list, output_list, use_calculated_workspaces) + compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -def test_subtract(use_calculated_workspaces): +@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) +def test_subtract(use_calculated_workspaces, target_options): i = relay.var("i", shape=[], dtype="int32") sub = relay.subtract(i, relay.const(1, dtype="int32")) func = relay.Function([i], sub, ret_type=relay.TensorType([], "int32")) @@ -227,11 +239,12 @@ def test_subtract(use_calculated_workspaces): inputs = {"i": i_data} output_list = generate_ref_data(func, inputs) input_list = [inputs["i"]] - compile_and_run(func, input_list, output_list, use_calculated_workspaces) + compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -def test_tuple_output(use_calculated_workspaces): +@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) +def test_tuple_output(use_calculated_workspaces, target_options): x = relay.var("x", shape=(6, 9)) y = relay.split(x, 3).astuple() a = relay.TupleGetItem(y, 0) @@ -243,15 +256,17 @@ def test_tuple_output(use_calculated_workspaces): inputs = {"x": x_data} output_list = generate_ref_data(func, inputs) input_list = [inputs["x"]] - compile_and_run(func, input_list, output_list, use_calculated_workspaces) + compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) @pytest.mark.parametrize( "use_calculated_workspaces_and_alignment", [(True, 1), (True, 16), (False, 1)] ) -def test_mobilenet(use_calculated_workspaces_and_alignment): +@pytest.mark.parametrize("target_options", ["--unpacked-api"]) +def test_mobilenet(use_calculated_workspaces_and_alignment, target_options): use_calculated_workspaces = use_calculated_workspaces_and_alignment[0] workspace_byte_alignment = use_calculated_workspaces_and_alignment[1] + mod, params = testing.mobilenet.get_workload(batch_size=1) data_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] data = np.random.uniform(size=data_shape).astype("float32") @@ -259,7 +274,13 @@ def test_mobilenet(use_calculated_workspaces_and_alignment): output_list = generate_ref_data(mod, inputs, params) input_list = [inputs["data"]] compile_and_run( - mod, input_list, output_list, use_calculated_workspaces, params, workspace_byte_alignment + mod, + input_list, + output_list, + target_options, + use_calculated_workspaces, + params, + workspace_byte_alignment, ) @@ -318,7 +339,8 @@ def visit_call(self, call): @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -def test_byoc_utvm(use_calculated_workspaces): +@pytest.mark.parametrize("target_options", [""]) +def test_byoc_utvm(use_calculated_workspaces, target_options): """This is a simple test case to check BYOC capabilities of AOT""" x = relay.var("x", shape=(10, 10)) w0 = relay.var("w0", shape=(10, 10)) @@ -361,7 +383,7 @@ def test_byoc_utvm(use_calculated_workspaces): output_list = generate_ref_data(mod, map_inputs) input_list = [map_inputs["x"]] input_list.extend([map_inputs["w{}".format(i)] for i in range(8)]) - compile_and_run(mod, input_list, output_list, use_calculated_workspaces) + compile_and_run(mod, input_list, output_list, target_options, use_calculated_workspaces) if __name__ == "__main__": diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 11f4515fbb1e..57f07b3f00e5 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -25,6 +25,7 @@ from tvm.relay.testing import run_infer_type as infer_type from utils.assert_diagnostic import DiagnosticTesting +from utils import ref_funcs def int32(val): @@ -1031,7 +1032,7 @@ def verify_any_strided_slice( mod = tvm.IRModule() data = relay.var("data", shape=data_shape, dtype="float32") if const_attrs: - data = relay.var("data", shape=data_np_shape, dtype="float32") + data = relay.var("data", shape=data_shape, dtype="float32") begin = relay.const(np_begin) end = relay.const(np_end) strides = relay.const(np_strides) @@ -1610,7 +1611,8 @@ def verify_all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, score_threshold, - expected_indices, + expected, + output_format="onnx", ): batch_size = boxes_np.shape[0] num_classes = scores_np.shape[1] @@ -1621,23 +1623,23 @@ def verify_all_class_non_max_suppression( ) nms_out = relay.vision.all_class_non_max_suppression( - boxes, - scores, - max_output_boxes_per_class, - iou_threshold, - score_threshold, + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format ) - three = relay.const(np.array([3]), dtype="int64") - begin = relay.const(np.array([0, 0]), dtype="int64") - end = relay.op.concatenate([nms_out[1], three], axis=0) - strides = relay.const(np.array([1, 1]), dtype="int64") - out = relay.op.strided_slice(nms_out[0], begin, end, strides) - - mod = tvm.IRModule() - mod["main"] = relay.Function([boxes, scores], out) - - check_result([boxes_np, scores_np], mod, [expected_indices]) + if output_format == "onnx": + three = relay.const(np.array([3]), dtype="int64") + begin = relay.const(np.array([0, 0]), dtype="int64") + end = relay.op.concatenate([nms_out[1], three], axis=0) + strides = relay.const(np.array([1, 1]), dtype="int64") + out = relay.op.strided_slice(nms_out[0], begin, end, strides) + mod = tvm.IRModule() + mod["main"] = relay.Function([boxes, scores], out) + check_result([boxes_np, scores_np], mod, [expected]) + else: + out = nms_out.tuple_value + mod = tvm.IRModule() + mod["main"] = relay.Function([boxes, scores], out) + check_result([boxes_np, scores_np], mod, expected) boxes = np.array( [ @@ -1667,6 +1669,39 @@ def verify_all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected ) + expected = [ + np.array( + [[[0, 4], [0, 2], [1, 4], [1, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]] + ), + np.array( + [ + [ + 0.9, + 0.6, + 0.9, + 0.8, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] + ] + ), + np.array([4]), + ] + + verify_all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + expected, + output_format="tensorflow", + ) + boxes = np.array( [ [ @@ -1703,5 +1738,52 @@ def verify_all_class_non_max_suppression( ) +@tvm.testing.uses_gpu +def test_gather_nd(): + def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, batch_dims=0): + x = relay.var("x", relay.TensorType(data_shape, "float32")) + y = relay.var("y", relay.TensorType(indices_shape, "int32")) + z = relay.gather_nd(x, y, batch_dims, indices_shape[0]) + + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + + data_np = np.random.uniform(size=data_shape_np).astype("float32") + indices_np = np.random.randint(low=0, high=2, size=indices_shape_np, dtype="int32") + + ref_res = ref_funcs.gather_nd(data_np, indices_np, batch_dims) + check_result([data_np, indices_np], mod, [ref_res]) + + verify_gather_nd((2, 2), (2, relay.Any()), (2, 2), (2, 3)) + verify_gather_nd((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3)) + verify_gather_nd((relay.Any(), 2), (1, relay.Any()), (10, 2), (1, 10), 1) + verify_gather_nd( + (relay.Any(), 2, 2, 3, 4), (3, relay.Any(), relay.Any()), (3, 2, 2, 3, 4), (3, 3, 2), 2 + ) + + +@tvm.testing.uses_gpu +def test_scatter_nd(): + def verify_scatter_nd(data_np, indices_np, updates_np, ref_res): + indices_shape = (2, relay.Any()) + updates_shape = (relay.Any(),) + data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) + indices = relay.var("indices", relay.TensorType(indices_shape, str(indices_np.dtype))) + updates = relay.var("updates", relay.TensorType(updates_shape, str(updates_np.dtype))) + + out = relay.op.scatter_nd(data, indices, updates, "add") + + mod = tvm.IRModule() + mod["main"] = relay.Function([data, indices, updates], out) + + check_result([data_np, indices_np, updates_np], mod, [ref_res]) + + data = np.zeros((2, 2)).astype("int64") + indices = np.array([[1, 1, 0], [0, 1, 0]]) + updates = np.array([2, 3, 0]) + out = np.array([[0, 0], [2, 3]]) + verify_scatter_nd(data, indices, updates, out) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 229b9905050c..f95a009f9dff 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -478,11 +478,17 @@ def test_no_match_func_attr(): def test_match_call_attr(): + # String attr is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"}) x = relay.var("x") y = relay.var("y") assert is_conv2d.match(relay.op.nn.conv2d(x, y)) + # Array attr + is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]}) + out = relay.op.nn.conv2d(x, y, kernel_size=[3, 3]) + assert is_conv2d.match(out) + # non-operator call attr_dict = {"call_attr": "attr"} call_has_attr = wildcard()(wildcard()).has_attr(attr_dict) @@ -508,6 +514,11 @@ def test_no_match_call_attr(): is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"}) assert not is_conv2d.match(relay.op.nn.conv2d(x, y)) + # Array attr + is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]}) + out = relay.op.nn.conv2d(x, y, kernel_size=[2, 1]) + assert not is_conv2d.match(out) + # non-operator calls call_has_attr = wildcard()(wildcard()).has_attr({"call_attr": "attr"}) wrong_key = tvm.ir.make_node("DictAttrs", **{"wrong": "attr"}) diff --git a/tests/python/relay/test_ir_op.py b/tests/python/relay/test_ir_op.py index fe559697348b..edb8086dd426 100644 --- a/tests/python/relay/test_ir_op.py +++ b/tests/python/relay/test_ir_op.py @@ -17,6 +17,7 @@ import tvm from tvm import relay from tvm.relay.testing.temp_op_attr import TempOpAttr +from tvm.relay.op import op as _op def test_op_attr(): @@ -103,11 +104,20 @@ def test_op_register(): """Tests register_op functionality.""" op_name = "custom_op" - tvm.ir.register_op(op_name) - tvm.ir.register_op_attr(op_name, "num_inputs", 2, 256) - - assert tvm.ir.Op.get(op_name).name == op_name - assert tvm.ir.Op.get(op_name).num_inputs == 2 + _op.register(op_name, r"code(Add two tensor with inner broadcasting.)code") + _op.get(op_name).set_num_inputs(2) + _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") + _op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.") + # call default relation functions + _op.get(op_name).add_type_rel("Identity") + _op.get(op_name).set_support_level(1) + _op.register_pattern(op_name, _op.OpPattern.ELEMWISE) + _op.register_stateful(op_name, False) + + assert _op.get(op_name).name == op_name + assert _op.get(op_name).num_inputs == 2 + assert _op.get(op_name).get_attr("TOpPattern") == _op.OpPattern.ELEMWISE + assert _op.get(op_name).get_attr("TOpIsStateful") == False if __name__ == "__main__": diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index c75a0f461eb6..099e127aeba9 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -171,6 +171,7 @@ def test_int_literal(): assert get_scalar(parse_text("0")) == 0 assert get_scalar(parse_text("-100")) == -100 assert get_scalar(parse_text("-05")) == -5 + assert get_scalar(parse_text("9223372036854775807")) == 9223372036854775807 def test_float_literal(): diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index b2ae28649e6a..4968660b95c8 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -30,6 +30,7 @@ def astext(program, unify_free_vars=False): text = program.astext() + print(text) if isinstance(program, Expr): roundtrip_program = tvm.parser.parse_expr(text) @@ -47,6 +48,17 @@ def show(text): print(text) +def test_large_graph(): + x = relay.var("x", shape=(3, 2)) + y = relay.var("y") + one = relay.const(10e10, dtype="float32") + z = relay.add(x, one) + for i in range(int(1e6)): + z = relay.add(z, one) + f = relay.Function([x, y], z) + show(astext(f)) + + def test_func(): x = relay.var("x", shape=(3, 2)) y = relay.var("y") diff --git a/tests/python/relay/test_op_fast_math.py b/tests/python/relay/test_op_fast_math.py index 8e401bc5670a..f968dbedddfe 100644 --- a/tests/python/relay/test_op_fast_math.py +++ b/tests/python/relay/test_op_fast_math.py @@ -23,11 +23,13 @@ from tvm import topi from tvm import te from tvm.contrib import graph_executor +from tvm.topi import testing -def test_fastmath(): +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_fastmath(target, dev): def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): - a_np = np.arange(low, high, step).astype(dtype) + a_np = np.arange(low, high, step).astype(dtype).reshape((1, -1)) b_np = f_numpy(a_np) x = relay.var("x", shape=a_np.shape, dtype="float32") @@ -36,13 +38,14 @@ def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): mod = tvm.IRModule.from_expr(func) with tvm.transform.PassContext(opt_level=3, required_pass=["FastMath"]): - graph, lib, params = relay.build(mod, target="llvm", params=None) + graph, lib, params = relay.build(mod, target=target, params=None) # Check that the op related to fast math have been convered to function in lib func_name = "fused_" + name - assert lib.get_function(func_name) + # When there're multiple targets in tvm.testing.parametrize_targets, the function + # built will have a "_1" in function name + assert func_name in graph - dev = tvm.cpu(0) m = graph_executor.create(graph, lib, dev) # Set inputs m.set_input("x", tvm.nd.array(a_np, dev)) @@ -56,6 +59,14 @@ def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): test_apply(relay.exp, "fast_exp", np.exp, low=-88, high=88, step=0.01) test_apply(relay.erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01) test_apply(relay.tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01) + test_apply( + relay.nn.fast_softmax, + "nn_fast_softmax", + tvm.topi.testing.softmax_python, + low=-10, + high=10, + step=0.01, + ) if __name__ == "__main__": diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index fd6d7a9aeb14..fc67f0b90295 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -26,6 +26,7 @@ from tvm.error import TVMError from tvm.relay import create_executor, transform from tvm.relay.testing import check_grad, run_infer_type +from utils import ref_funcs def test_zeros_ones(): @@ -1266,26 +1267,7 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0): else: y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32") - def gather_nd_batch_dims_1_ref(data, indices): - res = [] - for i, row in enumerate(data): - indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch - res.append(row[indices_tuple]) - # stack on the batch dim - return np.stack(res, 0) - - if batch_dims > 1: - x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:]) - y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :]) - - ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape) - - out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:] - ref_res = np.reshape(ref_res, out_shape) - elif batch_dims == 1: - ref_res = gather_nd_batch_dims_1_ref(x_data, y_data) - else: - ref_res = x_data[tuple(y_data)] + ref_res = ref_funcs.gather_nd(x_data, y_data, batch_dims) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: @@ -1977,7 +1959,14 @@ def calc_numpy_unique(data, is_sorted=False): uniq = uniq[order].astype(data.dtype) inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") counts = counts[order].astype("int32") - return [uniq.astype(data.dtype), inverse.astype("int32"), counts, num_uniq] + index = np.sort(index) # In unsorted case, need to sort the index of first occurence + return [ + uniq.astype(data.dtype), + index.astype("int32"), + inverse.astype("int32"), + num_uniq, + counts, + ] def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): if is_dyn: @@ -1998,18 +1987,26 @@ def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): for kind in backends: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - tvm_res = intrp.evaluate()(x_data) - np_res = calc_numpy_unique(x_data, is_sorted) + tvm_res = intrp.evaluate()( + x_data + ) # unique, indices, inverse_indices, num_unique, (counts) + np_res = calc_numpy_unique( + x_data, is_sorted + ) # unique, indices, inverse_indices, num_unique, counts num_unique = np_res[3][0] - assert num_unique == tvm_res[2].numpy()[0] + + # num_unique + assert num_unique == tvm_res[3].numpy()[0] # unique tvm.testing.assert_allclose(tvm_res[0].numpy()[:num_unique], np_res[0], rtol=1e-5) + # indices + tvm.testing.assert_allclose(tvm_res[1].numpy()[:num_unique], np_res[1], rtol=1e-5) # inverse_indices - tvm.testing.assert_allclose(tvm_res[1].numpy(), np_res[1], rtol=1e-5) + tvm.testing.assert_allclose(tvm_res[2].numpy(), np_res[2], rtol=1e-5) # counts if return_counts: tvm.testing.assert_allclose( - tvm_res[3].numpy()[:num_unique], np_res[2], rtol=1e-5 + tvm_res[4].numpy()[:num_unique], np_res[4], rtol=1e-5 ) for dtype in ["int32", "int64"]: diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index c49e3de62662..c4d26a1811b1 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -379,7 +379,17 @@ def test_mean_var_std(): @tvm.testing.uses_gpu def test_strided_slice(): - def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, dtype="int32"): + def verify( + dshape, + begin, + end, + strides, + output, + axes=None, + slice_mode="end", + test_ref=True, + dtype="int32", + ): x = relay.var("x", relay.TensorType(dshape, "float32")) ndim = len(dshape) begin = begin if begin else [0] * ndim @@ -387,12 +397,21 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, # target numpy result x_data = np.random.uniform(size=dshape).astype("float32") - ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode) + ref_res = tvm.topi.testing.strided_slice_python( + x_data, + begin, + end, + strides, + slice_mode, + axes=axes, + ) if strides: - z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode) + z = relay.strided_slice( + x, begin=begin, end=end, strides=strides, axes=axes, slice_mode=slice_mode + ) else: - z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode) + z = relay.strided_slice(x, begin=begin, end=end, axes=axes, slice_mode=slice_mode) func = relay.Function([x], z) func = run_infer_type(func) @@ -436,24 +455,43 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False ) verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True) + verify((3, 4, 3), [1], [4], None, None, axes=[1]) @tvm.testing.uses_gpu def test_dyn_strided_slice(): - def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, dtype="int32"): + def verify( + dshape, + begin, + end, + strides, + output, + axes=None, + ishape=None, + slice_mode="end", + test_ref=True, + dtype="int32", + ): ndim = len(dshape) begin = begin if begin else [0] * ndim end = end if end else list(dshape) # target numpy result x_data = np.random.uniform(size=dshape).astype("float32") - ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode) + ref_res = tvm.topi.testing.strided_slice_python( + x_data, begin, end, strides, slice_mode, axes=axes + ) - x = relay.var("x", relay.TensorType((relay.Any(),) * ndim, "float32")) + if ishape is None: + ishape = (relay.Any(),) * ndim + + x = relay.var("x", relay.TensorType(ishape, "float32")) if strides: - z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode) + z = relay.strided_slice( + x, begin=begin, end=end, strides=strides, axes=axes, slice_mode=slice_mode + ) else: - z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode) + z = relay.strided_slice(x, begin=begin, end=end, axes=axes, slice_mode=slice_mode) func = relay.Function([x], z) func = run_infer_type(func) @@ -483,13 +521,21 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None, (2, 3, 3)) verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) - # TODO(mbrookhart): fix static strided_slice with dynamic input and negative begin - # verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) - # verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) verify( (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False ) verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True) + verify( + (3, 4, 3, 2), + [1, 0], + [3, 1], + [1, 1], + None, + axes=[1, 3], + ishape=(relay.Any(), 4, relay.Any(), 2), + ) @tvm.testing.uses_gpu @@ -534,11 +580,12 @@ def verify(dshape, begin, end, strides, vshape, test_ref=True): if __name__ == "__main__": test_strided_slice() - test_strided_set() - test_binary_op() - test_cmp_type() - test_binary_int_broadcast_1() - test_binary_int_broadcast_2() - test_where() - test_reduce_functions() - test_mean_var_std() + test_dyn_strided_slice() + # test_strided_set() + # test_binary_op() + # test_cmp_type() + # test_binary_int_broadcast_1() + # test_binary_int_broadcast_2() + # test_where() + # test_reduce_functions() + # test_mean_var_std() diff --git a/tests/python/relay/test_op_qnn_dequantize.py b/tests/python/relay/test_op_qnn_dequantize.py index fd5c23fe35b1..70ea05fe1894 100644 --- a/tests/python/relay/test_op_qnn_dequantize.py +++ b/tests/python/relay/test_op_qnn_dequantize.py @@ -74,6 +74,15 @@ def test_int8_to_float32(): ) +def test_scalar_int8_to_float32(): + data = np.array(-128).astype("int8") + output = np.array(-63.5).astype("float32") + quant_args = {"in_zero_point": -1, "in_scale": 0.5} + dequantize_test_driver( + in_dtype="int8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=-1 + ) + + def test_int32_to_float32(): data = np.array([113, 29, -1052]).astype("int32") output = np.array([0.6550452, 0.16810896, -6.098297]).astype("float32") @@ -148,6 +157,7 @@ def test_dynamic_dequantize(): if __name__ == "__main__": test_uint8_to_float32() test_int8_to_float32() + test_scalar_int8_to_float32() test_int32_to_float32() test_channelwise_axis_1() test_channelwise_axis_0() diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 3031c55379ae..5c2793c607a9 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -770,6 +770,61 @@ def expected(): ) +@tvm.testing.uses_gpu +def test_alter_layout_strided_slice_axes_nhwc(): + """Test rewriting strided_slice with axes during alter_iop_layout""" + + def before(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + y = relay.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.strided_slice(y, begin=[0, 16], end=[1, 32], strides=[1, 1], axes=[0, 3]) + y = relay.Function(analysis.free_vars(y), y) + return y + + def alter_conv2d(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs["data_layout"] = "NHWC4c" + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + x = relay.layout_transform(x, "NHWC", "NHWC4c") + y = relay.op.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC4c", + kernel_layout="HWIO", + ) + y = relay.strided_slice(y, begin=[0, 4], end=[1, 8], strides=[1, 1], axes=[0, 3]) + y = relay.layout_transform(y, "NHWC4c", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + + mod_before = tvm.IRModule() + mod_new = tvm.IRModule() + mod_before["main"] = a + mod_new["main"] = b + assert tvm.ir.structural_equal(mod_before, mod_new) + + def test_alter_layout_depthwise_conv2d(): """Test depthwise_conv2d operator""" @@ -1298,3 +1353,4 @@ def expected(): test_alter_layout_nhwc_int8_aarch64() test_alter_op_with_global_var() test_alter_op_dense() + test_alter_layout_strided_slice_axes_nhwc() diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 7288b6421de1..f0949ab19f9c 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -23,12 +23,19 @@ from tvm.contrib import graph_executor from tvm.relay.expr_functor import ExprMutator from tvm.relay import transform +from tvm.ir.instrument import pass_instrument import tvm.testing -def _trace(module, metadata, _): - if metadata.name == "ManifestAlloc": - pass # import pdb; pdb.set_trace() +@tvm.instrument.pass_instrument +class Trace: + def run_before_pass(self, module, pass_info): + if pass_info.name == "ManifestAlloc": + pass # import pdb; pdb.set_trace() + + def run_after_pass(self, module, pass_info): + if pass_info.name == "ManifestAlloc": + pass # import pdb; pdb.set_trace() def check_graph_executor( @@ -49,7 +56,7 @@ def check_graph_executor( def check_vm_runtime(target, ref_res, device, func, params, config, opt_level, expected_index=None): - with tvm.transform.PassContext(opt_level=opt_level, trace=_trace, config=config): + with tvm.transform.PassContext(opt_level=opt_level, instruments=[Trace()], config=config): mod = tvm.IRModule() mod["main"] = func exe = relay.vm.compile(mod, target) diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index dd2dc979a731..4710d50ea8e4 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -735,7 +735,7 @@ def expected(): def test_conv_bn_convert_layout(): - """ Check that layout transforms are propagated through bn. """ + """Check that layout transforms are propagated through bn.""" def before(): x = relay.var("x", shape=(1, 56, 56, 64)) @@ -1097,7 +1097,7 @@ def expected(): def test_conv_convert_kernel_layout(): - """ Check that convolution kernel layout is correctly transformed. """ + """Check that convolution kernel layout is correctly transformed.""" def before(): x = relay.var("x", shape=(1, 56, 56, 64)) @@ -1235,6 +1235,49 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_conv_strided_slice_axes_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + y = relay.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.strided_slice(y, begin=[0, 16], end=[1, 33], strides=[1, 1], axes=[0, 3]) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + weight = relay.layout_transform(weight, "HWIO", "OIHW") + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = relay.strided_slice(y, begin=[0, 16], end=[1, 33], strides=[1, 1], axes=[0, 1]) + + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = run_opt_pass(before(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_conv_roi_pool_convert_layout(): def before(): x = relay.var("x", shape=(1, 64, 56, 56)) @@ -1289,7 +1332,7 @@ def expected(): def test_default_keyword(): - """ Check that the default keyword selects correct TVM default layout. """ + """Check that the default keyword selects correct TVM default layout.""" def before(): x = relay.var("x", shape=(1, 64, 56, 56)) @@ -1784,3 +1827,4 @@ def expected(): test_convert_with_config() test_conv_squeeze_convert_layout() test_conv_reduce_convert_layout() + test_conv_strided_slice_axes_convert_layout() diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py new file mode 100644 index 000000000000..3271379cf3ef --- /dev/null +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -0,0 +1,279 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-wildcard-import +import numpy as np +import pytest + +import tvm +from tvm import relay + + +def test_fake_quantize_conv(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.conv2d( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np, w_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np, w_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_transpose_quantize_conv(): + x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.transpose(x, [0, 3, 1, 2]) + op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.qnn.op.quantize(op, one, zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np, w_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np, w_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_transpose_quantize_conv_bias_add(): + x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + bias = relay.var("bias", shape=[16], dtype="int32") + one = relay.const(1.0) + zero = relay.const(0) + + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.transpose(x, [0, 3, 1, 2]) + op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, one, zero)) + op = relay.qnn.op.quantize(op, one, zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np, w_np, bias_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np, w_np, bias_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_maxpool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.max_pool2d(x, [3, 3]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_avgpool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.avg_pool2d(x, [3, 3]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.all(np.abs(result - result2) <= 1) + + +def test_fake_quantize_reshape(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.reshape(x, [1, 3, -1]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_transpose_reshape(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.transpose(x, [1, 0, 2, 3]) + op = relay.op.reshape(op, [3, -1]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_concat(): + zero = relay.const(0) + inputs = [] + for i in range(4): + inputs.append( + relay.qnn.op.dequantize( + relay.var("x%d" % i, shape=[1, 4], dtype="int8"), relay.const(i + 0.5), zero + ) + ) + concat = relay.op.concatenate(inputs, axis=1) + out = relay.qnn.op.quantize(concat, relay.const(3.5), zero) + + mod = tvm.IRModule.from_expr(out) + mod = tvm.relay.transform.InferType()(mod) + + inputs_np = [] + for i in range(4): + inputs_np.append(np.random.randint(-128, 127, size=[1, 4], dtype="int8")) + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(*inputs_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(*inputs_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_clip(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") + + x = relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(114)) + op = relay.op.clip(x, 0, 6) + op = relay.qnn.op.quantize(op, relay.const(2.0), relay.const(114), out_dtype="uint8") + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.array_equal(result, result2) diff --git a/tests/python/relay/test_pass_fast_math.py b/tests/python/relay/test_pass_fast_math.py index bb3fb84fc61f..f63b6ce0f23e 100644 --- a/tests/python/relay/test_pass_fast_math.py +++ b/tests/python/relay/test_pass_fast_math.py @@ -65,7 +65,19 @@ def test_erf(): assert "fast_erf" in fast_mod[0].astext() +def test_softmax(): + x = relay.var("x", shape=(1, 16), dtype="float32") + y = relay.nn.softmax(x) + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + + with tvm.transform.PassContext(opt_level=3, required_pass=["FastMath"]): + fast_mod = relay.optimize(mod, target="llvm") + assert "nn.fast_softmax" in fast_mod[0].astext() + + if __name__ == "__main__": test_exp() test_tanh() test_erf() + test_softmax() diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py new file mode 100644 index 000000000000..610d4e4e491b --- /dev/null +++ b/tests/python/relay/test_pass_instrument.py @@ -0,0 +1,538 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Instrument test cases. +""" +import pytest +import tvm +import tvm.relay +from tvm.relay import op +from tvm.ir.instrument import PassTimingInstrument, pass_instrument + + +def get_test_model(): + x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] + e1 = op.add(x, y) + e2 = op.subtract(x, z) + e3 = op.multiply(e1, e1 / e2) + return tvm.IRModule.from_expr(e3 + e2) + + +def test_pass_timing_instrument(): + pass_timing = PassTimingInstrument() + + # Override current PassContext's instruments + tvm.transform.PassContext.current().override_instruments([pass_timing]) + + mod = get_test_model() + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + profiles = pass_timing.render() + assert "AnnotateSpans" in profiles + assert "ToANormalForm" in profiles + assert "InferType" in profiles + + # Reset current PassContext's instruments to None + tvm.transform.PassContext.current().override_instruments(None) + + mod = get_test_model() + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + profiles = pass_timing.render() + assert profiles == "" + + +def test_custom_instrument(): + @pass_instrument + class MyTest: + def __init__(self): + self.events = [] + + def enter_pass_ctx(self): + self.events.append("enter ctx") + + def exit_pass_ctx(self): + self.events.append("exit ctx") + + def run_before_pass(self, mod, info): + self.events.append("run before " + info.name) + + def run_after_pass(self, mod, info): + self.events.append("run after " + info.name) + + mod = get_test_model() + my_test = MyTest() + with tvm.transform.PassContext(instruments=[my_test]): + mod = tvm.relay.transform.InferType()(mod) + + assert ( + "enter ctx" + "run before InferType" + "run after InferType" + "exit ctx" == "".join(my_test.events) + ) + + +def test_disable_pass(): + @pass_instrument + class CustomPI: + def __init__(self): + self.events = [] + + def should_run(self, mod, info): + # Only run pass name contains "InferType" + if "InferType" not in info.name: + return False + return True + + def run_before_pass(self, mod, info): + self.events.append(info.name) + + mod = get_test_model() + custom_pi = CustomPI() + with tvm.transform.PassContext(instruments=[custom_pi]): + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + assert "InferType" == "".join(custom_pi.events) + + +def test_multiple_instrument(): + @pass_instrument + class SkipPass: + def __init__(self, skip_pass_name): + self.skip_pass_name = skip_pass_name + + def should_run(self, mod, info): + if self.skip_pass_name in info.name: + return False + return True + + skip_annotate = SkipPass("AnnotateSpans") + skip_anf = SkipPass("ToANormalForm") + + @pass_instrument + class PrintPassName: + def __init__(self): + self.events = [] + + def run_before_pass(self, mod, info): + self.events.append(info.name) + + mod = get_test_model() + print_pass_name = PrintPassName() + with tvm.transform.PassContext(instruments=[skip_annotate, skip_anf, print_pass_name]): + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + assert "InferType" == "".join(print_pass_name.events) + + +def test_instrument_pass_counts(): + @pass_instrument + class PassesCounter: + def __init__(self): + self.run_before_count = 0 + self.run_after_count = 0 + + def __clear(self): + self.run_before_count = 0 + self.run_after_count = 0 + + def enter_pass_ctx(self): + self.__clear() + + def exit_pass_ctx(self): + self.__clear() + + def run_before_pass(self, mod, info): + self.run_before_count = self.run_before_count + 1 + + def run_after_pass(self, mod, info): + self.run_after_count = self.run_after_count + 1 + + mod = get_test_model() + passes_counter = PassesCounter() + with tvm.transform.PassContext(instruments=[passes_counter]): + tvm.relay.build(mod, "llvm") + assert passes_counter.run_after_count != 0 + assert passes_counter.run_after_count == passes_counter.run_before_count + + # Out of pass context scope, should be reset + assert passes_counter.run_before_count == 0 + assert passes_counter.run_after_count == 0 + + +def test_list_pass_configs(): + configs = tvm.transform.PassContext.list_configs() + + assert len(configs) > 0 + assert "relay.backend.use_auto_scheduler" in configs.keys() + assert configs["relay.backend.use_auto_scheduler"]["type"] == "IntImm" + + +def test_enter_pass_ctx_exception(): + events = [] + + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + events.append(self.id + " enter ctx") + + def exit_pass_ctx(self): + events.append(self.id + " exit ctx") + + @pass_instrument + class PIBroken(PI): + def __init__(self, id): + super().__init__(id) + + def enter_pass_ctx(self): + events.append(self.id + " enter ctx") + raise RuntimeError("Just a dummy error") + + pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), PIBroken("%2"), PI("%3")]) + with pytest.raises(tvm.error.TVMError) as cm: + with pass_ctx: + pass + assert "Just a dummy error" in str(cm.execption) + + assert "%1 enter ctx" "%2 enter ctx" "%1 exit ctx" == "".join(events) + + # Make sure we get correct PassContext + cur_pass_ctx = tvm.transform.PassContext.current() + assert pass_ctx != cur_pass_ctx + assert cur_pass_ctx.instruments == None + + +def test_enter_pass_ctx_exception_global(): + @pass_instrument + class PIBroken: + def enter_pass_ctx(self): + raise RuntimeError("Just a dummy error") + + cur_pass_ctx = tvm.transform.PassContext.current() + with pytest.raises(tvm.error.TVMError) as cm: + cur_pass_ctx.override_instruments([PIBroken()]) + assert "Just a dummy error" in str(cm.exception) + assert not cur_pass_ctx.instruments + + +def test_exit_pass_ctx_exception(): + events = [] + + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def exit_pass_ctx(self): + events.append(self.id + " exit ctx") + + def exit_pass_ctx(self): + events.append(self.id + " exit ctx") + + @pass_instrument + class PIBroken(PI): + def __init__(self, id): + super().__init__(id) + + def exit_pass_ctx(self): + events.append(self.id + " exit ctx") + raise RuntimeError("Just a dummy error") + + pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), PIBroken("%2"), PI("%3")]) + with pytest.raises(tvm.error.TVMError) as cm: + with pass_ctx: + pass + assert "Just a dummy error" in str(cm.exception) + + assert "%1 exit ctx" "%2 exit ctx" == "".join(events) + + # Make sure we get correct PassContext + cur_pass_ctx = tvm.transform.PassContext.current() + assert pass_ctx != cur_pass_ctx + assert not cur_pass_ctx.instruments + + +def test_exit_pass_ctx_exception_global(): + @pass_instrument + class PIBroken: + def exit_pass_ctx(self): + raise RuntimeError("Just a dummy error") + + cur_pass_ctx = tvm.transform.PassContext.current() + with pytest.raises(tvm.error.TVMError) as cm: + cur_pass_ctx.override_instruments([PIBroken()]) + cur_pass_ctx.override_instruments([PIBroken()]) + assert "Just a dummy error" in str(cm.exception) + assert not cur_pass_ctx.instruments + + +def test_pass_exception(): + events = [] + + @pass_instrument + class PI: + def enter_pass_ctx(self): + events.append("enter_pass_ctx") + + def exit_pass_ctx(self): + events.append("exit_pass_ctx") + + def should_run(self, mod, info): + events.append("should_run") + return True + + def run_before_pass(self, mod, info): + events.append("run_before_pass") + + def run_after_pass(self, mod, info): + events.append("run_after_pass") + + @tvm.transform.module_pass(opt_level=2) + def transform(mod, ctx): + events.append("transform pass") + raise RuntimeError("Just a dummy error") + return mod + + mod = get_test_model() + with pytest.raises(tvm.error.TVMError) as cm: + with tvm.transform.PassContext(instruments=[PI()]): + mod = transform(mod) + assert "Just a dummy error" in str(cm.exception) + + assert ( + "enter_pass_ctx" + "should_run" + "run_before_pass" + "transform pass" + "exit_pass_ctx" == "".join(events) + ) + + +def test_should_run_exception(): + events = [] + + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + events.append(self.id + " enter_pass_ctx") + + def exit_pass_ctx(self): + events.append(self.id + " exit_pass_ctx") + + def should_run(self, mod, info): + events.append(self.id + " should_run") + raise RuntimeError("Just a dummy error") + return True + + def run_before_pass(self, mod, info): + events.append(self.id + " run_before_pass") + + def run_after_pass(self, mod, info): + events.append(self.id + " run_after_pass") + + @tvm.transform.module_pass(opt_level=2) + def transform(mod, ctx): + events.append("transform pass") + return mod + + mod = get_test_model() + with pytest.raises(tvm.error.TVMError) as cm: + with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): + mod = transform(mod) + assert "Just a dummy error" in str(cm.exception) + + assert ( + "%1 enter_pass_ctx" + "%2 enter_pass_ctx" + "%1 should_run" + "%1 exit_pass_ctx" + "%2 exit_pass_ctx" == "".join(events) + ) + + +def test_run_before_exception(): + events = [] + + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + events.append(self.id + " enter_pass_ctx") + + def exit_pass_ctx(self): + events.append(self.id + " exit_pass_ctx") + + def should_run(self, mod, info): + events.append(self.id + " should_run") + return True + + def run_before_pass(self, mod, info): + events.append(self.id + " run_before_pass") + raise RuntimeError("Just a dummy error") + + def run_after_pass(self, mod, info): + events.append(self.id + " run_after_pass") + + @tvm.transform.module_pass(opt_level=2) + def transform(mod, ctx): + events.append("transform pass") + return mod + + mod = get_test_model() + with pytest.raises(tvm.error.TVMError) as cm: + with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): + mod = transform(mod) + assert "Just a dummy error" in str(cm.exception) + + assert ( + "%1 enter_pass_ctx" + "%2 enter_pass_ctx" + "%1 should_run" + "%2 should_run" + "%1 run_before_pass" + "%1 exit_pass_ctx" + "%2 exit_pass_ctx" == "".join(events) + ) + + +def test_run_after_exception(): + events = [] + + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + events.append(self.id + " enter_pass_ctx") + + def exit_pass_ctx(self): + events.append(self.id + " exit_pass_ctx") + + def should_run(self, mod, info): + events.append(self.id + " should_run") + return True + + def run_before_pass(self, mod, info): + events.append(self.id + " run_before_pass") + + def run_after_pass(self, mod, info): + events.append(self.id + " run_after_pass") + raise RuntimeError("Just a dummy error") + + @tvm.transform.module_pass(opt_level=2) + def transform(mod, ctx): + events.append("transform pass") + return mod + + x, y = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xy"] + mod = tvm.IRModule.from_expr(tvm.relay.add(x, y)) + + with pytest.raises(tvm.error.TVMError) as cm: + with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): + mod = transform(mod) + assert "Just a dummy error" in str(cm.exception) + + assert ( + "%1 enter_pass_ctx" + "%2 enter_pass_ctx" + "%1 should_run" + "%2 should_run" + "%1 run_before_pass" + "%2 run_before_pass" + "transform pass" + "%1 run_after_pass" + "%1 exit_pass_ctx" + "%2 exit_pass_ctx" == "".join(events) + ) + + +def test_instrument_call_sequence(): + events = [] + + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + events.append(self.id + " enter_pass_ctx") + + def exit_pass_ctx(self): + events.append(self.id + " exit_pass_ctx") + + def should_run(self, mod, info): + events.append(" " + self.id + " should_run") + return True + + def run_before_pass(self, mod, info): + events.append(" " + self.id + " run_before_pass") + + def run_after_pass(self, mod, info): + events.append(" " + self.id + " run_after_pass") + + @tvm.transform.module_pass(opt_level=2) + def transform1(mod, ctx): + events.append(" transform1 pass") + return mod + + @tvm.transform.module_pass(opt_level=2) + def transform2(mod, ctx): + events.append(" transform2 pass") + return mod + + mod = get_test_model() + with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): + mod = transform1(mod) + mod = transform2(mod) + + assert ( + "%1 enter_pass_ctx" + "%2 enter_pass_ctx" + " %1 should_run" + " %2 should_run" + " %1 run_before_pass" + " %2 run_before_pass" + " transform1 pass" + " %1 run_after_pass" + " %2 run_after_pass" + " %1 should_run" + " %2 should_run" + " %1 run_before_pass" + " %2 run_before_pass" + " transform2 pass" + " %1 run_after_pass" + " %2 run_after_pass" + "%1 exit_pass_ctx" + "%2 exit_pass_ctx" == "".join(events) + ) diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index 8a37da33a10f..95069d29fd84 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -74,7 +74,7 @@ def expected(): def test_legalize_none(): - """Test doing nothing by returning 'None' """ + """Test doing nothing by returning 'None'""" def before(): x = relay.var("x", shape=(1, 64, 56, 56)) diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index edd46168a286..fb1094becb21 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -25,6 +25,7 @@ from tvm.relay import Function, Call from tvm.relay import analysis from tvm.relay import transform as _transform +from tvm.ir import instrument as _instrument from tvm.relay.testing import run_infer_type import tvm.testing @@ -533,17 +534,26 @@ def test_print_ir(capfd): assert "multiply" in out -__TRACE_COUNTER__ = 0 +@tvm.instrument.pass_instrument +class PassCounter: + def __init__(self): + # Just setting a garbage value to test set_up callback + self.counts = 1234 + def enter_pass_ctx(self): + self.counts = 0 -def _tracer(module, info, is_before): - global __TRACE_COUNTER__ - if bool(is_before): - __TRACE_COUNTER__ += 1 + def exit_pass_ctx(self): + self.counts = 0 + + def run_before_pass(self, module, info): + self.counts += 1 + + def get_counts(self): + return self.counts def test_print_debug_callback(): - global __TRACE_COUNTER__ shape = (1, 2, 3) tp = relay.TensorType(shape, "float32") x = relay.var("x", tp) @@ -559,15 +569,20 @@ def test_print_debug_callback(): ] ) - assert __TRACE_COUNTER__ == 0 mod = tvm.IRModule({"main": func}) - with tvm.transform.PassContext(opt_level=3, trace=_tracer): + pass_counter = PassCounter() + with tvm.transform.PassContext(opt_level=3, instruments=[pass_counter]): + # Should be reseted when entering pass context + assert pass_counter.get_counts() == 0 mod = seq(mod) - # TODO(@jroesch): when we remove new fn pass behavior we need to remove - # change this back to 3 - assert __TRACE_COUNTER__ == 5 + # TODO(@jroesch): when we remove new fn pass behavior we need to remove + # change this back to match correct behavior + assert pass_counter.get_counts() == 6 + + # Should be cleanned up after exiting pass context + assert pass_counter.get_counts() == 0 if __name__ == "__main__": diff --git a/tests/python/relay/test_prng.py b/tests/python/relay/test_prng.py index 01515a93546b..f79c79329b67 100644 --- a/tests/python/relay/test_prng.py +++ b/tests/python/relay/test_prng.py @@ -79,6 +79,23 @@ def test_threefry_sequential_generate(target, dev): ).any(), "Sequential generates should not have the same output" +@tvm.testing.parametrize_targets +def test_threefry_sequential_generate_remaining(target, dev): + key = tvm.relay.random.threefry_key(1) + key, rand1 = tvm.relay.TupleWrapper(tvm.relay.random.threefry_generate(key, (7,)), 2) + _, rand2 = tvm.relay.TupleWrapper(tvm.relay.random.threefry_generate(key, (7,)), 2) + out1, out2 = tvm.relay.create_executor( + "vm", + tvm.IRModule.from_expr(tvm.relay.Function([], tvm.relay.Tuple((rand1, rand2)))), + target=target, + device=dev, + ).evaluate()() + + assert ( + out1.asnumpy()[-3:] != out2.asnumpy()[-3:] + ).any(), "Sequential generates should not have the same output" + + def test_threefry_generate_infer(): oshape = (12,) key_type = tvm.relay.TensorType([10], dtype="uint64") @@ -137,12 +154,10 @@ def test_threefry_split_infer_fail(): @tvm.testing.requires_llvm -@pytest.mark.xfail(raises=tvm.error.TVMError) -def test_threefry_generate_incorrect_out_size(): +def test_threefry_generate_out_size(): key = tvm.relay.random.threefry_key(1) - # xfail: output size should be multiple of 4 key, rand1 = tvm.relay.TupleWrapper(tvm.relay.random.threefry_generate(key, (5,)), 2) - out1, out2 = tvm.relay.create_executor( + out = tvm.relay.create_executor( "vm", tvm.IRModule.from_expr(tvm.relay.Function([], rand1)), target=tvm.target.Target("llvm"), @@ -154,3 +169,4 @@ def test_threefry_generate_incorrect_out_size(): test_threefry_repeatability(tvm.target.Target("llvm"), tvm.device("cpu")) test_threefry_split(tvm.target.Target("llvm"), tvm.device("cpu")) test_threefry_sequential_generate(tvm.target.Target("llvm"), tvm.device("cpu")) + test_threefry_sequential_generate_remaining(tvm.target.Target("llvm"), tvm.device("cpu")) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index e8179a37756c..a0d37844b837 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -22,6 +22,7 @@ from tvm import IRModule, te, relay, parser from tvm.relay import op, transform, analysis +from tvm.relay.op import op as _op def infer_mod(mod, annotate_spans=True): @@ -416,6 +417,134 @@ def test_dynamic_function(): assert mod["main"].params[0].checked_type == s_tt +def test_custom_op_infer(): + """Tests infer type for custom_op""" + op_name = "custom_log" + _op.register(op_name, r"code(cal log of a tensor.)code") + _op.get(op_name).set_num_inputs(1) + _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") + # call default relation functions + _op.get(op_name).add_type_rel("Identity") + _op.get(op_name).set_support_level(1) + _op.register_pattern(op_name, _op.OpPattern.ELEMWISE) + _op.register_stateful(op_name, False) + + def clog(x): + return relay.Call(_op.get(op_name), [x]) + + tp = relay.TensorType((10, 10), "float32") + x = relay.var("x", tp) + sb = relay.ScopeBuilder() + t1 = sb.let("t1", clog(x)) + t2 = sb.let("t2", relay.add(t1, x)) + sb.ret(t2) + f = relay.Function([x], sb.get()) + fchecked = infer_expr(f) + assert fchecked.checked_type == relay.FuncType([tp], tp) + + +def test_custom_add_broadcast_op(): + """Tests infer type for broadcast custom_op""" + op_name = "custom_broadcast_add" + _op.register(op_name, r"code(Add two tensor with inner broadcasting.)code") + _op.get(op_name).set_num_inputs(2) + _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") + _op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.") + # call default relation functions + _op.get(op_name).add_type_rel("Broadcast") + _op.get(op_name).set_support_level(1) + _op.register_stateful(op_name, False) + + def broadcast_add(x, y): + return relay.Call(_op.get(op_name), [x, y]) + + x = relay.var("x", shape=(10, 4)) + y = relay.var("y", shape=(5, 10, 1)) + z = broadcast_add(x, y) + func = relay.Function([x, y], z) + t1 = relay.TensorType((10, 4), "float32") + t2 = relay.TensorType((5, 10, 1), "float32") + t3 = relay.TensorType((5, 10, 4), "float32") + expected_ty = relay.FuncType([t1, t2], t3) + assert_has_type(func, expected_ty) + + +def test_custom_op_rel_infer(): + """Tests infer type for custom_op""" + + def custom_log1_rel(arg_types, attrs): + assert len(arg_types) == 1, "type relation arg number mismatch!" + if attrs: + assert isinstance(attrs, DictAttrs) + inputa_type = arg_types[0] + return relay.TensorType(inputa_type.shape, inputa_type.dtype) + + op_name = "custom_log1" + _op.register(op_name, r"code(cal log of a tensor.)code") + _op.get(op_name).set_num_inputs(1) + _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") + _op.get(op_name).set_attrs_type_key("DictAttrs") + # call customized relation functions + _op.get(op_name).add_type_rel("custom_log1", custom_log1_rel) + _op.get(op_name).set_support_level(1) + _op.register_pattern(op_name, _op.OpPattern.ELEMWISE) + _op.register_stateful(op_name, False) + + def clog(x): + return relay.Call(_op.get(op_name), [x]) + + tp = relay.TensorType((10, 10), "float32") + x = relay.var("x", tp) + sb = relay.ScopeBuilder() + t1 = sb.let("t1", clog(x)) + t2 = sb.let("t2", relay.add(t1, x)) + sb.ret(t2) + f = relay.Function([x], sb.get()) + fchecked = infer_expr(f) + assert fchecked.checked_type == relay.FuncType([tp], tp) + + +def test_custom_op_rel_infer_exception(): + """Tests infer type for custom_op""" + + def custom_log1_rel(arg_types, attrs): + assert len(arg_types) == 2, "type relation arg number mismatch!" + return None + + op_name = "custom_log2" + _op.register(op_name, r"code(cal log of a tensor.)code") + _op.get(op_name).set_num_inputs(1) + _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") + _op.get(op_name).set_attrs_type_key("DictAttrs") + # call customized relation functions + _op.get(op_name).add_type_rel("custom_log2", custom_log1_rel) + _op.get(op_name).set_support_level(1) + _op.register_pattern(op_name, _op.OpPattern.ELEMWISE) + _op.register_stateful(op_name, False) + + def clog(x): + return relay.Call(_op.get(op_name), [x]) + + tp = relay.TensorType((10, 10), "float32") + x = relay.var("x", tp) + sb = relay.ScopeBuilder() + t1 = sb.let("t1", clog(x)) + t2 = sb.let("t2", relay.add(t1, x)) + sb.ret(t2) + f = relay.Function([x], sb.get()) + with pytest.raises(tvm.error.TVMError) as cm: + fchecked = infer_expr(f) + assert "type relation arg number mismatch" in str(cm.execption) + + +def test_repeat_register(): + op_name = "custom_log3" + _op.register(op_name, r"code(cal log of a tensor.)code") + with pytest.raises(tvm.error.TVMError) as cm: + _op.register(op_name) + assert "Operator custom_log3 is registered before" in str(cm.execption) + + if __name__ == "__main__": import sys diff --git a/tests/python/relay/utils/ref_funcs.py b/tests/python/relay/utils/ref_funcs.py new file mode 100644 index 000000000000..924805b2295e --- /dev/null +++ b/tests/python/relay/utils/ref_funcs.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + + +def gather_nd(data_np, indices_np, batch_dims=0): + """gather_nd implemented using numpy""" + data_shape = data_np.shape + indices_shape = indices_np.shape + + def gather_nd_batch_dims_1_ref(data, indices): + res = [] + for i, row in enumerate(data): + indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch + res.append(row[indices_tuple]) + # stack on the batch dim + return np.stack(res, 0) + + if batch_dims > 1: + data_np_reshape = np.reshape(data_np, (-1,) + data_shape[batch_dims:]) + indices_np_reshape = np.reshape( + indices_np, (indices_shape[0], -1) + indices_shape[(batch_dims + 1) :] + ) + + ref_res = gather_nd_batch_dims_1_ref(data_np_reshape, indices_np_reshape) + + out_shape = indices_shape[1 : (batch_dims + 1)] + ref_res.shape[1:] + ref_res = np.reshape(ref_res, out_shape) + elif batch_dims == 1: + ref_res = gather_nd_batch_dims_1_ref(data_np, indices_np) + else: + ref_res = data_np[tuple(indices_np)] + + return ref_res diff --git a/tests/python/topi/python/test_topi_conv2d_nchw.py b/tests/python/topi/python/test_topi_conv2d_nchw.py index ab2ba51aa7b1..8dbe94b45a2f 100644 --- a/tests/python/topi/python/test_topi_conv2d_nchw.py +++ b/tests/python/topi/python/test_topi_conv2d_nchw.py @@ -148,6 +148,8 @@ def check_target(target): if use_cudnn: check_target("cuda -model=unknown -libs=cudnn") + if ("opencl", tvm.device("opencl")) in tvm.testing.enabled_targets(): + check_target("opencl -device=intel_graphics") @tvm.testing.uses_gpu diff --git a/tests/python/topi/python/test_topi_group_conv2d.py b/tests/python/topi/python/test_topi_group_conv2d.py index e5a2fe7f28ab..55b24feece93 100644 --- a/tests/python/topi/python/test_topi_group_conv2d.py +++ b/tests/python/topi/python/test_topi_group_conv2d.py @@ -30,6 +30,22 @@ import tvm.testing +def _transform_data(data, bn): + # NCHW -> NCHW[x]c + batch_size, channel, height, width = data.shape + data = np.reshape(data, (batch_size, channel // bn, bn, height, width)) + data = np.transpose(data, (0, 1, 3, 4, 2)) + return data + + +def _transform_kernel(kernel, ic_bn, oc_bn): + # OIHW -> OIHW[x]o[x]i + out_channel, in_channel, kh, kw = kernel.shape + kernel = np.reshape(kernel, (out_channel // oc_bn, oc_bn, in_channel // ic_bn, ic_bn, kh, kw)) + kernel = np.transpose(kernel, (0, 2, 4, 5, 1, 3)) + return kernel + + _group_conv2d_nchw_implement = { "generic": (topi.nn.group_conv2d_nchw, topi.generic.schedule_group_conv2d_nchw), "gpu": (topi.cuda.group_conv2d_nchw, topi.cuda.schedule_group_conv2d_nchw), @@ -154,6 +170,7 @@ def check_target(target): oc_block_factor = 4 +ic_block_factor = 4 def verify_group_conv2d_NCHWc_int8( @@ -176,6 +193,151 @@ def verify_group_conv2d_NCHWc_int8( in_height = in_width = in_size + A = te.placeholder( + (batch, in_channel // ic_block_factor, in_height, in_width, ic_block_factor), + name="A", + dtype="int8", + ) + W = te.placeholder( + ( + num_filter // oc_block_factor, + (in_channel // groups) // ic_block_factor, + kernel, + kernel, + oc_block_factor, + ic_block_factor, + ), + name="W", + dtype="int8", + ) + bias = te.placeholder( + (num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype="int8" + ) + + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_NCHWc_int8") + def get_ref_data(): + a_np = np.random.randint( + low=-128, high=127, size=(batch, in_channel, in_height, in_width) + ).astype(dtype) + w_np = np.random.randint( + low=-128, high=128, size=(num_filter, in_channel // groups, kernel, kernel) + ).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype( + dtype + ) + + # convert to NCHWc + _, _, out_height, out_width = c_np.shape + c_np = c_np.reshape( + (batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width) + ).transpose(0, 1, 3, 4, 2) + + if add_bias: + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + + return ( + _transform_data(a_np, ic_block_factor), + _transform_kernel(w_np, ic_block_factor, oc_block_factor), + b_np, + c_np, + ) + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_target(target): + dev = tvm.device(target, 0) + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + return + if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version): + print("Skip because int8 intrinsics are not available") + return + + print("Running on target: %s" % target) + with tvm.target.Target(target): + C = topi.cuda.group_conv2d_NCHWc_int8(A, W, stride, padding, dilation, groups, dtype) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = topi.cuda.schedule_group_conv2d_NCHWc_int8([C]) + + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) + if add_bias: + func = tvm.build( + s, + [A, W, bias, C], + target, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" + % ( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dilation, + groups, + ), + ) + func(a, w, b, c) + else: + func = tvm.build( + s, + [A, W, C], + target, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" + % ( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dilation, + groups, + ), + ) + func(a, w, c) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) + + for target in ["cuda"]: + check_target(target) + + +def verify_group_conv2d_nchw_int8( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dilation, + groups, + add_bias=False, + add_relu=False, +): + print( + "Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)" + % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups) + ) + + in_height = in_width = in_size + A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype="int8") W = te.placeholder((num_filter, in_channel // groups, kernel, kernel), name="W", dtype="int8") bias = te.placeholder( @@ -187,7 +349,7 @@ def verify_group_conv2d_NCHWc_int8( bias_shape = get_const_tuple(bias.shape) dtype = A.dtype - @memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_NCHWc_int8") + @memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_nchw_int8") def get_ref_data(): a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype) w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype) @@ -442,6 +604,30 @@ def test_group_conv2d_NCHWc_int8(): verify_group_conv2d_NCHWc_int8(9, 128, 56, 128, 3, 1, 1, 1, 32) +@tvm.testing.requires_cuda +def test_group_conv2d_nchw_int8(): + with Int8Fallback(): + # ResNeXt-50 workload + verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 1, 32) + verify_group_conv2d_nchw_int8(1, 256, 56, 256, 3, 2, 1, 1, 32) + verify_group_conv2d_nchw_int8(1, 256, 28, 256, 3, 1, 1, 1, 32) + verify_group_conv2d_nchw_int8(1, 512, 28, 512, 3, 2, 1, 1, 32) + verify_group_conv2d_nchw_int8(1, 512, 14, 512, 3, 1, 1, 1, 32) + verify_group_conv2d_nchw_int8(1, 1024, 14, 1024, 3, 2, 1, 1, 32) + verify_group_conv2d_nchw_int8(1, 1024, 7, 1024, 3, 1, 1, 1, 32) + + # bias, relu + verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True) + verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_bias=True) + verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True, add_bias=True) + # dilation + verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 2, 32) + + # batch size + verify_group_conv2d_nchw_int8(2, 128, 56, 128, 3, 1, 1, 1, 32) + verify_group_conv2d_nchw_int8(9, 128, 56, 128, 3, 1, 1, 1, 32) + + def test_group_conv2d_nhwc(): # ResNeXt-50 workload verify_group_conv2d_nhwc(1, 128, 56, 128, 3, 1, 1, 1, 32) @@ -468,4 +654,5 @@ def test_group_conv2d_nhwc(): if __name__ == "__main__": test_group_conv2d_nchw() test_group_conv2d_NCHWc_int8() + test_group_conv2d_nchw_int8() test_group_conv2d_nhwc() diff --git a/tests/python/topi/python/test_topi_prng.py b/tests/python/topi/python/test_topi_prng.py index 1be32e6ea1d1..b9ac51419772 100644 --- a/tests/python/topi/python/test_topi_prng.py +++ b/tests/python/topi/python/test_topi_prng.py @@ -112,6 +112,12 @@ def test_threefry_generate(target, dev): # check that gen out does not equal input assert (a != gen).any(), "Output generator should be different from input generator" + # check that we can generate data whose total number of elements is not a multiple of 4. + a, rands = threefry_generate(target, dev, gen, (7,)) + assert ( + rands.shape[0] == 7 and len(rands.shape) == 1 + ), "Output shape should match requested shape" + # test enough generates to go over generate limit gen = np.array( [0, 0, 0, 0, 0, 0, 0, 2 ** 64 - 2, 1 << 63, 0], dtype="uint64" diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 20172f07fd9e..ddde2e20e754 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -398,10 +398,12 @@ def check_device(target): check_device(target) -def verify_strided_slice(in_shape, begin, end, strides=None): +def verify_strided_slice(in_shape, begin, end, strides=None, axes=None): A = te.placeholder(shape=in_shape, name="A") strides = [1, 1, 1] if strides is None else strides - B = topi.strided_slice(A, begin, end, strides) + 1 + if axes: + strides = [strides[axis] for axis in axes] + B = topi.strided_slice(A, begin, end, strides, axes) + 1 def check_device(target): dev = tvm.device(target, 0) @@ -414,7 +416,7 @@ def check_device(target): foo = tvm.build(s, [A, B], target, name="stride_slice") x_np = np.random.uniform(size=in_shape).astype(A.dtype) - out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) + 1 + out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides, axes=axes) + 1 data_nd = tvm.nd.array(x_np, dev) out_nd = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype) foo(data_nd, out_nd) @@ -819,6 +821,7 @@ def test_strided_slice(): verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3]) verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3]) verify_strided_slice((3, 4, 3), [0, 0, 0], [None, None, None]) + verify_strided_slice((3, 4, 3), [0], [2], None, axes=[1]) @tvm.testing.uses_gpu diff --git a/tests/python/topi/python/test_topi_unique.py b/tests/python/topi/python/test_topi_unique.py index 032b4db73918..3e26241cea94 100644 --- a/tests/python/topi/python/test_topi_unique.py +++ b/tests/python/topi/python/test_topi_unique.py @@ -30,15 +30,24 @@ def calc_numpy_unique(data, is_sorted=False): num_uniq = np.array([len(uniq)]).astype("int32") if not is_sorted: order = np.argsort(index) + index = np.sort(index) reverse_order = np.argsort(order) uniq = uniq[order].astype(data.dtype) inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") counts = counts[order].astype("int32") - return [uniq.astype(data.dtype), inverse.astype("int32"), counts, num_uniq] + return [ + uniq.astype(data.dtype), + index.astype("int32"), + inverse.astype("int32"), + counts, + num_uniq, + ] - def check_unique(data, is_sorted=False): + def check_unique(data, is_sorted=False, with_counts=False): # numpy reference - np_unique, np_indices, np_counts, np_num_unique = calc_numpy_unique(data, is_sorted) + np_unique, np_indices, np_inverse_indices, np_counts, np_num_unique = calc_numpy_unique( + data, is_sorted + ) num_unique = np_num_unique[0] implementations = { @@ -59,44 +68,54 @@ def check_unique(data, is_sorted=False): tvm_data = tvm.nd.array(data, device=dev) tvm_unique = tvm.nd.array(np.zeros(data.shape).astype(data.dtype), device=dev) tvm_indices = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) + tvm_inverse_indices = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) tvm_num_unique = tvm.nd.array(np.zeros([1]).astype("int32"), device=dev) - # without counts with tvm.target.Target(target): te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype)) - outs = fcompute(te_input, False) + outs = fcompute(te_input, with_counts) s = fschedule(outs) func = tvm.build(s, [te_input, *outs]) - func(tvm_data, tvm_unique, tvm_indices, tvm_num_unique) - assert tvm_num_unique.numpy()[0] == np_num_unique - np.testing.assert_allclose(tvm_unique.numpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5) - np.testing.assert_allclose(tvm_indices.numpy(), np_indices, atol=1e-5, rtol=1e-5) + if with_counts: + tvm_counts = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) + func( + tvm_data, + tvm_unique, + tvm_indices, + tvm_inverse_indices, + tvm_num_unique, + tvm_counts, + ) + else: + func(tvm_data, tvm_unique, tvm_indices, tvm_inverse_indices, tvm_num_unique) - # with counts - tvm_counts = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) - with tvm.target.Target(target): - te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype)) - outs = fcompute(te_input, True) - s = fschedule(outs) - func = tvm.build(s, [te_input, *outs]) - func(tvm_data, tvm_unique, tvm_indices, tvm_num_unique, tvm_counts) - - np_unique, np_indices, _, np_num_unique = calc_numpy_unique(data, is_sorted) num_unique = np_num_unique[0] assert tvm_num_unique.numpy()[0] == np_num_unique + np.testing.assert_allclose(tvm_unique.numpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5) - np.testing.assert_allclose(tvm_indices.numpy(), np_indices, atol=1e-5, rtol=1e-5) - np.testing.assert_allclose(tvm_counts.numpy()[:num_unique], np_counts, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose( + tvm_indices.numpy()[:num_unique], np_indices, atol=1e-5, rtol=1e-5 + ) + + np.testing.assert_allclose( + tvm_inverse_indices.numpy(), np_inverse_indices, atol=1e-5, rtol=1e-5 + ) + + if with_counts: + np.testing.assert_allclose( + tvm_counts.numpy()[:num_unique], np_counts, atol=1e-5, rtol=1e-5 + ) for in_dtype in ["int32", "int64"]: for is_sorted in [True, False]: - data = np.random.randint(0, 100, size=(1)).astype(in_dtype) - check_unique(data, is_sorted) - data = np.random.randint(0, 10, size=(10)).astype(in_dtype) - check_unique(data, is_sorted) - data = np.random.randint(0, 100, size=(10000)).astype(in_dtype) - check_unique(data, is_sorted) + for with_counts in [True, False]: + data = np.random.randint(0, 100, size=(1)).astype(in_dtype) + check_unique(data, is_sorted, with_counts) + data = np.random.randint(0, 10, size=(10)).astype(in_dtype) + check_unique(data, is_sorted, with_counts) + data = np.random.randint(0, 100, size=(10000)).astype(in_dtype) + check_unique(data, is_sorted, with_counts) if __name__ == "__main__": diff --git a/tests/python/unittest/test_autotvm_flop_calculator.py b/tests/python/unittest/test_autotvm_flop_calculator.py index e07cdac9cc9c..e28beaf98709 100644 --- a/tests/python/unittest/test_autotvm_flop_calculator.py +++ b/tests/python/unittest/test_autotvm_flop_calculator.py @@ -152,7 +152,7 @@ def test_average_pool(): def test_move(): - """No float number operation in simple move. So the estimator should raise an error """ + """No float number operation in simple move. So the estimator should raise an error""" N = 1024 A = te.placeholder((N,)) diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py index 9d4255c86b5e..781fd7f93886 100644 --- a/tests/python/unittest/test_runtime_container.py +++ b/tests/python/unittest/test_runtime_container.py @@ -16,6 +16,7 @@ # under the License. import numpy as np +import random import tvm import tvm.testing import pickle @@ -77,7 +78,16 @@ def test_string(): assert s == z +def test_shape_tuple(): + shape = [random.randint(-10, 10) for _ in range(5)] + stuple = _container.ShapeTuple(shape) + len(stuple) == len(shape) + for a, b in zip(stuple, shape): + assert a == b + + if __name__ == "__main__": test_string() test_adt_constructor() test_tuple_object() + test_shape_tuple() diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index fc138bb43f1a..56ba9a085ffc 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -215,14 +215,21 @@ def check_cuda(n, value, lanes): y, x = s[A].op.axis s[A].vectorize(x) s[A].bind(y, bx) - fun = tvm.build(s, [A], "cuda", name="make_int4x8") + kernel_name = "make_int4x" + str(lanes) + fun = tvm.build(s, [A], "cuda", name=kernel_name) np_a = np.full((n, lanes), value, dtype="int8") a = tvm.nd.empty((n, lanes), dtype, dev) fun(a) np.testing.assert_equal(a.numpy(), np_a) + check_cuda(64, 1, 4) + check_cuda(64, 7, 4) check_cuda(64, 1, 8) check_cuda(64, 7, 8) + check_cuda(64, 1, 16) + check_cuda(64, 7, 16) + check_cuda(64, 1, 32) + check_cuda(64, 7, 32) @tvm.testing.requires_gpu diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 2770ae5878d0..dc165331729e 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -255,7 +255,7 @@ def test_vulkan_unique(): dtype = "int32" x = relay.var("x", shape=(relay.Any(),), dtype=dtype) mod = tvm.IRModule() - [unique, _, num_unique] = relay.unique(x, is_sorted=True) + [unique, _, _, num_unique] = relay.unique(x, is_sorted=True) mod["main"] = relay.Function([x], relay.op.strided_slice(unique, begin=[0], end=num_unique)) x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) res_np = np.unique(x_np) diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index b3ef8d5570b6..52dc4ccd9fef 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -18,6 +18,8 @@ import tvm from tvm.script import ty from tvm import te, tir +import numpy as np +import tvm.testing def test_unique_name(): @@ -281,6 +283,38 @@ def test_error_reporting(): assert False +def test_constant(): + M = 11 + A = te.placeholder((M,), name="A") + B = te.compute(tuple(), lambda: 2, name="B") + # Manually craft ProducerLoad because `B[]` is not allowed. + C = te.compute( + (M,), lambda x: A[x] + tvm.tir.expr.ProducerLoad(B, []), name="C", tag="broadcast" + ) + + func = te.create_prim_func([C, A]) + func = tvm.build(func) + a_np = np.random.uniform(size=(M,)).astype(A.dtype) + c = tvm.nd.array(np.zeros(M, dtype=C.dtype)) + x = func(c, tvm.nd.array(a_np)) + tvm.testing.assert_allclose(a_np + 2, c.numpy()) + + +def test_data_dependent_access(): + A = te.placeholder((10,), name="A") + B = te.placeholder((10,), name="B", dtype="int32") + C = te.compute((10,), lambda i: A[B[i]]) + + func = te.create_prim_func([C, A, B]) + func = tvm.build(func) + + a_np = np.random.uniform(size=(10,)).astype(A.dtype) + b_np = np.arange(10, dtype=B.dtype) + c = tvm.nd.array(np.zeros(10, dtype=C.dtype)) + func(c, tvm.nd.array(a_np), tvm.nd.array(b_np)) + tvm.testing.assert_allclose(a_np[b_np], c.numpy()) + + if __name__ == "__main__": test_unique_name() test_matmul() @@ -290,3 +324,4 @@ def test_error_reporting(): test_extern() test_arg_order() test_error_reporting() + test_constant() diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py new file mode 100644 index 000000000000..c34ec8d610d6 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -0,0 +1,373 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.script import ty + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_multi_producer_consumer(a: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + D = tir.match_buffer(d, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 # B has two consumers + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + with tir.block([128, 128], "D") as [vi, vj]: + D[vi, vj] = B[vi, vj] + 2.0 + C[vi, vj] # D has two producers + + +@tvm.script.tir +def elementwise_multi_consumer_inlined(a: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + D = tir.match_buffer(d, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + with tir.block([128, 128], "D") as [vi, vj]: + D[vi, vj] = A[vi, vj] * 2.0 + 2.0 + C[vi, vj] + + +@tvm.script.tir +def elementwise_standalone(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_standalone_dce(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_under_loop(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + for i in tir.serial(0, 128): + for j in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + for j in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_inlined(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + +@tvm.script.tir +def fail_multi_reader_writer(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.alloc_buffer((128, 128)) + D = tir.match_buffer(d, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + C[vi, vj] = A[vi, vj] + 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + D[vi, vj] = B[vi, vj] + C[vi, vj] + + +@tvm.script.tir +def elementwise_multi_reverse_loads(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = (B[vi, vj] + 1.0) * (B[vi, vj] * 2.0) + 3.0 + + +@tvm.script.tir +def elementwise_multi_reverse_loads_inlined(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0 + + +@tvm.script.tir +def opaque_access_load(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + tir.reads(B[0:128, 0:128]) + tir.writes(C[0:128, 0:128]) + C[vi, vj] = tir.load("float32", B.data, vi * 128 + vj) + 1.0 + + +@tvm.script.tir +def opaque_access_store(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + tir.reads(B[0:128, 0:128]) + tir.writes(C[0:128, 0:128]) + tir.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) + C[vi, vj] = tir.load("float32", B.data, vi * 16 + vj) + 1.0 + + +@tvm.script.tir +def buffer_matched(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + Bb = tir.match_buffer_region(B[vi : vi + 1, vj]) + C[vi, vj] = Bb[0, 0] + 1.0 + + +@tvm.script.tir +def elementwise_predicate(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "C") as [vi, vj]: + tir.where(B[i, j] < 10.0) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_predicate_inlined(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "C") as [vi, vj]: + tir.where(A[i, j] * 2.0 < 10.0) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + +@tvm.script.tir +def elementwise_multi_loads(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 126], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + B[vi, vj + 1] + B[vi, vj + 2] + + +@tvm.script.tir +def elementwise_multi_loads_inlined(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 126], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0 + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_compute_inline_elementwise(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.compute_inline(block_b) + tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert sch.get(block_c).name_hint == "C" + + +def test_compute_inline_under_loop(): + sch = tir.Schedule(elementwise_under_loop, debug_mode=True) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.compute_inline(block_b) + tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert sch.get(block_c).name_hint == "C" + + +def test_compute_inline_as_dce(): + sch = tir.Schedule(elementwise_standalone, debug_mode=True) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.compute_inline(block_b) + tvm.ir.assert_structural_equal(elementwise_standalone_dce, sch.mod["main"]) + assert sch.get(block_c).name_hint == "C" + + +def test_compute_inline_multi_consumer(): + sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mode=True) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + block_d = sch.get_block("D") + sch.compute_inline(block_b) + tvm.ir.assert_structural_equal(elementwise_multi_consumer_inlined, sch.mod["main"]) + assert sch.get(block_c).name_hint == "C" + assert sch.get(block_d).name_hint == "D" + + +def test_compute_inline_fail_multi_writer(): + sch = tir.Schedule(fail_multi_reader_writer, debug_mode=True, error_render_level="detail") + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block_b) + + +def test_reverse_compute_inline_elementwise(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.reverse_compute_inline(block_c) + tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert sch.get(block_b).name_hint == "B" + + +def test_reverse_compute_inline_under_loop(): + sch = tir.Schedule(elementwise_under_loop, debug_mode=True) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.reverse_compute_inline(block_c) + tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert sch.get(block_b).name_hint == "B" + + +def test_reverse_compute_inline_fail_as_dce(): + sch = tir.Schedule(elementwise_standalone, debug_mode=True) + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.reverse_compute_inline(block_b) + + +def test_reverse_compute_inline_fail_multi_producer(): + sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mode=True) + block_d = sch.get_block("D") + with pytest.raises(tvm.tir.ScheduleError): + sch.reverse_compute_inline(block_d) + + +def test_reverse_compute_inline_fail_multi_reader(): + sch = tir.Schedule(fail_multi_reader_writer, debug_mode=True) + block_c = sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reverse_compute_inline(block_c) + + +def test_reverse_compute_multi_reverse_loads(): + sch = tir.Schedule(elementwise_multi_reverse_loads, debug_mode=True) + block_c = sch.get_block("C") + sch.reverse_compute_inline(block_c) + tvm.ir.assert_structural_equal(elementwise_multi_reverse_loads_inlined, sch.mod["main"]) + + +def test_reverse_compute_fail_multi_reverse_loads(): + sch = tir.Schedule(elementwise_multi_loads, debug_mode=True) + block_c = sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reverse_compute_inline(block_c) + + +def test_opaque_access_load(): + sch = tir.Schedule(opaque_access_load, debug_mode=True) + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block_b) + + +def test_opaque_access_store(): + sch = tir.Schedule(opaque_access_store, debug_mode=True) + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block_b) + + +def test_buffer_matched(): + sch = tir.Schedule(buffer_matched, debug_mode=True) + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block_b) + + +def test_compute_inline_predicate(): + sch = tir.Schedule(elementwise_predicate, debug_mode=True) + block_b = sch.get_block("B") + sch.compute_inline(block_b) + tvm.ir.assert_structural_equal(elementwise_predicate_inlined, sch.mod["main"]) + + +def test_compute_inline_multi_loads(): + sch = tir.Schedule(elementwise_multi_loads, debug_mode=True) + block_b = sch.get_block("B") + sch.compute_inline(block_b) + tvm.ir.assert_structural_equal(elementwise_multi_loads_inlined, sch.mod["main"]) + + +if __name__ == "__main__": + test_compute_inline_elementwise() + test_compute_inline_under_loop() + test_compute_inline_as_dce() + test_compute_inline_multi_consumer() + test_compute_inline_fail_multi_writer() + test_reverse_compute_inline_elementwise() + test_reverse_compute_inline_under_loop() + test_reverse_compute_inline_fail_as_dce() + test_reverse_compute_inline_fail_multi_producer() + test_reverse_compute_inline_fail_multi_reader() + test_reverse_compute_multi_reverse_loads() + test_reverse_compute_fail_multi_reverse_loads() + test_opaque_access_load() + test_opaque_access_store() + test_buffer_matched() + test_compute_inline_predicate() + test_compute_inline_multi_loads() diff --git a/tests/python/unittest/test_tir_schedule_error.py b/tests/python/unittest/test_tir_schedule_error.py new file mode 100644 index 000000000000..1fa658feabe3 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_error.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.script import ty + + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = tir.float32(0) + for k in range(0, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_tir_schedule_error_detail(): + sch = tir.Schedule(matmul, debug_mode=True, error_render_level="detail") + with pytest.raises(tir.ScheduleError) as excinfo: + sch.get_block("wrong_name") + (msg,) = excinfo.value.args + assert "Cannot find a block with the name: wrong_name" in msg + + +def test_tir_schedule_error_fast(): + sch = tir.Schedule(matmul, debug_mode=True, error_render_level="fast") + with pytest.raises(tir.ScheduleError) as excinfo: + sch.get_block("wrong_name") + (msg,) = excinfo.value.args + assert "Cannot find a block with the specified name" in msg + + +def test_tir_schedule_error_none(): + sch = tir.Schedule(matmul, debug_mode=True, error_render_level="none") + with pytest.raises(tir.ScheduleError) as excinfo: + sch.get_block("wrong_name") + (msg,) = excinfo.value.args + assert "(not rendered)" in msg + + +if __name__ == "__main__": + test_tir_schedule_error_detail() + test_tir_schedule_error_fast() + test_tir_schedule_error_none() diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index c5163a8457af..1e3c8061e029 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -20,7 +20,7 @@ def lower_stmt(sche, params, passfunc): - func = tvm.driver.build_module.form_irmodule(sche, params, "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(sche, params, "main", None)["main"] func = passfunc()(tvm.IRModule.from_expr(func))["main"] stmt = func.body return stmt @@ -42,7 +42,7 @@ def get_promoted(op): lambda i: topi.cast(op(topi.cast(a[i], "float"), topi.cast(b[i], "float")), "bfloat16"), ) s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] return func.body def test_promoted(op): @@ -111,7 +111,7 @@ def get_target(): ), ) s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] return func.body tvm.ir.assert_structural_equal(get_eliminated(), get_target()) @@ -151,7 +151,7 @@ def check(fcompute_before, fcompute_after): b = te.placeholder((100,), dtype="uint16", name="B") c = te.compute((100,), fcompute_after(a, b), name="C") s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] tvm.ir.assert_structural_equal(stmt, func.body) def orig1(a, b): diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 7d02e4f12c1d..252a187dbdc5 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -522,7 +522,7 @@ def test_hoisting_block_scope_1(): s[B.op].bind(xi, te.thread_axis("threadIdx.y")) s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x")) s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) - func = tvm.driver.build_module.form_irmodule(s, [A, B], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [A, B], "main", None)["main"] stmt = func.body new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) @@ -622,7 +622,7 @@ def test_hoisting_block_scope_4(): s[C].pragma(xo2, "parallel_stride_pattern") s[C].pragma(xo2, "parallel_barrier_when_finish") s[C].vectorize(xi) - func = tvm.driver.build_module.form_irmodule(s, [A, B, C], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None)["main"] stmt = func.body new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py new file mode 100644 index 000000000000..9d917466758b --- /dev/null +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import te +import numpy + + +@pytest.fixture +def mod_without_attrs(): + ib = tvm.tir.ir_builder.create() + A = tvm.tir.decl_buffer(name="A", shape=[1]) + stmt = ib.get() + return tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], stmt)) + + +@pytest.fixture +def mod(mod_without_attrs): + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))( + mod_without_attrs + ) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) + + return mod + + +def test_fails_if_not_global_symbol(mod_without_attrs): + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))( + mod_without_attrs + ) + with pytest.raises(tvm.TVMError, match="Expect PrimFunc to have the global_symbol attribute"): + f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] + + +def test_fails_if_no_target(mod_without_attrs): + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod_without_attrs) + with pytest.raises(tvm.TVMError, match="Require the target attribute"): + f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] + + +@tvm.testing.parametrize_targets("c", "llvm", "cuda") +def test_device_setup(mod, target, dev): + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod) + f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] + assert len(f.params) == 1 + assert f.params[0].name == "arg0" + assert f.body.node == "default" + assert f.body.attr_key == "device_id" + assert f.body.value == 0 + assert f.body.body.node == "default" + assert f.body.body.attr_key == "device_type" + assert f.body.body.value == dev.device_type + + +def test_no_buffers_no_device_setup(): + ib = tvm.tir.ir_builder.create() + A = ib.pointer("float32", name="A") + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], stmt)) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) + + f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] + assert len(f.params) == 1 + assert f.body.var.name == "A" + assert f.body.value.name == "arg0" + + +def test_argument_mapping(mod): + f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] + assert len(f.params) == 1 + assert f.params[0].name == "arg0" + assert f.body.body.body.var.name == "A" + assert f.body.body.body.value.name == "arg0" + + +def test_argument_mapping_multiple(): + ib = tvm.tir.ir_builder.create() + A = tvm.tir.decl_buffer(name="A", shape=[1]) + B = tvm.tir.decl_buffer(name="B", shape=[1]) + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, B], stmt)) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) + + f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] + assert len(f.params) == 2 + assert f.params[0].name == "arg0" + assert f.params[1].name == "arg1" + assert f.body.body.body.var.name == "A" + assert f.body.body.body.value.name == "arg0" + assert f.body.body.body.body.var.name == "B" + assert f.body.body.body.body.value.name == "arg1" + + +def test_argument_mapping_multiple_matching(): + ib = tvm.tir.ir_builder.create() + A = tvm.tir.decl_buffer(name="A", shape=[1]) + B = tvm.tir.decl_buffer(name="B", shape=[1]) + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, A], stmt)) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) + + f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] + assert len(f.params) == 2 + assert f.params[0].name == "arg0" + assert f.params[1].name == "arg1" + assert f.body.body.body.var.name == "A" + assert f.body.body.body.value.name == "arg0" + assert f.body.body.body.body.condition.a.name == "A" + assert f.body.body.body.body.condition.b.name == "arg1" + + +def test_body(): + ib = tvm.tir.ir_builder.create() + A = tvm.tir.decl_buffer(name="A", shape=[1]) + B = tvm.tir.decl_buffer(name="B", shape=[1]) + C = ib.buffer_ptr(A) + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, B, C], stmt)) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) + f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] + assert len(f.params) == 3 + assert f.params[0].name == "arg0" + assert f.params[1].name == "arg1" + assert f.params[2].name == "arg2" + assert f.body.body.body.var.name == "A" + assert f.body.body.body.value.name == "arg2" + assert f.body.body.body.body.var.name == "B" + assert f.body.body.body.body.value.name == "arg1" + assert f.body.body.body.body.body.condition.a.name == "A" + assert f.body.body.body.body.body.condition.b.name == "arg0" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh index bbfd899e2ef4..e471fa9c49b2 100755 --- a/tests/scripts/task_ci_setup.sh +++ b/tests/scripts/task_ci_setup.sh @@ -30,7 +30,7 @@ set -o pipefail # echo "Addtiional setup in" ${CI_IMAGE_NAME} -python3 -m pip install --user tlcpack-sphinx-addon==0.1.6 synr==0.3.0 +python3 -m pip install --user tlcpack-sphinx-addon==0.2.0 synr==0.3.0 # Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in # Jenkinsfile. We expect config.cmake to be present from pack_lib(). diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 1eb75be830c3..79b87b75fc70 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -40,7 +40,9 @@ rm -rf docs/doxygen # prepare auto scheduler tutorials rm -rf tutorials/auto_scheduler/*.json +rm -rf tutorials/get_started/*.json cp -f tutorials/auto_scheduler/ci_logs/*.json tutorials/auto_scheduler +cp -f tutorials/auto_scheduler/ci_logs/*.json tutorials/get_started # remove stale tutorials and always build from scratch. rm -rf docs/tutorials diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index a7a844b7a86e..5b0931405212 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -437,7 +437,7 @@ def tune_and_evaluate(): # in function :code:`run_tuning`. Say, # :code:`tuner = auto_scheduler.TaskScheduler(tasks, task_weights, load_log_file=log_file)` # 4. If you have multiple target CPUs, you can use all of them for measurements to -# parallelize the measurements. Check this :ref:`section ` +# parallelize the measurements. Check this :ref:`section ` # to learn how to use the RPC Tracker and RPC Server. # To use the RPC Tracker in auto-scheduler, replace the runner in :code:`TuningOptions` # with :any:`auto_scheduler.RPCRunner`. diff --git a/tutorials/autotvm/tune_conv2d_cuda.py b/tutorials/autotvm/tune_conv2d_cuda.py index 4e80a74413aa..ef921563e466 100644 --- a/tutorials/autotvm/tune_conv2d_cuda.py +++ b/tutorials/autotvm/tune_conv2d_cuda.py @@ -77,7 +77,7 @@ # to tune other operators such as depthwise convolution and gemm. # In order to fully understand this template, you should be familiar with # the schedule primitives and auto tuning API. You can refer to the above -# tutorials and :doc:`autotvm tutorial ` +# tutorials and :ref:`autotvm tutorial ` # # It is worth noting that the search space for a conv2d operator # can be very large (at the level of 10^9 for some input shapes) diff --git a/tutorials/autotvm/tune_relay_arm.py b/tutorials/autotvm/tune_relay_arm.py index 68d263b6f29a..debf8b8ecf60 100644 --- a/tutorials/autotvm/tune_relay_arm.py +++ b/tutorials/autotvm/tune_relay_arm.py @@ -278,6 +278,10 @@ def tune_tasks( tuner_obj = XGBTuner(tsk, loss_type="rank") elif tuner == "xgb_knob": tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="knob") + elif tuner == "xgb_itervar": + tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="itervar") + elif tuner == "xgb_curve": + tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="curve") elif tuner == "ga": tuner_obj = GATuner(tsk, pop_size=50) elif tuner == "random": @@ -291,7 +295,7 @@ def tune_tasks( if os.path.isfile(tmp_log_file): tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) - # do tuning + # process tuning tsk_trial = min(n_trial, len(tsk.config_space)) tuner_obj.tune( n_trial=tsk_trial, diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py index 0bd656dd81dd..8f631075429f 100644 --- a/tutorials/dev/low_level_custom_pass.py +++ b/tutorials/dev/low_level_custom_pass.py @@ -86,7 +86,7 @@ def find_width8(op): - """ Find all the 'tir.For' nodes whose extent can be divided by 8. """ + """Find all the 'tir.For' nodes whose extent can be divided by 8.""" if isinstance(op, tvm.tir.For): if isinstance(op.extent, tvm.tir.IntImm): if op.extent.value % 8 == 0: @@ -110,7 +110,7 @@ def find_width8(op): def vectorize8(op): - """ Split can vectorize the loops found in `find_width8`. """ + """Split can vectorize the loops found in `find_width8`.""" if op in loops: extent = op.extent.value name = op.loop_var.name diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index 6a33d14e38c8..3804b1496d05 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -273,14 +273,16 @@ def visit_constant(self, c): # An example is below. -def print_ir(mod, info, is_before): +@tvm.instrument.pass_instrument +class PrintIR: """Print the name of the pass, the IR, only before passes execute.""" - if is_before: + + def run_before_pass(self, mod, info): print("Running pass: {}", info) print(mod) -with tvm.transform.PassContext(opt_level=3, trace=print_ir): +with tvm.transform.PassContext(opt_level=3, instruments=[PrintIR()]): with tvm.target.Target("llvm"): # Perform the optimizations. mod = seq(mod) diff --git a/tutorials/get_started/tune_matmul_x86.py b/tutorials/get_started/auto_scheduler_matmul_x86.py similarity index 99% rename from tutorials/get_started/tune_matmul_x86.py rename to tutorials/get_started/auto_scheduler_matmul_x86.py index 8156d0e106ff..f9fb3615aedc 100644 --- a/tutorials/get_started/tune_matmul_x86.py +++ b/tutorials/get_started/auto_scheduler_matmul_x86.py @@ -23,7 +23,7 @@ In this tutorial, we will show how TVM's Auto Scheduling feature can find optimal schedules without the need for writing a custom template. -Different from the template-based :ref:`` which relies on +Different from the template-based :doc:`AutoTVM ` which relies on manual templates to define the search space, the auto-scheduler does not require any templates. Users only need to write the computation declaration without any schedule commands or templates. The auto-scheduler can diff --git a/tutorials/get_started/autotvm_matmul.py b/tutorials/get_started/autotvm_matmul_x86.py similarity index 96% rename from tutorials/get_started/autotvm_matmul.py rename to tutorials/get_started/autotvm_matmul_x86.py index 234315b53ff9..97e1b0b8b55f 100644 --- a/tutorials/get_started/autotvm_matmul.py +++ b/tutorials/get_started/autotvm_matmul_x86.py @@ -15,17 +15,18 @@ # specific language governing permissions and limitations # under the License. """ -Optimizing Operators with Templates and AutoTVM -=============================================== +.. _tutorial-autotvm-matmul-x86: + +Optimizing Operators with Schedule Templates and AutoTVM +======================================================== **Authors**: `Lianmin Zheng `_, `Chris Hoge `_ -In this tutorial, we will now show how the TVM Template Extension (TE) language -can be used to write scheduling templates that can be searched by AutoTVM to -find optimal configurations of scheduling variables. This process is called -Auto-Tuning, and builds on TE to help automate the process of optimizing -operations. +In this tutorial, we show how the TVM Tensor Expression (TE) language +can be used to write schedule templates that can be searched by AutoTVM to +find the optimal schedule. This process is called Auto-Tuning, which helps +automate the process of optimizing tensor computation. This tutorial builds on the previous `tutorial on how to write a matrix multiplication using TE `. @@ -371,6 +372,6 @@ def matmul(N, L, M, dtype): # To gain a deeper understanding of how this works, we recommend expanding on # this example by adding new search parameters to the schedule based on # schedule operations demonstated in the `Getting Started With Tensor -# Expressions _` tutorial In the upcoming sections, we +# Expressions _` tutorial. In the upcoming sections, we # will demonstate the AutoScheduler, a method for TVM to optimize common # operators without the need for the user to provide a user-defined template. diff --git a/tutorials/get_started/auto_tuning_with_python.py b/tutorials/get_started/autotvm_relay_x86.py similarity index 98% rename from tutorials/get_started/auto_tuning_with_python.py rename to tutorials/get_started/autotvm_relay_x86.py index 8160442cdefd..67faec4505a6 100644 --- a/tutorials/get_started/auto_tuning_with_python.py +++ b/tutorials/get_started/autotvm_relay_x86.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """ -Compiling and Optimizing a Model with the Python AutoScheduler -============================================================== +Compiling and Optimizing a Model with the Python Interface (AutoTVM) +==================================================================== **Author**: `Chris Hoge `_ @@ -302,6 +302,7 @@ repeat=repeat, timeout=timeout, min_repeat_ms=min_repeat_ms, + enable_cpu_cache_flush=True, ) # Create a simple structure for holding tuning options. We use an XGBoost @@ -464,7 +465,7 @@ # Final Remarks # ------------- # -# In this tutorial, we we gave a short example of how to use the TVM Python API +# In this tutorial, we gave a short example of how to use the TVM Python API # to compile, run, and tune a model. We also discussed the need for pre and # post-processing of inputs and outputs. After the tuning process, we # demonstrated how to compare the performance of the unoptimized and optimize diff --git a/tutorials/get_started/install.py b/tutorials/get_started/install.py index 6d1db4ddb127..efc951a52709 100644 --- a/tutorials/get_started/install.py +++ b/tutorials/get_started/install.py @@ -23,8 +23,8 @@ Depending on your needs and your working environment, there are a few different methods for installing TVM. These include: - * Installing from source - * Installing from third-party binary package. +* Installing from source +* Installing from third-party binary package. """ ################################################################################ diff --git a/tutorials/get_started/introduction.py b/tutorials/get_started/introduction.py index 0ee79d334c03..0746c3983b61 100644 --- a/tutorials/get_started/introduction.py +++ b/tutorials/get_started/introduction.py @@ -19,7 +19,8 @@ ============ **Authors**: `Jocelyn Shiue `_, -`Chris Hoge `_ +`Chris Hoge `_, +`Lianmin Zheng `_ Apache TVM is an open source machine learning compiler framework for CPUs, GPUs, and machine learning accelerators. It aims to enable machine learning @@ -35,11 +36,11 @@ #. :doc:`Introduction ` #. :doc:`Installing TVM ` -#. :doc:`Compiling and Optimizing a Model with TVMC ` -#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler ` -#. :doc:`Working with Operators Using Tensor Expressions ` -#. :doc:`Optimizing Operators with Templates and AutoTVM ` -#. :doc:`Optimizing Operators with AutoScheduling ` +#. :doc:`Compiling and Optimizing a Model with the Command Line Interface ` +#. :doc:`Compiling and Optimizing a Model with the Python Interface ` +#. :doc:`Working with Operators Using Tensor Expression ` +#. :doc:`Optimizing Operators with Templates and AutoTVM ` +#. :doc:`Optimizing Operators with Template-free AutoScheduler ` #. :doc:`Cross Compilation and Remote Procedure Calls (RPC) ` #. :doc:`Compiling Deep Learning Models for GPUs ` """ @@ -51,18 +52,18 @@ # The diagram below illustrates the steps a machine model takes as it is # transformed with the TVM optimizing compiler framework. # -# .. image:: https://raw.githubusercontent.com/hogepodge/web-data/c339ebbbae41f3762873147c1e920a53a08963dd/images/getting_started/overview.png +# .. image:: https://raw.githubusercontent.com/apache/tvm-site/main/images/tutorial/overview.png # :width: 100% # :alt: A High Level View of TVM # # 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*. # The importer layer is where TVM can ingest models from other frameworks, like -# ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each +# Tensorflow, PyTorch, or ONNX. The level of support that TVM offers for each # frontend varies as we are constantly improving the open source project. If # you're having issues importing your model into TVM, you may want to try # converting it to ONNX. # -# 2. Translate to *Relay*, TVM's high level model language. +# 2. Translate to *Relay*, TVM's high-level model language. # A model that has been imported into TVM is represented in Relay. Relay is a # functional language and intermediate representation (IR) for neural networks. # It has support for: @@ -72,46 +73,47 @@ # differentiable language # - Ability to allow the user to mix the two programming styles # -# Relay applies several high-level optimization to the model, after which -# is runs the Relay Fusion Pass. To aid in the process of converting to -# Relay, TVM includes a Tensor Operator Inventory (TOPI) that has pre-defined -# templates of common computations. +# Relay applies graph-level optimization passes to optimize the model. # # 3. Lower to *Tensor Expression* (TE) representation. Lowering is when a # higher-level representation is transformed into a lower-level -# representation. In Relay Fusion Pass, the model is lowered from the -# higher-level Relay representation into a smaller set of subgraphs, where -# each node is a task. A task is a collection of computation templates, -# expressed in TE, where there parameters of the template can control how -# the computation is carried out on hardware. The specific ordering of compuation, -# defined by parameters to the TE template, is called a schedule. -# -# 4. Search for optimized schedule using *AutoTVM* or *AutoScheduler* for each -# task through tuning. Tuning is the process of searching the TE parameter -# space for a schedule that is optimized for target hardware. There are -# couple of optimization options available, each requiring varying levels of -# user interaction. The optimization options include: -# -# - **AutoTVM**: The user specifies a search template for the schedule of a TE task, -# or TE subraph. AutoTVM directs the search of the parameter space defined by the -# template to produce an optimized configuration. AutoTVM requires users to -# define manually templates for each operator as part of the TOPI. -# - **Ansor/AutoSchedule**: Using a TVM Operator Inventory (TOPI) of operations, -# Ansor can automatically search an optimization space with much less -# intervention and guidance from the end user. Ansor depends on TE templates to -# guide the search. -# -# 5. Choose the optimal configuration for the model. After tuning, an optimal schedule -# for each task is chosen. Regardless if it is AutoTVM or AutoSchedule, -# schedule records in JSON format are produced that are referred to by this step -# to build an optimized model. -# -# 6. Lower to a hardware specific compiler. After selecting an optimized configuration -# based on the tuning step, the model is then lowered to a representation -# expected by the target compiler for the hardware platform. This is the -# final code generation phase with the intention of producing an optimized -# model that can be deployed into production. TVM supports a number of -# different compiler backends including: +# representation. After applying the high-level optimizations, Relay +# runs FuseOps pass to partition the model into many small subgraphs and lowers +# the subgraphs to TE representation. Tensor Expression (TE) is a +# domain-specific language for describing tensor computations. +# TE also provides several *schedule* primitives to specify low-level loop +# optimizations, such as tiling, vectorization, parallelization, +# unrolling, and fusion. +# To aid in the process of converting Relay representation into TE representation, +# TVM includes a Tensor Operator Inventory (TOPI) that has pre-defined +# templates of common tensor operators (e.g., conv2d, transpose). +# +# 4. Search for the best schedule using the auto-tuning module *AutoTVM* or *AutoScheduler*. +# A schedule specifies the low-level loop optimizations for an operator or +# subgraph defined in TE. Auto-tuning modules search for the best schedule +# and compare them with cost models and on-device measurements. +# There are two auto-tuning modules in TVM. +# +# - **AutoTVM**: A template-based auto-tuning module. It runs search algorithms +# to find the best values for the tunable knobs in a user-defined template. +# For common operators, their templates are already provided in TOPI. +# - **AutoScheduler (a.k.a. Ansor)**: A template-free auto-tuning module. +# It does not require pre-defined schedule templates. Instead, it generates +# the search space automatically by analyzing the computation definition. +# It then searches for the best schedule in the generated search space. +# +# 5. Choose the optimal configurations for model compilation. After tuning, the +# auto-tuning module generates tuning records in JSON format. This step +# picks the best schedule for each subgraph. +# +# 6. Lower to Tensor Intermediate Representation (TIR), TVM's low-level +# intermediate representation. After selecting the optimal configurations +# based on the tuning step, each TE subgraph is lowered to TIR and be +# optimized by low-level optimization passes. Next, the optimized TIR is +# lowered to the target compiler of the hardware platform. +# This is the final code generation phase to produce an optimized model +# that can be deployed into production. TVM supports several different +# compiler backends including: # # - LLVM, which can target arbitrary microprocessor architecture including # standard x86 and ARM processors, AMDGPU and NVPTX code generation, and any diff --git a/tutorials/get_started/tensor_expr_get_started.py b/tutorials/get_started/tensor_expr_get_started.py index ee13d9e475f6..8fbdb751c9f8 100644 --- a/tutorials/get_started/tensor_expr_get_started.py +++ b/tutorials/get_started/tensor_expr_get_started.py @@ -17,22 +17,19 @@ """ .. _tutorial-tensor-expr-get-started: -Working with Operators Using Tensor Expressions -=============================================== +Working with Operators Using Tensor Expression +============================================== **Author**: `Tianqi Chen `_ In this tutorial we will turn our attention to how TVM works with Tensor -Expressions (TE) to create a space to search for performant configurations. TE +Expression (TE) to define tensor computations and apply loop optimizations. TE describes tensor computations in a pure functional language (that is each expression has no side effects). When viewed in context of the TVM as a whole, Relay describes a computation as a set of operators, and each of these operators can be represented as a TE expression where each TE expression takes -input tensors and produces an output tensor. It's important to note that the -tensor isn't necessarily a fully materialized array, rather it is a -representation of a computation. If you want to produce a computation from a -TE, you will need to use the scheduling features of TVM. +input tensors and produces an output tensor. -This is an introductory tutorial to the Tensor expression language in TVM. TVM +This is an introductory tutorial to the Tensor Expression language in TVM. TVM uses a domain specific tensor expression for efficient kernel construction. We will demonstrate the basic workflow with two examples of using the tensor expression language. The first example introduces TE and scheduling with vector @@ -47,8 +44,8 @@ # --------------------------------------------------------------- # # Let's look at an example in Python in which we will implement a TE for -# vector addition, followed by a schedule targeted towards a CPU. We begin by initializing a TVM -# environment. +# vector addition, followed by a schedule targeted towards a CPU. +# We begin by initializing a TVM environment. import tvm import tvm.testing @@ -59,7 +56,8 @@ # and specify it. If you're using llvm, you can get this information from the # command ``llc --version`` to get the CPU type, and you can check # ``/proc/cpuinfo`` for additional extensions that your processor might -# support. For example, ``tgt = "llvm -mcpu=`skylake` +# support. For example, you can use "llvm -mcpu=skylake-avx512" for CPUs with +# AVX-512 instructions. tgt = tvm.target.Target(target="llvm", host="llvm") @@ -69,7 +67,7 @@ # We describe a vector addition computation. TVM adopts tensor semantics, with # each intermediate result represented as a multi-dimensional array. The user # needs to describe the computation rule that generates the tensors. We first -# define a symbolic variable n to represent the shape. We then define two +# define a symbolic variable ``n`` to represent the shape. We then define two # placeholder Tensors, ``A`` and ``B``, with given shape ``(n,)``. We then # describe the result tensor ``C``, with a ``compute`` operation. The # ``compute`` defines a computation, with the output conforming to the @@ -79,7 +77,6 @@ # tensors. Remember, no actual computation happens during this phase, as we # are only declaring how the computation should be done. - n = te.var("n") A = te.placeholder((n,), name="A") B = te.placeholder((n,), name="B") @@ -88,10 +85,10 @@ ################################################################################ # .. note:: Lambda Functions # -# The second argument to the ``te.compute`` method is the function that -# performs the computation. In this example, we're using an anonymous function, -# also known as a ``lambda`` function, to define the computation, in this case -# addition on the ``i``th element of ``A`` and ``B``. +# The second argument to the ``te.compute`` method is the function that +# performs the computation. In this example, we're using an anonymous function, +# also known as a ``lambda`` function, to define the computation, in this case +# addition on the ``i``th element of ``A`` and ``B``. ################################################################################ # Create a Default Schedule for the Computation @@ -322,8 +319,6 @@ def evaluate_addition(func, target, optimization, log): bx, tx = s[C].split(C.op.axis[0], factor=64) - xXXXXXXXx - ################################################################################ # Finally we must bind the iteration axis bx and tx to threads in the GPU # compute grid. The naive schedule is not valid for GPUs, and these are diff --git a/tutorials/get_started/tvmc_command_line_driver.py b/tutorials/get_started/tvmc_command_line_driver.py index fffbfbf0356f..d9174da2ec58 100644 --- a/tutorials/get_started/tvmc_command_line_driver.py +++ b/tutorials/get_started/tvmc_command_line_driver.py @@ -494,5 +494,5 @@ # --help``. # # In the next tutorial, `Compiling and Optimizing a Model with the Python -# AutoScheduler `_, we will cover the same compilation +# Interface `_, we will cover the same compilation # and optimization steps using the Python interface. diff --git a/tutorials/micro/micro_tflite.py b/tutorials/micro/micro_tflite.py index b2896306b7b2..5e517bf062ef 100644 --- a/tutorials/micro/micro_tflite.py +++ b/tutorials/micro/micro_tflite.py @@ -230,7 +230,7 @@ # from tvm.micro.contrib import zephyr # # repo_root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding='utf-8').strip() -# project_dir = os.path.join(repo_root, "apps", "microtvm", "zephyr", "demo_runtime") +# project_dir = os.path.join(repo_root, "apps", "microtvm", "zephyr", "host_driven") # compiler = zephyr.ZephyrCompiler( # project_dir=project_dir, # board=BOARD, diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index 5ec11677da70..a982b88b75e8 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -210,7 +210,7 @@ def __init__(self, start=-1, end=-1): super().__init__() def visit_call(self, call): - """ Visit the children. """ + """Visit the children.""" # First visit the children. args = [self.visit(arg) for arg in call.args] @@ -265,7 +265,7 @@ def __init__(self): super().__init__() def visit_call(self, call): - """ Visit the children. """ + """Visit the children.""" # First visit the children. args = [self.visit(arg) for arg in call.args] @@ -302,7 +302,7 @@ def __init__(self, bfactor, cfactor, weight_bits): super().__init__() def visit_call(self, call): - """ Visit the children. """ + """Visit the children.""" # First visit the children. oshape = _get_tensor_shape(call) odtype = _get_tensor_type(call) diff --git a/vta/python/vta/top/op.py b/vta/python/vta/top/op.py index f243c3fc2c89..6b06d88096bf 100644 --- a/vta/python/vta/top/op.py +++ b/vta/python/vta/top/op.py @@ -41,7 +41,7 @@ # add clip vta strategy def compute_clip_vta(attrs, inputs, output_type): - """ Clip operator. """ + """Clip operator.""" x = inputs[0] a_min = attrs.a_min a_max = attrs.a_max diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py index 0b9cb719189f..5271b407fb8d 100644 --- a/vta/python/vta/top/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -29,7 +29,7 @@ @autotvm.register_topi_compute("conv2d_packed.vta") def conv2d_packed(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): - """ Packed conv2d function.""" + """Packed conv2d function.""" if not is_packed_layout(layout): raise topi.InvalidShapeError() assert dilation == (1, 1) diff --git a/vta/python/vta/top/vta_group_conv2d.py b/vta/python/vta/top/vta_group_conv2d.py index deb4ea779214..69d2579ad78c 100644 --- a/vta/python/vta/top/vta_group_conv2d.py +++ b/vta/python/vta/top/vta_group_conv2d.py @@ -28,7 +28,7 @@ @autotvm.register_topi_compute("group_conv2d_packed.vta") def group_conv2d_packed(cfg, data, kernel, strides, padding, dilation, group, out_dtype): - """ Packed group conv2d nchw function.""" + """Packed group conv2d nchw function.""" assert dilation == (1, 1) if padding[0]: diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index f8b4f2d2c5c3..7c7d02b40fbb 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -419,7 +419,7 @@ def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold): strides = list(x for x in strides) def raise_error(): - """Internal function to raise error """ + """Internal function to raise error""" raise RuntimeError( ( "Scope[%s]: cannot detect 2d pattern with elem_block=%d:" diff --git a/vta/tutorials/frontend/deploy_classification.py b/vta/tutorials/frontend/deploy_classification.py index 493db87d46d5..b2f909b9710a 100644 --- a/vta/tutorials/frontend/deploy_classification.py +++ b/vta/tutorials/frontend/deploy_classification.py @@ -141,7 +141,7 @@ ###################################################################### # Build the inference graph executor -# --------------------------------- +# ---------------------------------- # Grab vision model from Gluon model zoo and compile with Relay. # The compilation steps are: # diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index 77ce6be66e63..3054bd0d7109 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -31,7 +31,6 @@ #define DMLC_USE_LOGGING_LIBRARY #include -#include #include #include #include