From a3688d9d34bf9d0802f96402f524a6ec58270cc6 Mon Sep 17 00:00:00 2001 From: M N Ganesan Date: Wed, 9 Oct 2024 13:09:53 +0000 Subject: [PATCH 1/8] [Frontend][ArgParse] Compile with default(LLVM) target and build with BYOC(#17454) It is a unique use-case to check the default target(LLVM), though TVM is built with BYOC(MRVL-ON) The config of Codegen(BYOC) contains default values for configuration/options, it is extracted during _generate_codegen_args. In command line processing, validate_target_args checks if there are add-on options and it expects that particular target to be given explicitly in command line. Here, it is test for default (LLVM) path only, hence validate_target_args need to ignore the codegen's configuration for default target. Signed-off-by: M N Ganesan --- python/tvm/driver/tvmc/target.py | 4 ++++ tests/python/driver/tvmc/test_target_options.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index b5eee0482377..9e9a9ada5c1b 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -179,6 +179,10 @@ def validate_targets(parse_targets, additional_target_options=None): ) if additional_target_options is not None: + # Add-on target options are passed from codegen's config(BYOC) which has pass_default=True + # Eg: --target="llvm" + if len(tvm_targets) == 1: + return for target_name in additional_target_options: if not any([target for target in parse_targets if target["name"] == target_name]): first_option = list(additional_target_options[target_name].keys())[0] diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py index d98a8d588e22..0820a0f92f2e 100644 --- a/tests/python/driver/tvmc/test_target_options.py +++ b/tests/python/driver/tvmc/test_target_options.py @@ -86,6 +86,20 @@ def test_default_arg_for_mrvl_hybrid(): assert parsed.target_mrvl_num_tiles == 8 +@tvm.testing.requires_mrvl +def test_default_arg_for_mrvl_hybrid(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--target=mrvl, llvm", + ] + ) + assert parsed.target == "mrvl, llvm" + assert parsed.target_mrvl_mcpu == "cn10ka" + assert parsed.target_mrvl_num_tiles == 8 + + @tvm.testing.requires_cmsisnn def test_mapping_target_args(): parser = argparse.ArgumentParser() From 50199bd8fdbb6ea06144ceda0fb68bb2110e6795 Mon Sep 17 00:00:00 2001 From: M N Ganesan Date: Wed, 9 Oct 2024 13:09:53 +0000 Subject: [PATCH 2/8] [Frontend][ArgParse] Compile with default(LLVM) target and build with BYOC(#17454) It is a unique use-case to check the default target(LLVM), though TVM is built with BYOC(MRVL-ON) The config of Codegen(BYOC) contains default values for configuration/options, it is extracted during _generate_codegen_args. In command line processing, validate_target_args checks if there are add-on options and it expects that particular target to be given explicitly in command line. Here, it is test for default (LLVM) path only, hence validate_target_args need to ignore the codegen's configuration for default target. Signed-off-by: M N Ganesan --- python/tvm/driver/tvmc/target.py | 4 ++++ tests/python/driver/tvmc/test_target_options.py | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index b5eee0482377..9e9a9ada5c1b 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -179,6 +179,10 @@ def validate_targets(parse_targets, additional_target_options=None): ) if additional_target_options is not None: + # Add-on target options are passed from codegen's config(BYOC) which has pass_default=True + # Eg: --target="llvm" + if len(tvm_targets) == 1: + return for target_name in additional_target_options: if not any([target for target in parse_targets if target["name"] == target_name]): first_option = list(additional_target_options[target_name].keys())[0] diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py index d98a8d588e22..64218f02a0ab 100644 --- a/tests/python/driver/tvmc/test_target_options.py +++ b/tests/python/driver/tvmc/test_target_options.py @@ -86,6 +86,19 @@ def test_default_arg_for_mrvl_hybrid(): assert parsed.target_mrvl_num_tiles == 8 +@tvm.testing.requires_mrvl +# Test for default(LLVM) target, when built with USE_MRVL=ON +def test_mrvl_build_with_llvm_only_target(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--target=llvm", + ] + ) + assert parsed.target == "llvm" + + @tvm.testing.requires_cmsisnn def test_mapping_target_args(): parser = argparse.ArgumentParser() From 4a86250f7e872f6a5e2dcc1e1692eb7038edd567 Mon Sep 17 00:00:00 2001 From: M N Ganesan Date: Thu, 10 Oct 2024 17:14:22 +0000 Subject: [PATCH 3/8] [Frontend][ArgParse] Compile with default(LLVM) target and build with BYOC(#17454) It is a unique use-case to check the default target(LLVM), though TVM is built with BYOC(MRVL-ON) The config of Codegen(BYOC) contains default values for configuration/optons, it is extracted during _generate_codegen_args. In command line processing, validate_target_args checks if there are add-on options and it expects that particular target to be given explicitly in command line. Here, it is test for default (LLVM) path only, hence validate_target_args need to ignore the codegen's configuration Signed-off-by: M N Ganesan --- python/tvm/driver/tvmc/target.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index 0bbc49ba543c..37a232df321f 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -183,7 +183,7 @@ def validate_targets(parse_targets, additional_target_options=None): if not any([target for target in parse_targets if target["name"] == target_name]): # When built with USE_MRVL=ON, add-on target options are passed from MRVL codegen's # config which has pass_default=True, Eg: --target="llvm" cnn.onnx - if len(tvm_targets) == 1 && (target_name == 'mrvl'): + if (len(tvm_targets) == 1) and (target_name == 'mrvl'): return first_option = list(additional_target_options[target_name].keys())[0] From aba2866619cf7f09a9d54ca2c6456d1b1ce1f490 Mon Sep 17 00:00:00 2001 From: M N Ganesan Date: Fri, 11 Oct 2024 06:10:56 +0000 Subject: [PATCH 4/8] [Frontend][ArgParse] Compile with default(LLVM) target and build with BYOC(#17454) It is a unique use-case to check the default target(LLVM), though TVM is built with BYOC(MRVL-ON) The config of Codegen(BYOC) contains default values for configuration/optons, it is extracted during _generate_codegen_args. In command line processing, validate_target_args checks if there are add-on options and it expects that particular target to be given explicitly in command line. Here, it is test for default (LLVM) path only, hence validate_target_args need to ignore the codegen's configuration Signed-off-by: M N Ganesan --- python/tvm/driver/tvmc/target.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index 37a232df321f..9506ca527ef5 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -182,8 +182,9 @@ def validate_targets(parse_targets, additional_target_options=None): for target_name in additional_target_options: if not any([target for target in parse_targets if target["name"] == target_name]): # When built with USE_MRVL=ON, add-on target options are passed from MRVL codegen's - # config which has pass_default=True, Eg: --target="llvm" cnn.onnx - if (len(tvm_targets) == 1) and (target_name == 'mrvl'): + # config which has pass_default=True and compiled with default target, don't error + # Use case: --target="llvm" cnn.onnx + if (len(tvm_targets) == 1) and (target_name == "mrvl"): return first_option = list(additional_target_options[target_name].keys())[0] From 23b9cb6d1cc3e115867806bb7ee8c49e2da18a6e Mon Sep 17 00:00:00 2001 From: M N Ganesan Date: Fri, 11 Oct 2024 09:17:04 +0000 Subject: [PATCH 5/8] [Frontend][ArgParse] Compile with default(LLVM) target and build with BYOC(#17454) It is a unique use-case to check the default target(LLVM), though TVM is built with BYOC(MRVL-ON) The config of Codegen(BYOC) contains default values for configuration/optons, it is extracted during _generate_codegen_args. In command line processing, validate_target_args checks if there are add-on options and it expects that particular target to be given explicitly in command line. Here, it is test for default (LLVM) path only, hence validate_target_args need to ignore the codegen's configuration Signed-off-by: M N Ganesan --- python/tvm/driver/tvmc/target.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index 9506ca527ef5..a2d79a184804 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -183,7 +183,7 @@ def validate_targets(parse_targets, additional_target_options=None): if not any([target for target in parse_targets if target["name"] == target_name]): # When built with USE_MRVL=ON, add-on target options are passed from MRVL codegen's # config which has pass_default=True and compiled with default target, don't error - # Use case: --target="llvm" cnn.onnx + # Use case: --target="llvm" cnn.onnx if (len(tvm_targets) == 1) and (target_name == "mrvl"): return From fd36853ec24b1a06cb17335834fcf6d5c2403563 Mon Sep 17 00:00:00 2001 From: M N Ganesan Date: Wed, 23 Oct 2024 12:53:07 +0000 Subject: [PATCH 6/8] [Frontend][ArgParse] Compile with default target(LLVM) when built USE_MRVL=ON(#17454) This is a use-case of invoking TVMC with default target though it is built with MRVL_ON. In command line processing, validate_target_args checks if there are add-on options derived from the default arguments of codegen/BYOC and it expects that particular codegen to be given explicitly in command line. However, certain codegen's can have default target alone, in that case codegen optios are not extracted there by relaxing the validation Signed-off-by: M N Ganesan --- python/tvm/driver/tvmc/composite_target.py | 8 ++++++++ python/tvm/driver/tvmc/target.py | 4 ++++ tests/python/driver/tvmc/test_target_options.py | 13 +++++++++++++ 3 files changed, 25 insertions(+) diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index 6c51dd168963..e912ab564b55 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -52,41 +52,49 @@ "compute-library": { "config_key": None, "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_arm_compute_lib, }, "cmsis-nn": { "config_key": "relay.ext.cmsisnn.options", "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_cmsisnn, }, "ethos-n": { "config_key": "relay.ext.ethos-n.options", "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_ethosn, }, "ethos-u": { "config_key": "relay.ext.ethos-u.options", "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_ethosu, }, "bnns": { "config_key": None, "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_bnns, }, "vitis-ai": { "config_key": "relay.ext.vitis_ai.options", "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_vitis_ai, }, "clml": { "config_key": None, "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_clml, }, "mrvl": { "config_key": "relay.ext.mrvl.options", "pass_default": True, + "default_target": "llvm", "pass_pipeline": partition_for_mrvl, }, } diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index b5eee0482377..121e387d5a0e 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -122,7 +122,11 @@ def _reconstruct_codegen_args(args, codegen_name): codegen = get_codegen_by_target(codegen_name) pass_configs = PassContext.list_configs() codegen_options = {} + default_tgt = codegen["default_target"] + # Do not fetch codegen options, if the default target alone is choosen by user + if codegen_name not in args.target and default_tgt is not None and default_tgt in args.target: + return codegen_options if codegen["config_key"] is not None and codegen["config_key"] in pass_configs: attrs = make_node(pass_configs[codegen["config_key"]]["type"]) fields = attrs_api.AttrsListFieldInfo(attrs) diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py index d98a8d588e22..64218f02a0ab 100644 --- a/tests/python/driver/tvmc/test_target_options.py +++ b/tests/python/driver/tvmc/test_target_options.py @@ -86,6 +86,19 @@ def test_default_arg_for_mrvl_hybrid(): assert parsed.target_mrvl_num_tiles == 8 +@tvm.testing.requires_mrvl +# Test for default(LLVM) target, when built with USE_MRVL=ON +def test_mrvl_build_with_llvm_only_target(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--target=llvm", + ] + ) + assert parsed.target == "llvm" + + @tvm.testing.requires_cmsisnn def test_mapping_target_args(): parser = argparse.ArgumentParser() From d143cedff4e5adb47acce599b2b36ea6af4aaa36 Mon Sep 17 00:00:00 2001 From: M N Ganesan Date: Thu, 24 Oct 2024 09:41:53 +0000 Subject: [PATCH 7/8] [Frontend][ArgParse] Compile with default target(LLVM) when built USE_MRVL=ON(#17454) This is a use-case of invoking TVMC with default target though it is built with MRVL_ON. In command line processing, validate_target_args checks if there are add-on options derived from the default arguments of codegen/BYOC and it expects that particular codegen to be given explicitly in command line. However, certain codegen's can have default target alone, in that case codegen optios are not extracted there by relaxing the validation Signed-off-by: M N Ganesan --- conda/build-environment.yaml | 3 +- conda/recipe/meta.yaml | 2 +- docker/Dockerfile.ci_lint | 2 +- .../ubuntu2004_install_python_package.sh | 2 +- docker/install/ubuntu_install_jax.sh | 18 +- .../install/ubuntu_install_python_package.sh | 2 +- docker/install/ubuntu_install_tensorflow.sh | 4 +- .../ubuntu_install_tensorflow_aarch64.sh | 4 +- docker/install/ubuntu_install_tflite.sh | 40 +-- include/tvm/relax/attrs/manipulate.h | 23 ++ include/tvm/relax/transform.h | 21 ++ include/tvm/runtime/c_runtime_api.h | 2 +- include/tvm/tir/schedule/schedule.h | 11 + include/tvm/tir/stmt.h | 10 + .../main/java/org/apache/tvm/Function.java | 12 + .../src/main/java/org/apache/tvm/LibInfo.java | 2 + .../org/apache/tvm/contrib/GraphModule.java | 2 +- jvm/native/src/main/native/jni_helper_func.h | 21 ++ .../native/org_apache_tvm_native_c_api.cc | 15 + python/tvm/_ffi/libinfo.py | 2 +- python/tvm/driver/tvmc/composite_target.py | 8 + python/tvm/driver/tvmc/target.py | 5 + python/tvm/relax/frontend/nn/__init__.py | 2 + python/tvm/relax/frontend/nn/llm/kv_cache.py | 60 +--- .../tvm/relax/frontend/onnx/onnx_frontend.py | 139 +++++++- python/tvm/relax/op/__init__.py | 8 +- python/tvm/relax/op/binary.py | 26 ++ python/tvm/relax/op/create.py | 68 ++++ python/tvm/relax/op/manipulate.py | 83 +++++ python/tvm/relax/op/set.py | 37 ++ python/tvm/relax/pipeline.py | 50 ++- python/tvm/relax/transform/__init__.py | 2 + .../relax/transform/legalize_ops/binary.py | 3 +- .../relax/transform/legalize_ops/create.py | 30 ++ .../transform/legalize_ops/manipulate.py | 36 ++ python/tvm/relax/transform/transform.py | 29 ++ python/tvm/script/ir_builder/relax/ir.py | 12 + python/tvm/tir/schedule/schedule.py | 136 +++++++ python/tvm/topi/tensor.py | 35 +- src/arith/presburger_set.cc | 8 +- src/meta_schedule/database/database_utils.cc | 3 +- src/meta_schedule/postproc/rewrite_layout.cc | 8 +- src/relax/op/distributed/binary.cc | 2 + src/relax/op/tensor/binary.cc | 2 + src/relax/op/tensor/binary.h | 6 + src/relax/op/tensor/create.cc | 84 +++++ src/relax/op/tensor/create.h | 40 ++- src/relax/op/tensor/manipulate.cc | 209 +++++++++++ src/relax/op/tensor/manipulate.h | 45 +++ src/relax/op/tensor/set.cc | 23 ++ src/relax/op/tensor/set.h | 28 ++ .../attach_attr_layout_free_buffers.cc | 113 ++++++ .../transform/split_layout_rewrite_preproc.cc | 327 +++++++++++++++++ src/runtime/opencl/opencl_common.h | 2 + src/tir/schedule/analysis/analysis.cc | 4 +- src/tir/schedule/concrete_schedule.cc | 14 +- src/tir/schedule/concrete_schedule.h | 14 +- src/tir/schedule/primitive.h | 10 + .../primitive/annotate_buffer_access.cc | 167 +++++++++ src/tir/schedule/schedule.cc | 7 + src/tir/schedule/trace.cc | 4 +- src/tir/schedule/traced_schedule.cc | 20 +- src/tir/schedule/traced_schedule.h | 2 + src/tir/transforms/compact_buffer_region.cc | 43 ++- .../python/driver/tvmc/test_target_options.py | 13 + ...t_meta_schedule_postproc_rewrite_layout.py | 3 +- tests/python/relax/test_frontend_onnx.py | 121 ++++++- tests/python/relax/test_op_create.py | 58 +++ tests/python/relax/test_op_manipulate.py | 77 ++++ tests/python/relax/test_op_set.py | 34 ++ ...ansform_attach_attr_layout_free_buffers.py | 311 ++++++++++++++++ .../test_transform_legalize_ops_manipulate.py | 62 +++- ..._transform_split_layout_rewrite_preproc.py | 220 ++++++++++++ ...est_tir_schedule_annotate_buffer_access.py | 332 ++++++++++++++++++ .../test_tir_schedule_sampling.py | 28 ++ .../test_tir_schedule_split_fuse.py | 35 ++ .../test_topi_depthwise_conv2d_back_input.py | 4 +- version.py | 2 +- web/package-lock.json | 4 +- web/package.json | 2 +- 80 files changed, 3317 insertions(+), 141 deletions(-) create mode 100644 src/relax/transform/attach_attr_layout_free_buffers.cc create mode 100644 src/relax/transform/split_layout_rewrite_preproc.cc create mode 100644 src/tir/schedule/primitive/annotate_buffer_access.cc create mode 100644 tests/python/relax/test_transform_attach_attr_layout_free_buffers.py create mode 100644 tests/python/relax/test_transform_split_layout_rewrite_preproc.py create mode 100644 tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py diff --git a/conda/build-environment.yaml b/conda/build-environment.yaml index 8eb25ce01ac7..de4e6f4234d7 100644 --- a/conda/build-environment.yaml +++ b/conda/build-environment.yaml @@ -26,7 +26,8 @@ channels: # The packages to install to the environment dependencies: - python=3.9 - - conda-build + - conda < 24.9.0 + - conda-build < 24.9.0 - git - llvmdev >=11 - numpy diff --git a/conda/recipe/meta.yaml b/conda/recipe/meta.yaml index d4477468c79d..e340b25e5ba1 100644 --- a/conda/recipe/meta.yaml +++ b/conda/recipe/meta.yaml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -{% set version = '0.18.dev0' %} +{% set version = '0.19.dev0' %} {% set pkg_name = 'tvm' %} {% set cuda_tag = cuda_version | replace('.', '') %} # [cuda] {% set pkg_name = pkg_name + '-cu' + cuda_tag %} # [cuda] diff --git a/docker/Dockerfile.ci_lint b/docker/Dockerfile.ci_lint index bab0cd0ebf9c..89749b75bca8 100644 --- a/docker/Dockerfile.ci_lint +++ b/docker/Dockerfile.ci_lint @@ -38,7 +38,7 @@ ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. RUN apt-get update && apt-install-and-clear -y doxygen graphviz curl shellcheck -RUN pip3 install cpplint pylint==2.17.2 mypy==0.902 black==22.12.0 flake8==3.9.2 blocklint==0.2.3 jinja2==3.0.3 +RUN pip3 install cpplint==1.6.1 pylint==2.17.2 mypy==0.902 black==22.12.0 flake8==3.9.2 blocklint==0.2.3 jinja2==3.0.3 # Rust env (build early; takes a while) COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh diff --git a/docker/install/ubuntu2004_install_python_package.sh b/docker/install/ubuntu2004_install_python_package.sh index f1c03cf1c0e2..c72ea5d4fa66 100644 --- a/docker/install/ubuntu2004_install_python_package.sh +++ b/docker/install/ubuntu2004_install_python_package.sh @@ -35,7 +35,7 @@ pip3 install --upgrade \ psutil \ pytest \ git+https://github.com/tlc-pack/tlcpack-sphinx-addon.git@768ec1dce349fe4708f6ad68be1ebb3f3dabafa1 \ - pytest-profiling \ + pytest-profiling==1.7.0 \ pytest-xdist \ pytest-rerunfailures==10.2 \ requests \ diff --git a/docker/install/ubuntu_install_jax.sh b/docker/install/ubuntu_install_jax.sh index 17114e0efce8..19149909161e 100644 --- a/docker/install/ubuntu_install_jax.sh +++ b/docker/install/ubuntu_install_jax.sh @@ -20,18 +20,16 @@ set -e set -u set -o pipefail -JAX_VERSION=0.4.30 - -# Install jaxlib +# Install jax and jaxlib if [ "$1" == "cuda" ]; then - pip install -U \ - "jax[cuda12]~=${JAX_VERSION}" \ - jaxlib~=${JAX_VERSION} + pip3 install --upgrade \ + jaxlib~=0.4.9 \ + "jax[cuda11_pip]~=0.4.9" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html else - pip3 install -U \ - jax~=${JAX_VERSION} \ - jaxlib~=${JAX_VERSION} + pip3 install --upgrade \ + jaxlib~=0.4.9 \ + "jax[cpu]~=0.4.9" fi # Install flax -pip3 install flax~=0.8.5 +pip3 install flax~=0.6.9 diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index 593ba15f5947..7fe82a1db414 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -35,7 +35,7 @@ pip3 install --upgrade \ psutil \ pytest \ git+https://github.com/tlc-pack/tlcpack-sphinx-addon.git@768ec1dce349fe4708f6ad68be1ebb3f3dabafa1 \ - pytest-profiling \ + pytest-profiling!=1.8.0 \ pytest-xdist \ pytest-rerunfailures==10.2 \ requests \ diff --git a/docker/install/ubuntu_install_tensorflow.sh b/docker/install/ubuntu_install_tensorflow.sh index 012b678916b3..2225b7aef3b8 100755 --- a/docker/install/ubuntu_install_tensorflow.sh +++ b/docker/install/ubuntu_install_tensorflow.sh @@ -21,5 +21,5 @@ set -u set -o pipefail pip3 install \ - keras==3.5 \ - tensorflow==2.17.0 + keras==2.9 \ + tensorflow==2.9.1 diff --git a/docker/install/ubuntu_install_tensorflow_aarch64.sh b/docker/install/ubuntu_install_tensorflow_aarch64.sh index 4b158948387b..fcd912a4478a 100755 --- a/docker/install/ubuntu_install_tensorflow_aarch64.sh +++ b/docker/install/ubuntu_install_tensorflow_aarch64.sh @@ -25,5 +25,5 @@ apt-install-and-clear -y --no-install-recommends libhdf5-dev # h5py wheel tries to use the wrong .so file pip3 install \ numpy==1.23.5 \ - keras==3.5 \ - tensorflow-aarch64~=2.16.1 + keras==2.9 \ + tensorflow-aarch64~=2.9.3 diff --git a/docker/install/ubuntu_install_tflite.sh b/docker/install/ubuntu_install_tflite.sh index 8faabc022640..36e6dfc42794 100755 --- a/docker/install/ubuntu_install_tflite.sh +++ b/docker/install/ubuntu_install_tflite.sh @@ -26,11 +26,11 @@ set -o pipefail TENSORFLOW_VERSION=$(python3 -c "import tensorflow; print(tensorflow.__version__)" 2> /dev/null) # Download, build and install flatbuffers -git clone --branch=v24.3.25 --depth=1 --recursive https://github.com/google/flatbuffers.git -pushd flatbuffers - cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-Wno-class-memaccess" - ninja install -j8 -popd +git clone --branch=v1.12.0 --depth=1 --recursive https://github.com/google/flatbuffers.git +cd flatbuffers +cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-Wno-class-memaccess" +make install -j8 +cd .. # Install flatbuffers python packages. pip3 install flatbuffers @@ -41,22 +41,22 @@ pip3 install flatbuffers git clone https://github.com/tensorflow/tensorflow --branch=v${TENSORFLOW_VERSION} --depth 1 mkdir -p /opt/tflite -pushd /opt/tflite - cmake -G Ninja \ - -DTFLITE_ENABLE_XNNPACK=OFF \ - /tensorflow/tensorflow/lite +cd /opt/tflite +cmake \ + -DTFLITE_ENABLE_XNNPACK=OFF \ + /tensorflow/tensorflow/lite + +cmake --build . +cd - - cmake --build . -popd # Setup tflite from schema mkdir tflite -find / -name "schema.fbs" -cp /tensorflow/tensorflow/lite/stablehlo/schema/schema.fbs tflite -pushd tflite - flatc --python schema.fbs +cp tensorflow/tensorflow/lite/schema/schema.fbs tflite +cd tflite +flatc --python schema.fbs - cat <setup.py +cat <setup.py import setuptools setuptools.setup( @@ -77,12 +77,12 @@ setuptools.setup( ) EOM - cat <__init__.py +cat <__init__.py name = "tflite" EOM - # Install tflite over python3 - python3 setup.py install +# Install tflite over python3 +python3 setup.py install -popd +cd .. rm -rf tflite diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index ef4265d73b4b..ea41488354d8 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -164,6 +164,29 @@ struct ScatterElementsAttrs : public tvm::AttrsNode { "either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\"."); } }; // struct ScatterElementsAttrs + +/*! \brief Attributes used in scatter_nd operators */ +struct ScatterNDAttrs : public tvm::AttrsNode { + String reduction; + + TVM_DECLARE_ATTRS(ScatterNDAttrs, "relax.attrs.ScatterNDAttrs") { + TVM_ATTR_FIELD(reduction).set_default("update").describe( + "Accumulation mode of the ScatterND, " + "either \"update\", \"add\", \"mul\", \"min\" or \"max\"."); + } +}; // struct ScatterNDAttrs + +/*! \brief Attributes used in one_hot operator */ +struct OneHotAttrs : public tvm::AttrsNode { + int depth; + int axis; + + TVM_DECLARE_ATTRS(OneHotAttrs, "relax.attrs.OneHotAttrs") { + TVM_ATTR_FIELD(depth).describe("Depth of the one hot dimension."); + TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill."); + } +}; // struct OneHotAttrs + } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 5a7b85ac1376..eaad44a93ace 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -253,6 +253,27 @@ TVM_DLL Pass LegalizeOps(Optional> cmap, bool enable_war */ TVM_DLL Pass RealizeVDevice(); +/*! + * \brief Attach layout free buffers to the tir::PrimFunc. + * + * This pass is used to attach layout free buffers to the tir::PrimFunc according to + * the function usage in the relax function. Currently, the layout free buffers are the model + * weights and relax constants. + * + * \note We recommend applying CanonicalizeBindings before this pass. + * \return The Pass. + */ +TVM_DLL Pass AttachAttrLayoutFreeBuffers(); + +/*! + * \brief Split the layout rewrite preproc block to a separate tir::PrimFunc. + * + * This pass is used in the prepack weight after meta_schedule tuning. + * + * \return The Pass. + */ +TVM_DLL Pass SplitLayoutRewritePreproc(); + /*! * \brief Lift transformation of the parameters of a function. * diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index d26c95e4f53c..438d049ed4a1 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -73,7 +73,7 @@ #endif // TVM version -#define TVM_VERSION "0.18.dev0" +#define TVM_VERSION "0.19.dev0" // TVM Runtime is DLPack compatible. #include diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 092bd52d5634..e4b13888f948 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -834,6 +834,17 @@ class ScheduleNode : public runtime::Object { */ virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0; + /*! + * \brief Annotate the buffer access of a block + * \param block_rv The block to be annotated + * \param buffer_index The index of the buffer in block's read or write region + * \param buffer_index_type The type of the buffer index, kRead or kWrite. + * \param index_map The index map that defines the new read or write region + */ + virtual void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, + const IndexMap& index_map) = 0; + /******** Schedule: Misc ********/ /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ virtual void EnterPostproc() = 0; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index c77254ed34cb..38289af463d5 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1664,6 +1664,16 @@ constexpr const char* warp_execution = "warp_execution"; /*! \brief Mark that a block is disallowed in auto inline. */ constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule"; +/*! \brief Mark that a block has an explicitly specified read region. + * This is used to override the default read region inference in TIR. + */ +constexpr const char* explicit_read_region = "explicit_read_region"; + +/*! \brief Mark that a block has an explicitly specified write region. + * This is used to override the default write region inference in TIR. + */ +constexpr const char* explicit_write_region = "explicit_write_region"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/jvm/core/src/main/java/org/apache/tvm/Function.java b/jvm/core/src/main/java/org/apache/tvm/Function.java index df535a87aa85..594b35b0af68 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Function.java +++ b/jvm/core/src/main/java/org/apache/tvm/Function.java @@ -222,6 +222,16 @@ public Function pushArg(byte[] arg) { return this; } + /** + * Push argument to the function. + * @param arg Device. + * @return this + */ + public Function pushArg(Device arg) { + Base._LIB.tvmFuncPushArgDevice(arg); + return this; + } + /** * Invoke function with arguments. * @param args Can be Integer, Long, Float, Double, String, NDArray. @@ -255,6 +265,8 @@ private static void pushArgToStack(Object arg) { Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id); } else if (arg instanceof Function) { Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id); + } else if (arg instanceof Device) { + Base._LIB.tvmFuncPushArgDevice((Device) arg); } else if (arg instanceof TVMValue) { TVMValue tvmArg = (TVMValue) arg; switch (tvmArg.typeCode) { diff --git a/jvm/core/src/main/java/org/apache/tvm/LibInfo.java b/jvm/core/src/main/java/org/apache/tvm/LibInfo.java index 62b8c901bd71..aede9be334c8 100644 --- a/jvm/core/src/main/java/org/apache/tvm/LibInfo.java +++ b/jvm/core/src/main/java/org/apache/tvm/LibInfo.java @@ -37,6 +37,8 @@ class LibInfo { native void tvmFuncPushArgHandle(long arg, int argType); + native void tvmFuncPushArgDevice(Device device); + native int tvmFuncListGlobalNames(List funcNames); native int tvmFuncFree(long handle); diff --git a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java index 737fdef24ae8..0a0bc7efc46d 100644 --- a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java +++ b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java @@ -41,7 +41,7 @@ public class GraphModule { private Function fdebugGetOutput; private Function floadParams; - GraphModule(Module module, Device dev) { + public GraphModule(Module module, Device dev) { this.module = module; this.device = dev; fsetInput = module.getFunction("set_input"); diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index d60a1a4230b7..3e44f757392d 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -214,4 +214,25 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) { return NULL; } +// Helper function to pack two int32_t values into an int64_t +inline int64_t deviceToInt64(const int32_t device_type, const int32_t device_id) { + int64_t result; + int32_t* parts = reinterpret_cast(&result); + + // Lambda function to check endianness + const auto isLittleEndian = []() -> bool { + uint32_t i = 1; + return *reinterpret_cast(&i) == 1; + }; + + if (isLittleEndian()) { + parts[0] = device_type; + parts[1] = device_id; + } else { + parts[1] = device_type; + parts[0] = device_id; + } + return result; +} + #endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_ diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index 09522381f181..c039508b4b7f 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -112,6 +112,21 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle(JNIEnv* e->tvmFuncArgTypes.push_back(static_cast(argType)); } +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDevice(JNIEnv* env, jobject obj, + jobject arg) { + jclass deviceClass = env->FindClass("org/apache/tvm/Device"); + jfieldID deviceTypeField = env->GetFieldID(deviceClass, "deviceType", "I"); + jfieldID deviceIdField = env->GetFieldID(deviceClass, "deviceId", "I"); + jint deviceType = env->GetIntField(arg, deviceTypeField); + jint deviceId = env->GetIntField(arg, deviceIdField); + + TVMValue value; + value.v_int64 = deviceToInt64(deviceType, deviceId); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); + e->tvmFuncArgValues.push_back(value); + e->tvmFuncArgTypes.push_back(kDLDevice); +} + JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(JNIEnv* env, jobject obj, jbyteArray arg) { jbyteArray garg = reinterpret_cast(env->NewGlobalRef(arg)); diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index 2ec4ba8e31be..f29ddaab72a9 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -247,4 +247,4 @@ def find_include_path(name=None, search_path=None, optional=False): # We use the version of the incoming release for code # that is under development. # The following line is set by tvm/python/update_version.py -__version__ = "0.18.dev0" +__version__ = "0.19.dev0" diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index 6c51dd168963..e912ab564b55 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -52,41 +52,49 @@ "compute-library": { "config_key": None, "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_arm_compute_lib, }, "cmsis-nn": { "config_key": "relay.ext.cmsisnn.options", "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_cmsisnn, }, "ethos-n": { "config_key": "relay.ext.ethos-n.options", "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_ethosn, }, "ethos-u": { "config_key": "relay.ext.ethos-u.options", "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_ethosu, }, "bnns": { "config_key": None, "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_bnns, }, "vitis-ai": { "config_key": "relay.ext.vitis_ai.options", "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_vitis_ai, }, "clml": { "config_key": None, "pass_default": False, + "default_target": None, "pass_pipeline": partition_for_clml, }, "mrvl": { "config_key": "relay.ext.mrvl.options", "pass_default": True, + "default_target": "llvm", "pass_pipeline": partition_for_mrvl, }, } diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index b5eee0482377..4cfaf130e4db 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -122,6 +122,11 @@ def _reconstruct_codegen_args(args, codegen_name): codegen = get_codegen_by_target(codegen_name) pass_configs = PassContext.list_configs() codegen_options = {} + default_tgt = codegen["default_target"] + + # Do not fetch codegen options, if the default target alone is choosen by user + if codegen_name not in args.target and default_tgt is not None and default_tgt in args.target: + return codegen_options if codegen["config_key"] is not None and codegen["config_key"] in pass_configs: attrs = make_node(pass_configs[codegen["config_key"]]["type"]) diff --git a/python/tvm/relax/frontend/nn/__init__.py b/python/tvm/relax/frontend/nn/__init__.py index a8200d8dd627..f490af7062b0 100644 --- a/python/tvm/relax/frontend/nn/__init__.py +++ b/python/tvm/relax/frontend/nn/__init__.py @@ -23,6 +23,8 @@ from .modules import ( GELU, Conv1D, + Conv2D, + Conv3D, ConvTranspose1D, Embedding, GroupNorm, diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index fd866ae06c16..9b16fc2fbfee 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -925,12 +925,8 @@ def _attention_decode( THREAD_LIMIT = 512 TILE_SIZE_PER_BDX = 2 - if target.kind.name == "opencl" and ( - ("android" in str(target.host)) or ("adreno" in str(target.attrs)) - ): - # Keeping lower thread limit for this kernel on adreno target - # to avoid register spill - THREAD_LIMIT = 256 + if target.kind.name == "opencl" and "android" in str(target.host): + THREAD_LIMIT = 256 if H_kv < 8 else 512 TILE_SIZE_PER_BDX = 1 max_num_threads_per_block = get_max_num_threads_per_block(target) thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) @@ -1574,11 +1570,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], bdx = 32 num_warps = 4 - tile_x, tile_y, tile_z = ( - 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), - d, - 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), - ) + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( @@ -1588,12 +1580,6 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], tile_z = 8 num_warps = 2 - if target.kind.name == "opencl" and ( - ("android" in str(target.host)) or ("adreno" in str(target.attrs)) - ): - LOAD_VEC = 16 // ((DataType(dtype).bits + 7) // 8) # 16 bytes - NUM_BLKS = group_size * 8 - # fmt: off @T.prim_func def batch_prefill_ragged_kv( # pylint: disable=too-many-branches @@ -1722,6 +1708,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches for lz, ly in T.grid(tile_z, tile_y): with T.block("K_load"): i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: K_smem[i, j] = T.if_then_else( @@ -1836,14 +1824,6 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches # fmt: on # pylint: enable=line-too-long,too-many-branches sch = tir.Schedule(batch_prefill_ragged_kv) - get_extent = lambda *lps: [int(sch.get(lp).extent) for lp in lps] - - def get_vecsize(extent): - return min(LOAD_VEC, (extent & ~(extent - 1))) - - def getxy_vecsize(x, y, t): - assert (x * y) % t == 0 - return min(get_vecsize(y), get_vecsize(x * y // t)) def get_tile_size(x, y, t): cnt = (x * y) // t @@ -1857,37 +1837,26 @@ def get_tile_size(x, y, t): def apply_to_qkv_load(sch: tir.Schedule, block): loop_x, loop_y = sch.get_loops(block)[-2:] - x_extent, y_extent = get_extent(loop_x, loop_y) - vec_size = getxy_vecsize(x_extent, y_extent, bdx * num_warps) - yo, yv = sch.split(loop_y, [None, vec_size]) - yo_extent = y_extent // vec_size - tile_x, tile_y = get_tile_size(x_extent, yo_extent, (bdx * num_warps)) - xo, xi = sch.split(loop_x, [tile_x, None]) - yo, yi = sch.split(yo, [tile_y, None]) - sch.reorder(xi, yi, xo, yo) - t = sch.fuse(xi, yi) - ty, tx = sch.split(t, [num_warps, bdx]) + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True + ) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - sch.vectorize(yv) + sch.vectorize(vec) def apply_to_so_ewise(sch: tir.Schedule, block, tile): loop_x, loop_y = sch.get_loops(block)[-2:] xo, xi = sch.split(loop_x, factors=[None, tile[0]]) yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) - sch.unroll(xi) - yiv_extent = get_vecsize(tile[1]) - yio, yiv = sch.split(yi, [None, yiv_extent]) - sch.unroll(yio) - sch.vectorize(yiv) t = sch.fuse(xo, yo) ty, tx = sch.split(t, factors=[None, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") def apply_to_gemm( # pylint: disable=unused-argument - sch: tir.Schedule, block, tile, read_0, read_1, r_len=16, k_major=False + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False ): loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] xo, xi = sch.split(loop_x, factors=[None, tile[0]]) @@ -1903,12 +1872,6 @@ def apply_to_gemm( # pylint: disable=unused-argument sch.reorder(ko, xi, yi, ki) else: sch.reorder(ko, ki, xi, yi) - yiv_extent = get_vecsize(tile[1]) - yio, yiv = sch.split(yi, [None, yiv_extent]) - sch.unroll(yio) - sch.vectorize(yiv) - sch.unroll(xi) - sch.unroll(ki) sch.decompose_reduction(block, ty) def apply_to_md(sch, block): @@ -1917,7 +1880,6 @@ def apply_to_md(sch, block): sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - sch.transform_layout("K_load", ("write", 0), lambda i, j: (j, i)) tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index aa156a025fef..6c9225070d3f 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -260,7 +260,7 @@ def base_impl(cls, bb, inputs, attr, params): else inputs[0].data.numpy() ) y = ( - _np.array(inputs[0].value) + _np.array(inputs[1].value) if isinstance(inputs[1], relax.PrimValue) else inputs[1].data.numpy() ) @@ -287,7 +287,7 @@ class Sub(BinaryBase): relax_op = relax.op.subtract @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def _impl_v7(cls, bb, inputs, attr, params): return cls.base_impl(bb, inputs, attr, params) @@ -298,7 +298,7 @@ class Mul(BinaryBase): relax_op = relax.op.multiply @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def _impl_v7(cls, bb, inputs, attr, params): return cls.base_impl(bb, inputs, attr, params) @@ -309,7 +309,7 @@ class Div(BinaryBase): relax_op = relax.op.divide @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def _impl_v7(cls, bb, inputs, attr, params): return cls.base_impl(bb, inputs, attr, params) @@ -320,7 +320,24 @@ class Pow(BinaryBase): relax_op = relax.op.power @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def _impl_v7(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + + +class Mod(BinaryBase): + """Converts an onnx Mod node into an equivalent Relax expression.""" + + numpy_op = _np.mod + relax_op = relax.op.mod + + @classmethod + def _impl_v10(cls, bb, inputs, attr, params): + if attr.get("fmod", 0) == 0: + cls.numpy_op = _np.fmod + cls.relax_op = relax.op.floor_mod + else: + cls.numpy_op = _np.mod + cls.relax_op = relax.op.mod return cls.base_impl(bb, inputs, attr, params) @@ -523,6 +540,23 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.nn.log_softmax(inputs[0], axis=axis) +class Hardmax(OnnxOpConverter): + """Converts an onnx Hardmax node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + axis = attr.get("axis", -1) + indices = inputs[0] + dtype = indices.struct_info.dtype + axis_len = int(inputs[0].struct_info.shape[axis]) + argmax = relax.op.argmax(indices, axis=axis) + on_value = relax.PrimValue(tvm.tir.const(1.0, dtype)) + off_value = relax.PrimValue(tvm.tir.const(0.0, dtype)) + + one_hot = relax.op.one_hot(argmax, on_value, off_value, axis_len, axis) + return one_hot + + class Transpose(OnnxOpConverter): """Converts an onnx Transpose node into an equivalent Relax expression.""" @@ -692,6 +726,36 @@ def _impl_v11(cls, bb, inputs, attr, params): return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], axis=axis) +class ScatterND(OnnxOpConverter): + """Convert an onnx ScatterND node into an equivalent Relax expression.""" + + @staticmethod + def _reduction_check(attr, valid_reductions: List[str]): + reduction = attr.get("reduction", None) + reduction = reduction or b"update" + reduction = reduction.decode("utf-8") + reduction = "update" if reduction == "none" else reduction + assert ( + reduction in valid_reductions + ), f"Only {valid_reductions} reductions are supported, but {reduction} is gotten" + + return reduction + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2]) + + @classmethod + def _impl_v16(cls, bb, inputs, attr, params): + reduction = cls._reduction_check(attr, ["update", "add", "mul"]) + return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction) + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + reduction = cls._reduction_check(attr, ["update", "add", "mul", "min", "max"]) + return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction) + + class Size(OnnxOpConverter): """Convert an onnx Size node into an equivalent Relax expression.""" @@ -701,6 +765,20 @@ def _impl_v1(cls, bb, inputs, attr, params): return relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0]))) +class EyeLike(OnnxOpConverter): + """Convert an onnx EyeLike node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + k = attr.get("k", 0) + input_dtype = inputs[0].struct_info.dtype + if "dtype" in attr and get_type(attr["dtype"]) != input_dtype: + raise ValueError( + f"dtype mismatch between input ({input_dtype}) and attribute ({attr['dtype']})" + ) + return relax.op.eye_like(inputs[0], k, input_dtype) + + class Gemm(OnnxOpConverter): """Convert an onnx Gemm node into an equivalent Relax expression.""" @@ -1552,6 +1630,35 @@ def _impl_v13(cls, bb, inputs, attr, params): class Pad(OnnxOpConverter): """Converts an onnx Pad node into an equivalent Relax expression.""" + @classmethod + def _impl_v2(cls, bb, inputs, attr, params): + pads = attr.get("pads") + pads = relax.const(_np.array(pads), inputs[0].struct_info.shape[0].dtype) + constant_value = attr.get("value") + if constant_value is None: + constant_value = 0.0 + + if isinstance(pads, relax.Constant): + pad_before, pad_after = _np.split(pads.data.numpy(), 2) + pad_before = _np.ndarray.tolist(pad_before) + pad_after = _np.ndarray.tolist(pad_after) + else: + raise ValueError("Dynamic pads are not supported yet.") + + pad_mode = attr.get("mode", b"constant").decode("utf-8") + if not pad_mode in ["constant", "edge", "reflect"]: + raise tvm.error.OpAttributeInvalid( + "Value " + pad_mode + ' in attribute "mode" is invalid for operator Pad.' + ) + + if pad_mode == "constant": + return bb.emit_te(topi.nn.pad, inputs[0], pad_before, pad_after, constant_value) + elif pad_mode == "reflect": + return bb.emit_te(topi.nn.mirror_pad, inputs[0], pad_before, pad_after, "REFLECT") + else: + # TODO(gigiblender) Support edge mode. + raise NotImplementedError("Pad mode {} not implemented".format(pad_mode)) + @classmethod def _impl_v11(cls, bb, inputs, attr, params): pads = get_constant(inputs[1], params) @@ -2461,13 +2568,13 @@ def _impl_v11(cls, bb, inputs, attr, params): depth = get_constant(inputs[1], params) values = get_constant(inputs[2], params) axis = attr.get("axis", -1) - dtype = values.struct_info.dtype assert isinstance(depth, relax.Constant), "Only constant depth currently supported." depth = depth.data.numpy().tolist() assert isinstance(values, relax.Constant), "Only constant values currently supported." values = values.data.numpy().tolist() off_value, on_value = values - return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth, axis, dtype) + off_value, on_value = relax.PrimValue(off_value), relax.PrimValue(on_value) + return relax.op.one_hot(indices, on_value, off_value, depth, axis) class Unique(OnnxOpConverter): @@ -2482,6 +2589,14 @@ def _impl_v11(cls, bb, inputs, attr, params): return relax.op.unique(data, sorted=sorted, axis=axis) +class NonZero(OnnxOpConverter): + """Converts an onnx NonZero node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.nonzero(inputs[0]) + + class HardSigmoid(OnnxOpConverter): """Converts an onnx HardSigmoid node into an equivalent Relax expression.""" @@ -2733,7 +2848,7 @@ def _get_convert_map(): "Sub": Sub, "Mul": Mul, "Div": Div, - # "Mod": Mod, + "Mod": Mod, "Less": Less, "LessOrEqual": LessOrEqual, "Greater": Greater, @@ -2803,7 +2918,7 @@ def _get_convert_map(): "Sigmoid": Sigmoid, "Softmax": Softmax, "LogSoftmax": LogSoftmax, - # "Hardmax": Hardmax, + "Hardmax": Hardmax, "Transpose": Transpose, "Unsqueeze": Unsqueeze, "Where": Where, @@ -2819,10 +2934,10 @@ def _get_convert_map(): # "GatherND": GatherND, "Scatter": Scatter, "ScatterElements": ScatterElements, - # "ScatterND": ScatterND, + "ScatterND": ScatterND, # "Compress": Compress, "Size": Size, - # "EyeLike": EyeLike, + "EyeLike": EyeLike, # Normalization "BatchNormalization": BatchNormalization, "LayerNormalization": LayerNormalization, @@ -2867,7 +2982,7 @@ def _get_convert_map(): "Range": Range, "OneHot": OneHot, "Unique": Unique, - # "NonZero": NonZero, + "NonZero": NonZero, # "If": If, # "LRN": LRN, # "MaxRoiPool": MaxRoiPool, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index c99201e969b5..1603ea2f0f7e 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -50,6 +50,7 @@ divide, equal, floor_divide, + floor_mod, greater, greater_equal, left_shift, @@ -60,6 +61,7 @@ logical_xor, maximum, minimum, + mod, multiply, not_equal, power, @@ -72,6 +74,8 @@ full_like, ones, ones_like, + eye, + eye_like, tril, triu, zeros, @@ -89,10 +93,12 @@ flatten, flip, layout_transform, + one_hot, permute_dims, repeat, reshape, scatter_elements, + scatter_nd, split, squeeze, tile, @@ -101,7 +107,7 @@ from .qdq import dequantize, quantize from .sampling import multinomial_from_uniform from .search import argmax, argmin, where -from .set import unique +from .set import nonzero, unique from .sorting import argsort, sort, topk from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance from .ternary import ewise_fma diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py index 7632235cb32c..7a41c8b0953c 100644 --- a/python/tvm/relax/op/binary.py +++ b/python/tvm/relax/op/binary.py @@ -139,6 +139,32 @@ def subtract(x1: Expr, x2: Expr) -> Expr: return _ffi_api.subtract(x1, x2) # type: ignore +def mod(x1: Expr, x2: Expr) -> Expr: + """Modulo with numpy-style broadcasting. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + """ + return _ffi_api.mod(x1, x2) # type: ignore + + +def floor_mod(x1: Expr, x2: Expr) -> Expr: + """Floor modulo with numpy-style broadcasting. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + """ + return _ffi_api.floor_mod(x1, x2) # type: ignore + + ###################### Comparison operators ###################### diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py index 092d79a74dc4..c61d9521a41d 100644 --- a/python/tvm/relax/op/create.py +++ b/python/tvm/relax/op/create.py @@ -163,6 +163,74 @@ def zeros_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr: return _ffi_api.zeros_like(x, dtype) # type: ignore +def eye( + n: Union[PrimExprLike, PrimValue], + m: Optional[Union[PrimExprLike, PrimValue]] = None, + k: Union[PrimExprLike, PrimValue] = 0, + dtype: Union[str, DataType] = "float32", +) -> Expr: + """Construct a 2-D tensor with ones on the diagonal and zeros elsewhere. + + Parameters + ---------- + n : Union[PrimExprLike, PrimValue] + Number of rows in the output. + + m : Optional[Union[PrimExprLike, PrimValue]] + Number of columns in the output. If None, defaults to n. + + k : Union[PrimExprLike, PrimValue] + Index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value + to a lower diagonal. + + dtype : Union[str, DataType] + The data type of the created tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + m = n if m is None else m + n = n if isinstance(n, PrimValue) else PrimValue(n) + m = m if isinstance(m, PrimValue) else PrimValue(m) + k = k if isinstance(k, PrimValue) else PrimValue(k) + return _ffi_api.eye(n, m, k, dtype) # type: ignore + + +def eye_like( + x: Expr, + k: Union[PrimExprLike, PrimValue] = 0, + dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + """Return a 2-D tensor with ones on the diagonal and zeros elsewhere, + with the same shape as the input tensor. + + Parameters + ---------- + x : relax.Expr + The input tensor, which provides the shape, and dtype + when the `dtype` field is not specified. + + k : Union[PrimExprLike, PrimValue] + Index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value + to a lower diagonal. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of the input tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + k = k if isinstance(k, PrimValue) else PrimValue(k) + return _ffi_api.eye_like(x, k, dtype) # type: ignore + + def arange( start: Union[PrimExprLike, PrimValue], end: Optional[Union[PrimExprLike, PrimValue]] = None, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index da0a09cc7b51..3210cc821689 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -511,3 +511,86 @@ def scatter_elements( """ return _ffi_api.scatter_elements(data, indices, updates, axis, reduction) # type: ignore + + +def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "update") -> Expr: + """Scatter updates into an array according to indices. + + Parameters + ---------- + data: relax.Expr + The input data to be updated. + + indices: relax.Expr + The index positions to update in `data`. + + updates: relax.Expr + Values to replace to. + + reduction: str + Type of reduction to apply: update, add, mul, max, min. + It is "update" by default. + + Returns + ------- + result : relax.Expr + The result has the same shape as data. + + Examples + -------- + .. code-block:: python + + # inputs + data = [1, 2, 3, 4, 5, 6, 7, 8] + indices = [[4], [3], [1], [7]] + updates = [9, 10, 11, 12] + + # output + output = [1, 11, 3, 10, 9, 6, 7, 12] + + """ + return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore + + +def one_hot( + indices: Expr, on_value: PrimValue, off_value: PrimValue, depth: int, axis: int = -1 +) -> Expr: + """Returns a one-hot tensor. + + Parameters + ---------- + indices : relax.Expr + The indices to set to `on_value`. + + on_value : relax.PrimValue + The value to fill at `indices`. + + off_value : relax.PrimValue + The value to fill at other locations. + + depth : int + The depth of the one-hot dimension. + + axis : int, optional + The axis to fill. Default is -1 which adds a new dimension at the end. + + Returns + ------- + result : relax.Expr + The computed result. + + Examples + -------- + .. code-block:: python + + indices = [0, 1, 2] + depth = 3 + on_value = 1 + off_value = 0 + + one_hot(indices, on_value, off_value, depth) = + [[1, 0, 0], + [0, 1, 0], + [0, 0, 1]] + """ + return _ffi_api.one_hot(indices, on_value, off_value, depth, axis) # type: ignore diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py index 0b86e19ce53f..c5db852ddd5d 100644 --- a/python/tvm/relax/op/set.py +++ b/python/tvm/relax/op/set.py @@ -110,3 +110,40 @@ def numpy_unique( return tvm.nd.array(output_sorted_numpy) output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis) return tvm.nd.array(output_numpy) + + +def nonzero(x: Expr) -> Expr: + """Find the indices of elements of a tensor that are non-zero. + + Parameters + ---------- + x : relax.Expr + The input data tensor. + + Returns + ------- + result : relax.Expr + A (n+1)-D tensor containing indices of non-zero elements. + + Note + ---- + This function is equivalent to `onnx.nonzero`. + + Examples + -------- + + .. code-block:: python + + x = [[0, 1], + [2, 0]] + nonzero(x) = [[0, 1], + [1, 0]] + + """ + return _ffi_api.nonzero(x) # type: ignore + + +@tvm.register_func("relax.run.nonzero") +def numpy_nonzero(x: tvm.nd.array) -> tvm.nd.array: + np_result = np.atleast_1d(x.numpy()).nonzero() + return tvm.nd.array(np.stack(np_result, axis=0)) diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index 582f5111aaf5..fe3dbc99fc15 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -109,6 +109,7 @@ def static_shape_tuning_pipeline( total_trials: int, target: Union[str, tvm.target.Target], work_dir: str = "tuning_logs", + cpu_weight_prepack: bool = False, ): """Tune the static shape model and store the log to database. @@ -122,18 +123,65 @@ def static_shape_tuning_pipeline( work_dir : str The directory to store the tuning logs. + + cpu_weight_prepack : bool + Whether to enable the cpu weight prepack feature. + + Note + ---- + `cpu_weight_prepack` is expected to be `True` when running on CPU for + better performance. However, it requires an explicit layout transformation + step by calling the corresponding vm function, which changes the interface + of deployment. So we disable it by default. Here is an example to enable it: + + .. code-block:: python + + mod = relax.pipeline.static_shape_tuning_pipeline( + total_trials=1000, + target="llvm -num-cores 16", + work_dir="tuning_logs", + cpu_weight_prepack=True, + )(mod) + + ex = relax.build(mod, target=target) + vm = relax.VirtualMachine(ex, device=tvm.cpu()) + + # Transform the params using the vm function + # the name should be f"{func_name}_transform_params" + params = vm["main_transform_params"](params["main"]) + + input_data = tvm.nd.array(np.random.randn(1, 3, 224, 224).astype("float32")) + out = vm["main"](input_data, *params).numpy() """ @tvm.transform.module_pass(opt_level=0) def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + if cpu_weight_prepack: + pre_tuning_layout_rewrite = [transform.AttachAttrLayoutFreeBuffers()] + post_tuning_layout_rewrite = [ + transform.SplitLayoutRewritePreproc(), + transform.LiftTransformParams(), + transform.FoldConstant(), + ] + else: + pre_tuning_layout_rewrite = [] + post_tuning_layout_rewrite = [] + with tvm.target.Target(target): mod = tvm.transform.Sequential( [ transform.DecomposeOpsForInference(), transform.CanonicalizeBindings(), zero_pipeline(), - transform.MetaScheduleTuneIRMod({}, work_dir, total_trials), + *pre_tuning_layout_rewrite, + # Skip tuning if total_trials is 0 + ( + transform.MetaScheduleTuneIRMod({}, work_dir, total_trials) + if total_trials > 0 + else tvm.transform.Sequential([]) + ), transform.MetaScheduleApplyDatabase(work_dir), + *post_tuning_layout_rewrite, ] )(mod) diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 1ce864651cd9..16e4800ca33d 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -21,6 +21,7 @@ AllocateWorkspace, AlterOpImpl, AnnotateTIROpPattern, + AttachAttrLayoutFreeBuffers, AttachGlobalSymbol, BindParams, BindSymbolicVars, @@ -73,6 +74,7 @@ RewriteDataflowReshape, RunCodegen, SplitCallTIRByPattern, + SplitLayoutRewritePreproc, StaticPlanBlockMemory, ToMixedPrecision, ToNonDataflow, diff --git a/python/tvm/relax/transform/legalize_ops/binary.py b/python/tvm/relax/transform/legalize_ops/binary.py index d28e100edb9f..41e317f1e0ef 100644 --- a/python/tvm/relax/transform/legalize_ops/binary.py +++ b/python/tvm/relax/transform/legalize_ops/binary.py @@ -48,7 +48,8 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.power", _binary(topi.power)) register_legalize("relax.subtract", _binary(topi.subtract)) register_legalize("relax.equal", _binary(topi.equal)) - +register_legalize("relax.mod", _binary(topi.mod)) +register_legalize("relax.floor_mod", _binary(topi.floor_mod)) register_legalize("relax.greater", _binary(topi.greater)) register_legalize("relax.greater_equal", _binary(topi.greater_equal)) register_legalize("relax.less", _binary(topi.less)) diff --git a/python/tvm/relax/transform/legalize_ops/create.py b/python/tvm/relax/transform/legalize_ops/create.py index 1b022672d0bd..8bf85e34dee8 100644 --- a/python/tvm/relax/transform/legalize_ops/create.py +++ b/python/tvm/relax/transform/legalize_ops/create.py @@ -70,6 +70,36 @@ def tril_triu_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.triu", _tril_triu(is_upper=True, primfunc_name="triu")) +def _eye(is_like: bool, primfunc_name: str) -> LegalizeFunc: + def eye_call_te(bb: BlockBuilder, call: Call) -> Expr: + _convert_to_scalar_const = lambda x: _try_convert_to_scalar_const(x, python_native=True) + if is_like: + x = call.args[0] + k = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 else 0 + n, m = x.struct_info.shape + dtype = x.struct_info.dtype + else: + n = _convert_to_scalar_const(call.args[0]) + m = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 else n + k = _convert_to_scalar_const(call.args[2]) if len(call.args) > 2 else 0 + dtype = call.attrs.dtype + + return bb.call_te( + topi.eye, + n, + m, + k, + dtype, + primfunc_name_hint=primfunc_name, + ) + + return eye_call_te + + +register_legalize("relax.eye", _eye(is_like=False, primfunc_name="eye")) +register_legalize("relax.eye_like", _eye(is_like=True, primfunc_name="eye_like")) + + @register_legalize("relax.arange") def _arange(bb: BlockBuilder, call: Call) -> Expr: assert len(call.args) == 3 diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 1efa78c069ad..163085a07c34 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -168,6 +168,42 @@ def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.scatter_nd") +def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr: + # TODO(relax-team): Support native scatter_nd without te extern + def scatter_nd(data, indices, updates, reduction): + axes = list(range(len(indices.shape))) + indices = topi.transpose(indices, axes[-1:] + axes[:-1]) + return topi.scatter_nd(data, indices, updates, reduction) + + return bb.call_te( + scatter_nd, + call.args[0], + call.args[1], + call.args[2], + call.attrs.reduction, + ) + + +@register_legalize("relax.one_hot") +def _one_hot(bb: BlockBuilder, call: Call) -> Expr: + indices, on_value, off_value = call.args + if not (isinstance(on_value, relax.PrimValue) and isinstance(off_value, relax.PrimValue)): + raise ValueError("on_value and off_value must be PrimValue") + on_value, off_value = on_value.value, off_value.value + if on_value.dtype != off_value.dtype: + raise ValueError("on_value and off_value must have the same dtype") + return bb.call_te( + topi.one_hot, + indices, + on_value, + off_value, + call.attrs.depth, + call.attrs.axis, + on_value.dtype, + ) + + @register_legalize("relax.layout_transform") def _layout_transform(bb: BlockBuilder, call: Call) -> Expr: def te_layout_transform(data, name): diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 3330d4098734..603211b59ebc 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -970,6 +970,35 @@ def MergeCompositeFunctions() -> tvm.ir.transform.Pass: return _ffi_api.MergeCompositeFunctions() # type: ignore +def AttachAttrLayoutFreeBuffers() -> tvm.ir.transform.Pass: + """Attach layout free buffers to the tir::PrimFunc. + + This pass is used to attach layout free buffers to the tir::PrimFunc according to + the function usage in the relax function. Currently, the layout free buffers are the model + weights and relax constants. + + Note that we recommend applying CanonicalizeBindings before this pass. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for attaching layout free buffers. + """ + return _ffi_api.AttachAttrLayoutFreeBuffers() # type: ignore + + +def SplitLayoutRewritePreproc() -> tvm.ir.transform.Pass: + """Split the TIR layout rewrite into multiple TIR functions. + This pass is used in the prepack weight after meta_schedule tuning. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for splitting TIR layout rewrite. + """ + return _ffi_api.SplitLayoutRewritePreproc() # type: ignore + + def LiftTransformParams(shared_transform: Union[bool, List[str]] = False) -> tvm.ir.transform.Pass: """Lift transformation of the parameters of a function. diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index e6ff35ebe56b..049345fcb10d 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -85,10 +85,13 @@ ewise_fma, exp, expand_dims, + eye, + eye_like, flatten, flip, floor, floor_divide, + floor_mod, full, full_like, grad, @@ -119,6 +122,7 @@ memory, min, minimum, + mod, multinomial_from_uniform, multiply, negative, @@ -127,6 +131,7 @@ null_value, ones, ones_like, + one_hot, permute_dims, power, print, @@ -138,6 +143,7 @@ round, rsqrt, scatter_elements, + scatter_nd, shape_of, shape_to_tensor, sigmoid, @@ -738,6 +744,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "cumsum", "einsum", "scatter_elements", + "scatter_nd", "dataflow", "device", "divide", @@ -751,10 +758,13 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "exp", "expand_dims", "ext_dev", + "eye", + "eye_like", "flatten", "flip", "floor", "floor_divide", + "floor_mod", "full", "full_like", "func_attr", @@ -793,6 +803,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "metal", "min", "minimum", + "mod", "multinomial_from_uniform", "multiply", "negative", @@ -800,6 +811,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "null_value", "ones", "ones_like", + "one_hot", "opencl", "output", "permute_dims", diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index be88e234634f..17c256be3538 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -3907,3 +3907,139 @@ def unsafe_hide_buffer_access( buf_type, buf_index_array, ) + + @type_checked + def annotate_buffer_access( + self, block: BlockRV, buffer_index: int, buf_type: str, gen_new_ranges: Callable + ) -> None: + """Annotate the read or write region of a block + + Parameters + ---------- + block : BlockRV + The block to be annotated + buffer_index : int + The index of the buffer in block's read or write region + buf_type : str + The buffer type: "read" or "write" + gen_new_ranges : Callable + A function that takes the block's iter_vars and returns a + Tuple[Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], ...] + which defines the new read or write region for the buffer. + Each element in the tuple can be: + - A single PrimExpr representing the iter_var itself + - A tuple of two PrimExprs representing the range (begin, end) + + Examples + -------- + Annotate a 2D read region for a buffer. + Before annotate_buffer_access, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_annotate_buffer_access( + A: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32") + ) -> None: + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do annotate_buffer_access: + + .. code-block:: python + + sch = tir.Schedule(before_annotate_buffer_access) + block = sch.get_block("B") + sch.annotate_buffer_access(block, 0, "read", + lambda vi, vj: ((vi - 1, vi + 1), (vj - 1, vj + 1))) + print(sch.mod["main"].script()) + + After applying annotate_buffer_access, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_annotate_buffer_access( + A: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32") + ) -> None: + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi - 1:vi + 1, vj - 1:vj + 1]) + T.writes(B[vi, vj]) + T.block_attr({"explicit_read_region": 0}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + This annotates the read region for buffer A (index 0) in block "B" to be + [vi-1:vi+1, vj-1:vj+1] for each (vi, vj) in the block's iteration domain. + + Note + ---- + This function allows manual specification of read or write regions, which + can be useful in cases where the compiler cannot accurately infer the + access pattern, such as complex data-dependent accesses. + It overrides the automatically inferred region for the specified buffer. + The function adds an annotation to the block, indicating that an explicit + region has been provided for the buffer at the given index. This annotation + is used in the CompactBufferAllocation pass to respect the manually specified + region instead of relying on automatic inference. + + Caution should be exercised when using this function, as incorrect annotations + may lead to incorrect code generation or runtime errors. It's crucial to + ensure that the specified region covers all actual reads or writes performed + by the block for the given buffer. + + """ + block_obj = self.get(block) + iter_vars = [x.var for x in block_obj.iter_vars] + new_ranges_spec = gen_new_ranges(*iter_vars) + if len(iter_vars) != len(new_ranges_spec): + raise ValueError( + f"Number of iter_vars ({len(iter_vars)}) must match " + f"number of new_ranges_spec ({len(new_ranges_spec)})" + ) + + result = [] + for rng in new_ranges_spec: + if isinstance(rng, (tuple, list)): + if len(rng) != 2: + raise ValueError( + "Tuple must have exactly 2 elements to represent (begin, end)." + ) + result.extend(rng) + elif isinstance(rng, PrimExpr): + result.extend([rng, rng + 1]) # Single point represented as (rng, rng + 1) + else: + raise TypeError(f"Expected PrimExpr or tuple of PrimExpr, got {type(rng)}") + + # Create index_map using IndexMap constructor + index_map = IndexMap( + initial_indices=iter_vars, + final_indices=result, + inverse_index_map=None, + ) + + if buf_type == "read": + buffer_index_type = 0 + elif buf_type == "write": + buffer_index_type = 1 + else: + raise ValueError(f"Invalid buf_type: {buf_type}. Expected 'read' or 'write'.") + + return _ffi_api.ScheduleAnnotateBufferAccess( + self, block, buffer_index, buffer_index_type, index_map + ) diff --git a/python/tvm/topi/tensor.py b/python/tvm/topi/tensor.py index 31ebe86760cb..449c599deaf3 100644 --- a/python/tvm/topi/tensor.py +++ b/python/tvm/topi/tensor.py @@ -16,7 +16,11 @@ # under the License. # pylint: disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition """Elementwise operators""" -from __future__ import absolute_import as _abs + +from typing import Optional + +from tvm import te + from . import cpp @@ -73,3 +77,32 @@ def full_like(x, fill_value): The result. """ return cpp.full_like(x, fill_value) + + +def eye(n: int, m: Optional[int] = None, k: int = 0, dtype: str = "float32") -> te.Tensor: + """Generate an identity matrix or a matrix with ones on the k-th diagonal. + + Parameters + ---------- + n : int + Number of rows + m : int, optional + Number of columns. If None, defaults to n. + k : int, optional + Index of the diagonal. 0 (default) refers to the main diagonal. + A positive value refers to an upper diagonal, and a negative value + to a lower diagonal. + dtype : str, optional + Data type of the returned array. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + m = m if m is not None else n + return te.compute( + (n, m), + lambda i, j: te.if_then_else(i == j - k, te.const(1, dtype), te.const(0, dtype)), + name="eye", + ) diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index 3798ba190446..4f4d7e18578f 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -215,7 +215,9 @@ PresburgerSet Intersect(const Array& sets) { IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) { Array tvm_coeffs = DetectLinearEquation(e, set->GetVars()); -#if TVM_MLIR_VERSION >= 160 +#if TVM_MLIR_VERSION >= 190 + SmallVector coeffs; +#elif TVM_MLIR_VERSION >= 160 SmallVector coeffs; #else SmallVector coeffs; @@ -223,7 +225,9 @@ IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) { coeffs.reserve(tvm_coeffs.size()); for (const PrimExpr& it : tvm_coeffs) { -#if TVM_MLIR_VERSION >= 160 +#if TVM_MLIR_VERSION >= 190 + coeffs.push_back(llvm::DynamicAPInt(*as_const_int(it))); +#elif TVM_MLIR_VERSION >= 160 coeffs.push_back(mlir::presburger::MPInt(*as_const_int(it))); #else coeffs.push_back(*as_const_int(it)); diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index ce025540e496..22b0933db4b4 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -236,7 +236,8 @@ class JSONTokenizer { str.push_back('\t'); break; default: - LOG(FATAL) << "ValueError: Unsupported escape sequence: \\" << *cur_; + LOG(FATAL) << "ValueError: Unsupported escape sequence: \\" << *cur_ + << ". record:" << std::string(cur_, end_); } } if (cur_ == end_) { diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 71ae43387112..87fa96f67ceb 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -249,7 +249,13 @@ class RewriteLayoutNode : public PostprocNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from PostprocNode - bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); } + bool Apply(const tir::Schedule& sch) final { + try { + return tir::RewriteLayout(sch); + } catch (const std::runtime_error& e) { + return false; + } + } Postproc Clone() const { ObjectPtr n = make_object(*this); diff --git a/src/relax/op/distributed/binary.cc b/src/relax/op/distributed/binary.cc index 6ad71e0f85bf..1e7fa8172718 100644 --- a/src/relax/op/distributed/binary.cc +++ b/src/relax/op/distributed/binary.cc @@ -42,6 +42,8 @@ RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(floor_divide); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(multiply); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(power); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(subtract); +RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(mod); +RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(floor_mod); /***************** Comparison operators *****************/ diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index f1dc3d4904c8..bd4c681c7925 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -181,6 +181,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(mod); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_mod); /***************** Comparison operators *****************/ diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index 003bcb7e27cf..b66eb96f8452 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -79,6 +79,12 @@ Expr power(Expr x1, Expr x2); /*! \brief Subtraction with numpy-style broadcasting. */ Expr subtract(Expr x1, Expr x2); +/*! \brief Modulo with numpy-style broadcasting. */ +Expr mod(Expr x1, Expr x2); + +/*! \brief Floor modulo with numpy-style broadcasting. */ +Expr floor_mod(Expr x1, Expr x2); + /***************** Comparison operators *****************/ /*! \brief Broadcasted element-wise test for (lhs == rhs). */ diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 7aca1470aee4..8696d85f7756 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -228,6 +228,90 @@ TVM_REGISTER_OP("relax.zeros_like") .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) .set_attr("FPurity", Bool(true)); +/* relax.eye & relax.eye_like */ +Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.eye"); + return Call(op, {std::move(n), std::move(m), std::move(k)}, Attrs(attrs), {}); +} + +Expr eye_like(Expr x, PrimValue k, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.eye_like"); + return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.eye").set_body_typed(eye); +TVM_REGISTER_GLOBAL("relax.op.eye_like").set_body_typed(eye_like); + +StructInfo InferStructInfoEye(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 3) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Eye op should have 3 arguments: n, m, and k, but got " << call->args.size() + << " arguments"); + } + + auto get_prim_value = [&ctx](const Expr& expr, std::string key) { + if (!expr->IsInstance()) { + ctx->ReportFatal(Diagnostic::Error(expr) + << "Eye expects the `" << key << "` to be a PrimValue, but got " + << expr->GetTypeKey()); + } + return expr.as()->value; + }; + + PrimExpr n = get_prim_value(call->args[0], "n"); + PrimExpr m = get_prim_value(call->args[1], "m"); + + DataType dtype = call->attrs.as()->dtype; + return TensorStructInfo(ShapeExpr({n, m}), dtype); +} + +StructInfo InferStructInfoEyeLike(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Eye_like op should have 2 arguments: x and k, but got " + << call->args.size() << " arguments"); + } + + const auto* x_sinfo = GetStructInfoAs(call->args[0]); + if (x_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Eye_like expects the input `x` to be a Tensor, but got " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (x_sinfo->ndim != 2 && x_sinfo->ndim != kUnknownNDim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Eye_like expects the input tensor to be 2-dimensional, but got " + << x_sinfo->ndim << " dimensions"); + } + + const auto* attrs = call->attrs.as(); + DataType out_dtype = attrs->dtype.is_void() ? x_sinfo->dtype : attrs->dtype; + + return TensorStructInfo(x_sinfo->shape.value(), out_dtype, x_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.eye") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("n", "PrimValue", "Number of rows in the output.") + .add_argument("m", "PrimValue", "Number of columns in the output.") + .add_argument("k", "PrimValue", "Index of the diagonal.") + .set_attr("FInferStructInfo", InferStructInfoEye) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); + +TVM_REGISTER_OP("relax.eye_like") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("k", "PrimValue", "Index of the diagonal.") + .set_attr("FInferStructInfo", InferStructInfoEyeLike) + .set_attr("FPurity", Bool(true)); + /* relax.arange */ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { ObjectPtr attrs = make_object(); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 6e7c8255238a..d88336146d44 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -72,12 +72,48 @@ Expr ones(Expr shape, DataType dtype); */ Expr ones_like(Expr x, DataType dtype); -/*! \brief Construct a tensor of all zeros, with the input shape and dtype. */ +/*! + * \brief Construct a tensor of all zeros, with the input shape and dtype. + * \param shape The shape of the created tensor. + * \param dtype The data type of the created tensor. + * \return The result tensor. + */ Expr zeros(Expr shape, DataType dtype); -/*! \brief Construct a tensor with all zeros, with shape of the input tensor shape. */ +/*! + * \brief Construct a tensor with all zeros, with shape of the input tensor shape. + * \param x The input tensor, which provides the shape, and dtype + * when the input dtype is void. + * \param dtype The data type of the created tensor. If it is + * void, the input tensor's dtype will be used. + * \return The result tensor. + */ Expr zeros_like(Expr x, DataType dtype); +/*! + * \brief Construct a 2-D tensor with ones on the diagonal and zeros elsewhere. + * \param n The number of rows and columns in the output. + * \param m The number of columns in the output. If None, defaults to n. + * \param k The index of the diagonal. A positive value refers to an upper diagonal, + * a negative value to a lower diagonal, and 0 to the main diagonal. + * \param dtype The data type of the created tensor. + * \return The result tensor. + */ +Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype); + +/*! + * \brief Construct a tensor with ones on the diagonal and zeros elsewhere, + * with shape and dtype similar to the input tensor. + * \param x The input tensor, which provides the shape, and dtype + * when the input dtype is void. + * \param k The index of the diagonal. A positive value refers to an upper diagonal, + * a negative value to a lower diagonal, and 0 to the main diagonal. + * \param dtype The data type of the created tensor. If it is + * void, the input tensor's dtype will be used. + * \return The result tensor. + */ +Expr eye_like(Expr x, PrimValue k, DataType dtype); + /*! \brief Construct a tensor with evenly spaced elements. */ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 2b1c6eafb652..ba443413025a 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -30,6 +30,8 @@ #include #include +#include "tvm/runtime/data_type.h" + namespace tvm { namespace relax { @@ -1531,5 +1533,212 @@ TVM_REGISTER_OP("relax.scatter_elements") .set_attr("FInferStructInfo", InferStructInfoScatterElements) .set_attr("FPurity", Bool(true)); +/* relax.scatter_nd */ +TVM_REGISTER_NODE_TYPE(ScatterNDAttrs); + +Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction) { + auto attrs = make_object(); + attrs->reduction = std::move(reduction); + static const Op& op = Op::Get("relax.scatter_nd"); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.scatter_nd").set_body_typed(scatter_nd); + +StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { + // `call->args` contains: [data, indices, updates] + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + ICHECK_EQ(call->args.size(), 3); + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* indices_sinfo = GetStructInfoAs(call->args[1]); + const auto* updates_sinfo = GetStructInfoAs(call->args[2]); + + if (data_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the input data to be a tensor. However, the given type is " + << call->args[0]->GetTypeKey()); + } + if (indices_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the input indices to be a tensor. However, the given type is " + << call->args[1]->GetTypeKey()); + } + if (updates_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the input updates to be a tensor. However, the given type is " + << call->args[2]->GetTypeKey()); + } + + if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the input data and updates to have known dtype. " + "However, the given types are " + << "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype); + } + + if (data_sinfo->dtype != updates_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the input data to have same type with updates. " + "However, the given types are " + << "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype); + } + + if (indices_sinfo->IsUnknownDtype()) { + LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; + } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the input indices to have integer dtype. However, " + "the given indices dtype is " + << indices_sinfo->dtype); + } + + const auto* data_shape = data_sinfo->shape.as(); + const auto* indices_shape = indices_sinfo->shape.as(); + const auto* updates_shape = updates_sinfo->shape.as(); + + if (data_shape && indices_shape && updates_shape) { + const IntImmNode* k_dim = indices_shape->values[indices_sinfo->ndim - 1].as(); + if (!k_dim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND needs a static shape for the last axis of indices, got " + << indices_shape->values); + } + const size_t data_ndim = data_sinfo->ndim; + const size_t indices_ndim = indices_sinfo->ndim; + const size_t updates_ndim = updates_sinfo->ndim; + if (data_ndim + indices_ndim - k_dim->value - 1 != updates_ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the updates tensor to have the rank of " + "`data tensor + indices tensor - last axis of indices tensor - 1`. " + "However, the given shapes are " + << "data: " << ShapeExpr(data_shape->values) + << ", indices: " << ShapeExpr(indices_shape->values) + << ", updates: " << ShapeExpr(updates_shape->values)); + } + if (k_dim->value > static_cast(data_ndim)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the last axis of indices tensor to be less than " + "or equal to the rank of data tensor. However, the given shapes are " + << "data: " << ShapeExpr(data_shape->values) + << ", indices: " << ShapeExpr(indices_shape->values)); + } + Array expected_updates_shape; + for (size_t i = 0; i < indices_ndim - 1; i++) { + expected_updates_shape.push_back(indices_shape->values[i]); + } + for (size_t i = k_dim->value; i < data_ndim; i++) { + expected_updates_shape.push_back(data_shape->values[i]); + } + auto check_shape = [&](const Array& expected, const Array& actual) { + if (expected.size() != actual.size()) { + return false; + } + for (size_t i = 0; i < expected.size(); i++) { + if (!analyzer->CanProve(expected[i] == actual[i])) { + return false; + } + } + return true; + }; + if (!check_shape(expected_updates_shape, updates_shape->values)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the updates tensor to have the shape with constraint: " + << "`updates.shape = indices.shape[:-1] + data.shape[K:]`, but got " + << "updates.shape: " << ShapeExpr(updates_shape->values) << ", indices.shape: " + << ShapeExpr(indices_shape->values) << ", data.shape: " << ShapeExpr(data_shape->values)); + } + } + if (data_shape) { + return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, data_sinfo->vdevice); + } + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.scatter_nd") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("updates", "Tensor", "The input tensor of updates.") + .set_attr("FInferStructInfo", InferStructInfoScatterND) + .set_attr("FPurity", Bool(true)); + +/* relax.one_hot */ +TVM_REGISTER_NODE_TYPE(OneHotAttrs); +Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, int axis) { + ObjectPtr attrs = make_object(); + attrs->depth = depth; + attrs->axis = axis; + + // Check if on_value and off_value have the same dtype + DataType on_dtype = on_value->value->dtype; + DataType off_dtype = off_value->value->dtype; + ICHECK(on_dtype == off_dtype) << "one_hot: on_value and off_value must have the same dtype, " + << "but got " << on_dtype << " and " << off_dtype; + + ICHECK(depth > 0) << "one_hot: depth must be positive, but got " << depth; + + static const Op& op = Op::Get("relax.one_hot"); + return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); +} // namespace relax + +TVM_REGISTER_GLOBAL("relax.op.one_hot").set_body_typed(one_hot); + +StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo indices_sinfo = GetInputTensorStructInfo(call, 0, ctx); + const auto* attrs = call->attrs.as(); + PrimValue on_value = Downcast(call->args[1]); + PrimValue off_value = Downcast(call->args[2]); + // Check if on_value and off_value have the same dtype + ICHECK(on_value->value->dtype == off_value->value->dtype) + << "one_hot: on_value and off_value must have the same dtype, " + << "but got " << on_value->value->dtype << " and " << off_value->value->dtype; + DataType dtype = on_value->value->dtype; + + // Check if indices has an integer dtype + if (indices_sinfo->IsUnknownDtype()) { + LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; + } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "one_hot op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_sinfo->dtype); + } + // Check if indices has unknown dimension + if (indices_sinfo->IsUnknownNdim()) { + return TensorStructInfo(dtype, kUnknownNDim, indices_sinfo->vdevice); + } + // Get the shape of indices + const auto* indices_shape = indices_sinfo->shape.as(); + if (indices_shape == nullptr) { + return TensorStructInfo(dtype, indices_sinfo->ndim + 1, indices_sinfo->vdevice); + } + + Array output_shape = indices_shape->values; + int axis = attrs->axis; + if (axis < 0) { + axis += output_shape.size() + 1; + } + ICHECK(0 <= axis && axis <= static_cast(output_shape.size())) + << "one_hot: axis must be in the range of [0, " << output_shape.size() << "], " + << "but got " << axis; + output_shape.insert(output_shape.begin() + axis, attrs->depth); + + return TensorStructInfo(ShapeExpr(output_shape), dtype, indices_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.one_hot") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("on_value", "PrimValue", "The value to fill at specified indices.") + .add_argument("off_value", "PrimValue", "The value to fill at other indices.") + .set_attr("FInferStructInfo", InferStructInfoOneHot) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 68622f1359e0..010ceb663ef3 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -27,6 +27,7 @@ #include #include "../op_common.h" +#include "tvm/relax/expr.h" namespace tvm { namespace relax { @@ -173,6 +174,50 @@ Expr tile(Expr data, Array repeats); */ Expr flip(Expr data, Integer axis); +/*! + * \brief Scatter updates into an array according to indices. + * \param data The input tensor. + * \param indices The index positions to update in `data`. + * \param updates The values to replace to. + * \param axis The axis along which to scatter the elements. + * \param reduction The reduction mode of the scatter elements, + * either "update", "add", "mul", "mean", "max" or "min". + * \return The computed result. + */ +Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String reduction); + +/*! + * \brief Scatter updates into an array according to indices. + * \param data The input tensor to be updated. + * \param indices The index positions to update in `data`. + * \param updates The values to replace to. + * \param reduction The reduction mode of the scatter operation. + * Supported modes are: + * - "update": Replace the values at the indices with the update values. + * - "add": Add the update values to the existing values at the indices. + * - "mul": Multiply the existing values at the indices by the update values. + * - "max": Take the maximum of the existing value and the update value at each index. + * - "min": Take the minimum of the existing value and the update value at each index. + * \return The computed result tensor with the same shape as `data`. + * + * \note The shape of `indices` defines the shape of the scattered tensor. + * The last dimension of `indices` corresponds to the depth of each index vector. + * The shape of `updates` must match the shape of `indices` except for the last dimension, + * which must match the slice shape at each index. + */ +Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction); + +/*! + * \brief Returns a one-hot tensor. + * \param indices The indices to set to `on_value`. + * \param on_value The value to fill at `indices`. + * \param off_value The value to fill at other locations. + * \param depth The depth of the one hot dimension. + * \param axis The axis to fill. + * \return The computed result. + */ +Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, int axis); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index 29d9d52c6077..c659a49afd12 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -24,6 +24,7 @@ #include "set.h" +#include #include #include @@ -137,5 +138,27 @@ TVM_REGISTER_OP("relax.unique") .set_attr("FCallPacked", "relax.run.unique") .set_attr("FPurity", Bool(true)); +/* relax.nonzero */ +Expr nonzero(Expr x) { + static const Op& op = Op::Get("relax.nonzero"); + return Call(op, {std::move(x)}); +} + +TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero); + +StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + // Cheat zero dim scalar as 1-dim. + int dim = data_sinfo->IsUnknownNdim() ? kUnknownNDim : std::max(1, data_sinfo->ndim) + 1; + return TensorStructInfo(DataType::Int(64), dim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.nonzero") + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoNonzero) + .set_attr("FCallPacked", "relax.run.nonzero") + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h index a5c7ee85bfb2..251dd1975e9f 100644 --- a/src/relax/op/tensor/set.h +++ b/src/relax/op/tensor/set.h @@ -29,8 +29,36 @@ namespace tvm { namespace relax { +/*! + * \brief Find the unique elements in a given tensor. + * In addition, it optionally returns + * - the indices of the input tensor that give the unique values; + * - the indices of the unique tensor that reconstruct the input tensor; + * - the number of times each unique value comes up in the input tensor. + * \param x The input tensor. + * \param sorted Whether to sort the unique elements in ascending order before + * returning as output. + * \param return_index Whether to return an additional tensor with indices for where elements in + * the unique tensor come from the original input. + * \param return_inverse Whether to return an additional tensor with indices for where elements in + * the original input ended up in the returned unique list. + * \param return_counts Whether to return an additional tensor with counts of each unique elements. + * \param axis The dimension to apply unique. + * If not specified, the unique values of the flattened input are returned. + * \return The unique elements of the array. The returned array will be sorted if `sorted` is True. + * Additional return values depend on `return_index`, `return_inverse`, and `return_counts`. + */ Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse, PrimValue return_counts, Optional axis); + +/*! + * \brief Returns the indices of the non-zero elements of the input tensor. + * \param x The input tensor. + * \return a list of 1-D tensors containing indices of non-zero elements for each dimension. + * \note This function behaves similarly to numpy.nonzero(), but return a multi-dimensional array + * instead of a tuple of 1-D arrays. + */ +Expr nonzero(Expr x); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/attach_attr_layout_free_buffers.cc b/src/relax/transform/attach_attr_layout_free_buffers.cc new file mode 100644 index 000000000000..64062e224372 --- /dev/null +++ b/src/relax/transform/attach_attr_layout_free_buffers.cc @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/attach_attr_layout_free_buffers.cc + * \brief Attach layout_free_buffers for layout-free buffers. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { + +class AttrAttacher : public ExprMutator { + public: + static IRModule Transform(const IRModule& mod) { + AttrAttacher mutator(mod); + for (auto [gvar, func] : mod->functions) { + if (func->IsInstance()) { + // clear the layout_free_exprs_ for each function + mutator.layout_free_exprs_.clear(); + mutator.builder_->UpdateFunction(gvar, Downcast(mutator.VisitExpr(func))); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + explicit AttrAttacher(IRModule mod) : ExprMutator(mod), mod_(mod) {} + + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const FunctionNode* op) final { + if (auto opt_num_input = op->attrs.GetAttr(attr::kNumInput)) { + ICHECK(layout_free_exprs_.empty()) << "meet a non-global function with num_input attr"; + size_t num_input = opt_num_input.value()->value; + for (size_t i = num_input; i < op->params.size(); i++) { + layout_free_exprs_.insert(op->params[i].get()); + } + } + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const ConstantNode* op) final { + layout_free_exprs_.insert(op); + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const CallNode* op) final { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + Call call = Downcast(ExprMutator::VisitExpr_(op)); + if (call->op != call_tir_op_) { + return call; + } + GlobalVar gv = Downcast(call->args[0]); + Array call_tir_args = Downcast(call->args[1])->fields; + // Compute the layout free buffers + Array layout_free_buffers; + for (size_t i = 0; i < call_tir_args.size(); i++) { + if (layout_free_exprs_.count(call_tir_args[i].get())) { + layout_free_buffers.push_back(Integer(i)); + } + } + // Attach the layout free buffers to the tir::PrimFunc + tir::PrimFunc func = WithAttr(Downcast(mod_->Lookup(gv)), "layout_free_buffers", + layout_free_buffers); + // Renew defs + func = tir::RenewDefs(func); + // Add the updated tir::PrimFunc in the IRModule + // Note the blockbuilder would automatically combine the same tir function + // So we don't need to worry about the duplicate insertion + GlobalVar new_gv = builder_->AddFunction(func, gv->name_hint); + // Create a new call node with the updated tir::PrimFunc + auto n = make_object(*op); + n->args = {new_gv, Tuple(call_tir_args)}; + return Call(n); + } + + private: + IRModule mod_; + std::unordered_set layout_free_exprs_; +}; +namespace transform { + +Pass AttachAttrLayoutFreeBuffers() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return AttrAttacher::Transform(mod); }; + auto pass = CreateModulePass(pass_func, 0, "_AttachAttrLayoutFreeBuffers", {}); + // Apply DeadCodeElimination to remove unused tir::PrimFunc + return tvm::transform::Sequential({pass, DeadCodeElimination()}, "AttachAttrLayoutFreeBuffers"); +} + +TVM_REGISTER_GLOBAL("relax.transform.AttachAttrLayoutFreeBuffers") + .set_body_typed(AttachAttrLayoutFreeBuffers); +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc new file mode 100644 index 000000000000..5fee946c26dd --- /dev/null +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/split_tir_layout_rewrite.cc + * \brief Use for rewriting the TIRs after meta_schedule layout rewrite post process. + */ +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace tir { +class SplitPrimFuncLayoutRewrite : public StmtMutator { + public: + explicit SplitPrimFuncLayoutRewrite(const PrimFunc& func) : original_func_(func) {} + std::tuple, PrimFunc> Transform(const PrimFunc& func) { + ICHECK(func->body.as()) << "The body of the primfunc should be a root block."; + const auto& block = func->body.as()->block; + visit_root_block(block.get()); + if (layout_rewrite_preproc_stmts_.size() > 0) { + return std::make_tuple(create_layout_rewrite_preproc_func(), create_compute_func()); + } else { + return std::make_tuple(NullOpt, func); + } + } + + private: + void sort_rewrite_infos() { + std::sort( + rewrite_infos_.begin(), rewrite_infos_.end(), + [](const RewriteInfo& a, const RewriteInfo& b) { return a.buffer_index < b.buffer_index; }); + } + + PrimFunc create_layout_rewrite_preproc_func() const { + // Step 1: Check the number of pre_rewrite_buffers and post_rewrite_buffers + ICHECK(rewrite_infos_.size() > 0) << "There should be at least one buffer rewrite."; + + // Step 2: Create the params for the new PrimFunc + Array params; + Map buffer_map; + + for (const auto& info : rewrite_infos_) { + params.push_back(Var(info.pre_rewrite_buffer->name, DataType::Handle())); + buffer_map.Set(params.back(), info.pre_rewrite_buffer); + } + for (const auto& info : rewrite_infos_) { + params.push_back(Var(info.post_rewrite_buffer->name, DataType::Handle())); + buffer_map.Set(params.back(), info.post_rewrite_buffer); + } + + // Step 3: Create the body for the new PrimFunc + ICHECK(layout_rewrite_preproc_stmts_.size() > 0) + << "There should be at least one layout rewrite preproc stmt."; + Stmt body = layout_rewrite_preproc_stmts_.size() == 1 ? layout_rewrite_preproc_stmts_[0] + : SeqStmt(layout_rewrite_preproc_stmts_); + body = BlockRealize( + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"root", body)); + + PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map); + + return RenewDefs(func); + } + + PrimFunc create_compute_func() const { + // Step 1: Create the params for the new PrimFunc + Array params = original_func_->params; + Map buffer_map = original_func_->buffer_map; + for (const auto& info : rewrite_infos_) { + const Var& param = params[info.buffer_index]; + ICHECK(buffer_map[param] == info.pre_rewrite_buffer); + buffer_map.Set(param, info.post_rewrite_buffer); + } + + // Step 2: Create the body for the new PrimFunc + Stmt body = compute_stmts_.size() == 1 ? compute_stmts_[0] : SeqStmt(compute_stmts_); + Block original_block = original_func_->body.as()->block; + Array alloc_buffers; + for (const auto& buffer : original_block->alloc_buffers) { + auto it = + std::find_if(rewrite_infos_.begin(), rewrite_infos_.end(), + [&](const RewriteInfo& info) { return info.post_rewrite_buffer == buffer; }); + if (it == rewrite_infos_.end()) { + alloc_buffers.push_back(buffer); + } + } + + body = BlockRealize( + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"root", body, + /*init=*/NullOpt, + /*alloc_buffers=*/alloc_buffers)); + + PrimFunc func = PrimFunc(original_func_->params, body, VoidType(), buffer_map); + return RenewDefs(func); + } + + void visit_root_block(const BlockNode* op) { + Stmt body = op->body; + if (const auto* seq_stmt = body.as()) { + for (const auto& stmt : seq_stmt->seq) { + current_subtree_ = 0; + Stmt new_stmt = this->VisitStmt(stmt); + ICHECK(current_subtree_ != 0) << "There should be at least a block in the subtree."; + if (current_subtree_ == 1) { + layout_rewrite_preproc_stmts_.push_back(new_stmt); + } else { + compute_stmts_.push_back(new_stmt); + } + } + } else { + current_subtree_ = 0; + this->VisitStmt(body); + ICHECK(current_subtree_ == -1) + << "There should be a compute block if there is only one subtree under the root."; + } + } + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtMutator::VisitStmt_(op)); + auto it = op->annotations.find(attr::meta_schedule_layout_rewrite_preproc); + bool is_layout_rewrite_preproc = + it != op->annotations.end() && is_one(Downcast((*it).second)); + + if (current_subtree_ == 0) { + current_subtree_ = is_layout_rewrite_preproc ? 1 : -1; + } else if (current_subtree_ == 1) { + CHECK(is_layout_rewrite_preproc) + << "There is a layout rewrite block in the subtree, but meet a non-layout rewrite block."; + } else { + CHECK(!is_layout_rewrite_preproc) + << "There is a non-layout rewrite block in the subtree, but meet a layout rewrite block."; + } + + if (is_layout_rewrite_preproc) { + ICHECK(op->reads.size() == 1) << "There should be only one read buffer in the layout rewrite"; + ICHECK(op->writes.size() == 1) + << "There should be only one write buffer in the layout rewrite"; + ICHECK(op->alloc_buffers.empty()) << "There should be no alloc buffer in the layout rewrite"; + ICHECK(op->match_buffers.empty()) << "There should be no match buffer in the layout rewrite"; + const Buffer& preproc_buffer = op->reads[0]->buffer; + int buffer_index = -1; + for (size_t i = 0; i < original_func_->params.size(); ++i) { + const Buffer& buffer = original_func_->buffer_map[original_func_->params[i]]; + if (buffer == preproc_buffer) { + buffer_index = i; + break; + } + } + ICHECK(buffer_index != -1) << "The preproc buffer is not found in the original primfunc."; + rewrite_infos_.push_back( + RewriteInfo{buffer_index, op->reads[0]->buffer, op->writes[0]->buffer}); + + auto new_annotations = op->annotations; + new_annotations.erase(attr::meta_schedule_layout_rewrite_preproc); + auto n = make_object(*block.get()); + n->annotations = new_annotations; + return Block(n); + } + return block; + } + + public: + struct RewriteInfo { + int buffer_index; + Buffer pre_rewrite_buffer; + Buffer post_rewrite_buffer; + }; + std::vector rewrite_infos_; + + private: + /*! \brief The stmts that are used for layout rewrite preproc*/ + Array layout_rewrite_preproc_stmts_; + /*! \brief The stmts that are other than layout rewrite preproc*/ + Array compute_stmts_; + /*! + \brief Whether the current subtree is a layout rewrite preproc subtree. + -1: visited a non-layout rewrite preproc block + 0: unsure, not visited any block + 1: visited a layout rewrite preproc block + */ + int current_subtree_; + /*! \brief The original primfunc*/ + PrimFunc original_func_; +}; +} // namespace tir + +namespace relax { +class SplitLayoutRewritePreproc : public ExprMutator { + public: + static IRModule Transform(const IRModule& mod) { + SplitLayoutRewritePreproc mutator(mod); + + // Step 1: Split the primfunc into preproc and compute + for (auto [gv, func] : mod->functions) { + if (func->IsInstance()) { + tir::SplitPrimFuncLayoutRewrite tir_rewriter(Downcast(func)); + auto [preproc_func, compute_func] = tir_rewriter.Transform(Downcast(func)); + if (preproc_func.defined()) { + mutator.split_funcs_.emplace(gv.get(), + std::make_tuple(preproc_func.value(), compute_func)); + mutator.rewrite_infos_.emplace(gv.get(), tir_rewriter.rewrite_infos_); + } + } + } + + for (auto [gv, func] : mod->functions) { + if (func->IsInstance()) { + auto relax_func = Downcast(func); + mutator.builder_->UpdateFunction(gv, Downcast(mutator(relax_func))); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + explicit SplitLayoutRewritePreproc(const IRModule& mod) : ExprMutator(mod) {} + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* op) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + Call call = Downcast(ExprMutator::VisitExpr_(op)); + + // Step 1: Skip call to other than `tir.call_tir` + if (!call->op.same_as(call_tir_op)) { + return call; + } + + // Step 2: Skip if there is no preproc stage + const GlobalVar gv = Downcast(call->args[0]); + auto it = split_funcs_.find(gv.get()); + if (it == split_funcs_.end()) { + return call; + } + + // Step 3: Get the preproc and compute functions and update the module + const auto& [preproc_func, compute_func] = it->second; + GlobalVar preproc_gv = builder_->AddFunction(preproc_func, gv->name_hint + "_weight_prepack"); + GlobalVar compute_gv = builder_->AddFunction(compute_func, gv->name_hint + "_prepacked"); + // Step 4. Get rewrite infos + auto rewrite_infos_it = rewrite_infos_.find(gv.get()); + ICHECK(rewrite_infos_it != rewrite_infos_.end()) + << "Rewrite infos are not found for " << gv->name_hint; + const auto& rewrite_infos = rewrite_infos_it->second; + + // Step 5: Emit the preproc call + Array call_tir_args = Downcast(call->args[1])->fields; + Array preproc_args; + Array preproc_sinfo_list; + for (const auto& info : rewrite_infos) { + preproc_args.push_back(call_tir_args[info.buffer_index]); + tir::Buffer rewritten_buffer = info.post_rewrite_buffer; + for (const auto& shape_expr : rewritten_buffer->shape) { + CHECK(shape_expr.as()) << "Currently does not support rewrite buffer with " + "dynamic shape."; + } + preproc_sinfo_list.push_back( + TensorStructInfo(ShapeExpr(rewritten_buffer->shape), rewritten_buffer->dtype)); + } + StructInfo preproc_sinfo = preproc_sinfo_list.size() > 1 // + ? TupleStructInfo(preproc_sinfo_list) // + : preproc_sinfo_list[0]; + + // Step 6: Call the preproc function + Expr preproc_call = + builder_->Emit(Call(call_tir_op, {preproc_gv, Tuple(preproc_args)}, {}, {preproc_sinfo})); + if (rewrite_infos.size() == 1) { + call_tir_args.Set(rewrite_infos[0].buffer_index, preproc_call); + } else { + for (size_t i = 0; i < rewrite_infos.size(); ++i) { + call_tir_args.Set(rewrite_infos[i].buffer_index, TupleGetItem(preproc_call, i)); + } + } + Expr main_call = + builder_->Emit(Call(call_tir_op, {compute_gv, Tuple(call_tir_args)}, {}, call->sinfo_args)); + + return main_call; + } + + private: + std::unordered_map> split_funcs_; + std::unordered_map> + rewrite_infos_; +}; + +} // namespace relax + +namespace transform { +Pass SplitLayoutRewritePreproc() { + auto pass_func = [](IRModule mod, PassContext pc) { + return relax::SplitLayoutRewritePreproc::Transform(mod); + }; + auto pass = CreateModulePass(pass_func, 0, "SplitLayoutRewritePreproc", {}); + return tvm::transform::Sequential({pass, relax::transform::DeadCodeElimination()}, + "SplitLayoutRewritePreproc"); +} +TVM_REGISTER_GLOBAL("relax.transform.SplitLayoutRewritePreproc") + .set_body_typed(SplitLayoutRewritePreproc); +} // namespace transform +} // namespace tvm diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 8c1607c4e56f..f752a487ea7e 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -171,6 +171,8 @@ inline const char* CLGetErrorString(cl_int error) { return "CL_INVALID_BUFFER_SIZE"; case CL_INVALID_MIP_LEVEL: return "CL_INVALID_MIP_LEVEL"; + case CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST: + return "CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST"; default: return "Unknown OpenCL error code"; } diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index b60e60c3cfc9..6195313fddae 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1581,14 +1581,14 @@ std::pair GetCumulativeSpaceAndReductionLength(const tir::Sche tir::IterVarType type = GetLoopIterType(loop_sref); if (type == tir::kDataPar) { const int64_t* extent = GetLoopIntExtent(loop_sref); - if (*extent != -1) { + if (extent && *extent != -1) { cum_space_len *= *extent; } else { return std::make_pair(-1, -1); } } else if (type == tir::kCommReduce) { const int64_t* extent = GetLoopIntExtent(loop_sref); - if (*extent != -1) { + if (extent && *extent != -1) { cum_reduce_len *= *extent; } else { return std::make_pair(-1, -1); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 73b5ff3fafd4..dd1a376deaf8 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -246,8 +246,10 @@ Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int int max_innermost_factor, Optional> decision) { TVM_TIR_SCHEDULE_BEGIN(); + // use None RV object to denotes auto-infer tile factors. return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, - max_innermost_factor, &decision)); + max_innermost_factor, &decision), + /*convert_negone_to_none=*/true); TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_); throw; } @@ -1059,5 +1061,15 @@ void ConcreteScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const this->state_->DebugVerify(); } +void ConcreteScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, + const IndexMap& index_map) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::AnnotateBufferAccess(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, + index_map); + TVM_TIR_SCHEDULE_END("annotate-buffer-access", this->error_render_level_); + this->state_->DebugVerify(); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 092bcf0c79f9..4aebe3036cf2 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -183,6 +183,8 @@ class ConcreteScheduleNode : public ScheduleNode { void EnterPostproc() override {} void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, const Array& buf_index_array) override; + void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map) override; protected: /******** Utility functions ********/ @@ -217,9 +219,12 @@ class ConcreteScheduleNode : public ScheduleNode { /*! * \brief Add a list of integers as random variables into the symbol table * \param value The list of integers to be added to the symbol table + * \param convert_negone_to_none Convert negative one to none RV. + * Which is convention of certain primitives. * \return The new random variables created */ - inline Array CreateRV(const std::vector& value); + inline Array CreateRV(const std::vector& value, + bool convert_negone_to_none = false); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); /*! @@ -360,10 +365,15 @@ inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { return std::move(rv); } -inline Array ConcreteScheduleNode::CreateRV(const std::vector& value) { +inline Array ConcreteScheduleNode::CreateRV(const std::vector& value, + bool convert_negone_to_none) { Array results; results.reserve(value.size()); for (int64_t v : value) { + if (convert_negone_to_none && v == -1) { + results.push_back(ExprRV(nullptr)); + continue; + } results.push_back(CreateRV(v)); } return results; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index fd1349e4a3ec..cf1ac957c89f 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -718,6 +718,16 @@ TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int w TVM_DLL void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, const String& buf_type, const Array& buf_index_array); +/*! + * \brief Annotate the read or write region of a specific buffer in a block + * \param self The state of the schedule + * \param block_sref The sref of the block to be annotated + * \param buffer_index The index of the buffer in block's read or write region + * \param buffer_index_type The type of the buffer index, kRead or kWrite + * \param index_map The IndexMap that defines the new read or write region for the buffer + */ +TVM_DLL void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/annotate_buffer_access.cc b/src/tir/schedule/primitive/annotate_buffer_access.cc new file mode 100644 index 000000000000..2c5976b035dd --- /dev/null +++ b/src/tir/schedule/primitive/annotate_buffer_access.cc @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +class AnnotateRegionRewriter : public StmtExprMutator { + public: + AnnotateRegionRewriter(Buffer buffer, int buffer_index, BufferRegion new_region, + BufferIndexType buffer_index_type) + : buffer_(buffer), + buffer_index_(buffer_index), + new_region_(new_region), + buffer_index_type_(buffer_index_type) {} + + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + + Array regions = + buffer_index_type_ == BufferIndexType::kWrite ? block->writes : block->reads; + ICHECK_GE(buffer_index_, 0) << "Buffer index must be non-negative"; + ICHECK_LT(buffer_index_, static_cast(regions.size())) << "Buffer index out of range"; + regions.Set(buffer_index_, new_region_); + + ObjectPtr n = CopyOnWrite(block.get()); + if (buffer_index_type_ == BufferIndexType::kWrite) { + n->writes = std::move(regions); + } else { + n->reads = std::move(regions); + } + + // Annotate the block with explicit_read_region or explicit_write_region + Map new_annotations = n->annotations; + String annotation_key = buffer_index_type_ == BufferIndexType::kWrite + ? attr::explicit_write_region + : attr::explicit_read_region; + if (new_annotations.count(annotation_key)) { + Array buffer_indices = Downcast>(new_annotations[annotation_key]); + bool found = false; + for (const Integer& index : buffer_indices) { + if (index->value == buffer_index_) { + found = true; + break; + } + } + if (!found) { + buffer_indices.push_back(Integer(buffer_index_)); + new_annotations.Set(annotation_key, buffer_indices); + } + } else { + new_annotations.Set(annotation_key, Array{Integer(buffer_index_)}); + } + n->annotations = std::move(new_annotations); + + return Block(n); + } + + private: + Buffer buffer_; + int buffer_index_; + BufferRegion new_region_; + BufferIndexType buffer_index_type_; +}; + +void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + Buffer buffer = GetNthAccessBuffer(self, GetRef(block), buffer_index, buffer_index_type); + + arith::Analyzer analyzer; + Array block_iter_vars; + for (const IterVar& iter_var : block->iter_vars) { + block_iter_vars.push_back(iter_var->var); + } + Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); + ICHECK_EQ(new_indices.size() % 2, 0) << "The size of new_indices should be even."; + Array new_ranges; + for (size_t i = 0; i < new_indices.size(); i += 2) { + // (begin, end) represents a region + new_ranges.push_back(Range::FromMinExtent( + new_indices[i], analyzer.Simplify(new_indices[i + 1] - new_indices[i]))); + } + + BufferRegion new_region(buffer, new_ranges); + + AnnotateRegionRewriter mutator(buffer, buffer_index, new_region, buffer_index_type); + Stmt new_stmt = mutator(GetRef(block_sref->stmt)); + + self->Replace(block_sref, new_stmt, {{GetRef(block), Downcast(new_stmt)}}); +} + +struct AnnotateBufferAccessTraits : public UnpackedInstTraits { + static constexpr const char* kName = "AnnotateBufferAccess"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 4; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index, + Integer buffer_index_type, IndexMap index_map) { + return sch->AnnotateBufferAccess(block, buffer_index->value, + static_cast(buffer_index_type->value), + index_map); + } + + static String IndexMap2GenNewRangesLambda(const IndexMap& index_map) { + std::ostringstream oss; + oss << "lambda "; + for (size_t i = 0; i < index_map->initial_indices.size(); ++i) { + if (i != 0) oss << ", "; + oss << index_map->initial_indices[i]; + } + oss << ": ["; + for (size_t i = 0; i < index_map->final_indices.size(); i += 2) { + if (i != 0) oss << ", "; + if (index_map->final_indices[i].same_as(index_map->final_indices[i + 1])) { + oss << index_map->final_indices[i]; + } else { + oss << "(" << index_map->final_indices[i] << ", " << index_map->final_indices[i + 1] << ")"; + } + } + oss << "]"; + return String(oss.str()); + } + + static String UnpackedAsPython(Array outputs, String block, Integer buffer_index, + Integer buffer_index_type, IndexMap index_map) { + PythonAPICall py("annotate_buffer_access"); + py.Input("block", block); + py.Input("buffer_index", buffer_index->value); + + std::ostringstream os; + os << "\"" << BufferIndexType2Str(static_cast(buffer_index_type->value)) + << "\""; + py.Input("buf_type", os.str()); + + py.Input("gen_new_ranges", IndexMap2GenNewRangesLambda(index_map)); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(AnnotateBufferAccessTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 44f9b8f42c68..2c3661d17ecc 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -310,6 +310,13 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeHideBufferAccess") .set_body_method(&ScheduleNode::UnsafeHideBufferAccess); +/******** (FFI) Annotate buffer access ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotateBufferAccess") + .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, + int buffer_index_type, const IndexMap& index_map) { + return self->AnnotateBufferAccess(block_rv, buffer_index, + static_cast(buffer_index_type), index_map); + }); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 6e243bf19198..7421cbbf32df 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -227,7 +227,9 @@ Array TranslateAddOutputRVs( ICHECK(!rv_names->count(output)) << "ValueError: The random variable has been produced once: " << rv_names->at(output); String result{ObjectPtr{nullptr}}; - if (output->IsInstance()) { + if (!output.defined()) { + result = "_"; + } else if (output->IsInstance()) { result = "b" + std::to_string(i); } else if (output->IsInstance()) { result = "l" + std::to_string(i); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 1611109d7735..784ecdeb32cb 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -70,9 +70,11 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array& candidat Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision) { - Array results = CreateRV(tir::SamplePerfectTile( - &this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision)); - + // use None RV object to denotes auto-infer tile factors. + Array results = + CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, + max_innermost_factor, &decision), + /*convert_negone_to_none=*/true); static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // /*inputs=*/{loop_rv}, @@ -769,5 +771,17 @@ void TracedScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const S /*outputs=*/{})); } +void TracedScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, + const IndexMap& index_map) { + ConcreteScheduleNode::AnnotateBufferAccess(block_rv, buffer_index, buffer_index_type, index_map); + static const InstructionKind& kind = InstructionKind::Get("AnnotateBufferAccess"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv, Integer(buffer_index), Integer(buffer_index_type), index_map}, + /*attrs=*/{}, + /*outputs=*/{})); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 78629e84f039..1c21c3e2c894 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -142,6 +142,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { void EnterPostproc() final; void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, const Array& buf_index_array) final; + void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map) final; }; } // namespace tir diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index f562a057e595..7385af49528b 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -136,7 +136,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) final { - VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices)); + auto explicit_it = explicit_access_annotations_.find(op->buffer); + if (explicit_it != explicit_access_annotations_.end()) { + VisitBufferAccess(explicit_it->second); + } else { + VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices)); + } StmtExprVisitor::VisitExpr_(op); } @@ -235,17 +240,38 @@ class BufferAccessRegionCollector : public StmtExprVisitor { auto& regions = access_annotations_[p.first]; p.second.swap(regions); } - // Step 2. Record relax position of ancestor_loops_ + + // Step 2. Record explicit read/write region annotations + auto record_explicit_region = [&](const String& attr_key, BufferIndexType index_type) { + auto it = op->annotations.find(attr_key); + if (it != op->annotations.end()) { + Array buffer_indices = Downcast>((*it).second); + for (const auto& index : buffer_indices) { + int buffer_index = index->value; + if (buffer_index >= 0 && buffer_index < static_cast(op->reads.size())) { + const BufferRegion& explicit_region = index_type == BufferIndexType::kRead + ? op->reads[buffer_index] + : op->writes[buffer_index]; + explicit_access_annotations_[explicit_region->buffer] = explicit_region; + } + } + } + }; + + record_explicit_region(attr::explicit_read_region, BufferIndexType::kRead); + record_explicit_region(attr::explicit_write_region, BufferIndexType::kWrite); + + // Step 3. Record relax position of ancestor_loops_ for (const Buffer& buffer : op->alloc_buffers) { VisitBufferDef(buffer->data); } - // Step 3. Visit match buffers + // Step 4. Visit match buffers for (const MatchBufferRegion& region : op->match_buffers) { VisitBufferAccess(region->source); } - // Step 4. Visit block body recursively + // Step 5. Visit block body recursively StmtExprVisitor::VisitStmt_(op); - // Step 5. Recover read/write region annotations + // Step 6. Recover read/write region annotations for (auto& p : cur_access_annotations) { auto& regions = access_annotations_[p.first]; if (p.second.empty()) { @@ -254,7 +280,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor { regions.swap(p.second); } } - // Step 6. Update buffer_access_region_ from relaxed_accesses_ for inner buffers. + // Step 7. Clear explicit access annotations + explicit_access_annotations_.clear(); + // Step 8. Update buffer_access_region_ from relaxed_accesses_ for inner buffers. for (const Buffer& buffer : op->alloc_buffers) { ICHECK_EQ(var2buffer_[buffer->data].size(), 1) << "Block allocation buffer shoud not be alised"; @@ -489,6 +517,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor { /*! \brief The map from Buffer to it's access regions annotated by current block. */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> access_annotations_; + /*! \brief The map from Buffer to its explicit access region annotated by the block. */ + std::unordered_map + explicit_access_annotations_; }; /*! \brief The storage alignment for a dimension */ diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py index d98a8d588e22..64218f02a0ab 100644 --- a/tests/python/driver/tvmc/test_target_options.py +++ b/tests/python/driver/tvmc/test_target_options.py @@ -86,6 +86,19 @@ def test_default_arg_for_mrvl_hybrid(): assert parsed.target_mrvl_num_tiles == 8 +@tvm.testing.requires_mrvl +# Test for default(LLVM) target, when built with USE_MRVL=ON +def test_mrvl_build_with_llvm_only_target(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--target=llvm", + ] + ) + assert parsed.target == "llvm" + + @tvm.testing.requires_cmsisnn def test_mapping_target_args(): parser = argparse.ArgumentParser() diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py index e2305de2afaf..8348c57c1949 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py @@ -61,7 +61,8 @@ def inner(mod): ) sch = tvm.tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - assert ctx.space_generator.postprocs[0].apply(sch) + if not ctx.space_generator.postprocs[0].apply(sch): + raise tvm.TVMError("RewriteLayout postproc failed") return sch.mod return inner diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index e3ed3a3a9d4d..46373510b101 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -63,8 +63,11 @@ def generate_random_inputs( if dtype == "bool": # random_value = np.random.choice(a=[False, True], size=shape) random_value = rg.choice(a=[False, True], size=shape) + elif dtype.startswith("int"): + # Keep non-zero values + random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype) + random_value[random_value <= 0] -= 1 else: - # random_value = np.random.normal(size=shape).astype(dtype) random_value = rg.standard_normal(size=shape).astype(dtype) input_values[i.name] = random_value @@ -118,7 +121,6 @@ def check_correctness( tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) # Legalize any relax ops into tensorir. tvm_model = relax.transform.LegalizeOps()(tvm_model) - print(tvm_model) # Separate model from parameters. tvm_model, params = relax.frontend.detach_params(tvm_model) @@ -247,7 +249,6 @@ def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32 ) model = helper.make_model(graph, producer_name="binary_test") - # NOTE: explicitly pass inputs to avoid numerical error check_correctness(model, opset=opset) @@ -328,6 +329,16 @@ def test_binary(op_name: str): verify_binary_scalar(op_name) +@pytest.mark.parametrize("int_mode", [True, False]) +def test_mod(int_mode: bool): + if int_mode: + dtype, fmod = TensorProto.INT32, 0 + else: + dtype, fmod = TensorProto.FLOAT, 1 + verify_binary("Mod", [1, 32], [1, 32], [1, 32], attrs={"fmod": fmod}, dtype=dtype) + verify_binary_scalar("Mod", attrs={"fmod": fmod}, dtype=dtype) + + @pytest.mark.parametrize("num_inputs", [1, 2, 4]) @pytest.mark.parametrize("op_name", ["Min", "Max", "Sum", "Mean"]) def test_multi_input(op_name: str, num_inputs: int): @@ -431,6 +442,7 @@ def test_bitwise_shift(direction: str): "Sigmoid", "Softmax", "LogSoftmax", + "Hardmax", "Identity", ], ) @@ -446,7 +458,7 @@ def test_unary(op_name: str): output_dtype = TensorProto.BOOL else: output_dtype = TensorProto.FLOAT - verify_unary(op_name, [32, 32], input_dtype=input_dtype, output_dtype=output_dtype) + verify_unary(op_name, [8, 8, 8], input_dtype=input_dtype, output_dtype=output_dtype) @pytest.mark.parametrize("from_type", [TensorProto.INT32, TensorProto.FLOAT, TensorProto.FLOAT16]) @@ -523,6 +535,38 @@ def test_scatter(axis: int, name: str, opset: int): check_correctness(model, inputs={"indices": indices}, opset=opset) +@pytest.mark.parametrize("reduction", ["none", "add", "mul"]) +def test_scatter_nd(reduction): + def verify_scatter_nd(data_shape, indices_shape, updates_shape): + scatter_nd_node = helper.make_node( + "ScatterND", + ["data", "indices", "updates"], + ["output"], + reduction=reduction, + ) + + graph = helper.make_graph( + [scatter_nd_node], + "scatter_nd_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape), + helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape), + helper.make_tensor_value_info("updates", TensorProto.FLOAT, updates_shape), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, data_shape)], + ) + + model = helper.make_model(graph, producer_name="scatter_nd_test") + + indices = np.random.choice(data_shape[0], indices_shape) + check_correctness(model, inputs={"indices": indices}, opset=16) + + verify_scatter_nd([8], [4, 1], [4]) + verify_scatter_nd([4, 4, 4], [2, 1], [2, 4, 4]) + verify_scatter_nd([4, 5, 6], [2, 3, 2], [2, 3, 6]) + verify_scatter_nd([10], [5, 1], [5]) + + def test_size(): test_node = helper.make_node("Size", ["x"], ["y"]) graph = helper.make_graph( @@ -536,6 +580,11 @@ def test_size(): check_correctness(model) +@pytest.mark.parametrize("k", [-1, 0, 1]) +def test_eye_like(k: int): + verify_unary("EyeLike", [32, 32], attrs={"k": k}) + + @pytest.mark.parametrize("alpha", [None, 0.25, 1.0]) @pytest.mark.parametrize("beta", [None, 0.35, 1.0]) @pytest.mark.parametrize("useC", [False, True]) @@ -935,7 +984,7 @@ def test_cumsum1(): ) model = helper.make_model(graph, producer_name="cumsum_graph") - check_correctness(model) + check_correctness(model, inputs={"axis": np.array([0], dtype=np.int32)}) @pytest.mark.parametrize("axis", [[0, 2], None]) @@ -1665,6 +1714,63 @@ def verify_pad(input_shape, pads, mode="constant", value=0.0): verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect") +@pytest.mark.parametrize("dynamic", [True, False]) +def test_pad_v2(dynamic): + + if dynamic: + pytest.skip("Dynamic pad not supported") + + def verify_pad(input_shape, pads, mode="constant", value=0.0): + indata = np.random.normal(size=input_shape).astype(np.float32) + # numpy expect result + len_dim = len(pads) // 2 + np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)] + pads = np.array(pads) + # onnx graph + if mode in ["edge", "reflect"]: + outdata = np.pad(indata, pad_width=np_pads, mode=mode) + node = helper.make_node( + "Pad", inputs=["input"], outputs=["output"], mode=mode, pads=pads + ) + graph = helper.make_graph( + [node], + "pad_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)) + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) + ], + ) + else: + outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) + node = helper.make_node( + "Pad", + inputs=["input"], + outputs=["output"], + mode="constant", + pads=pads, + value=value, + ) + graph = helper.make_graph( + [node], + "pad_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)) + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) + ], + ) + model = helper.make_model(graph, producer_name="pad_test") + check_correctness(model=model, opset=10) + + verify_pad((2, 2), [0, 1, 0, 0], "constant", 0.0) + verify_pad((2, 3), [1, 0, 0, 1], "constant", 0.0) + verify_pad((3, 2), [0, 0, 1, 0], "constant", 5.0) + verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect") + + @pytest.mark.parametrize("fp_arith", [np.float16, np.float32]) @pytest.mark.parametrize("dynamic", [True, False]) def test_split(fp_arith, dynamic): @@ -2162,6 +2268,11 @@ def test_unique(axis: Optional[int], sorted: int): check_correctness(model) +@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (4, 5, 6)]) +def test_nonzero(shape): + verify_unary("NonZero", shape, input_dtype=TensorProto.BOOL, output_dtype=TensorProto.INT64) + + @pytest.mark.parametrize("mode", ["DCR", "CRD"]) def test_depth_to_space(mode: Literal["DCR", "CRD"]): in_shape = [1, 8, 2, 3] diff --git a/tests/python/relax/test_op_create.py b/tests/python/relax/test_op_create.py index 1e895169f620..67f347019163 100644 --- a/tests/python/relax/test_op_create.py +++ b/tests/python/relax/test_op_create.py @@ -545,6 +545,64 @@ def test_ones_like_zeros_like_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.zeros_like(x1)) +def test_eye_infer_struct_info(): + bb = relax.BlockBuilder() + + _check_inference(bb, relax.op.eye(3), relax.TensorStructInfo((3, 3), "float32")) + _check_inference(bb, relax.op.eye(2, 4), relax.TensorStructInfo((2, 4), "float32")) + _check_inference(bb, relax.op.eye(3, dtype="int64"), relax.TensorStructInfo((3, 3), "int64")) + _check_inference(bb, relax.op.eye(3, 5, k=1), relax.TensorStructInfo((3, 5), "float32")) + _check_inference(bb, relax.op.eye(3, 5, k=-2), relax.TensorStructInfo((3, 5), "float32")) + + +def test_eye_infer_struct_info_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + k = tir.Var("k", "int64") + + _check_inference(bb, relax.op.eye(n), relax.TensorStructInfo((n, n), "float32")) + _check_inference(bb, relax.op.eye(n, m), relax.TensorStructInfo((n, m), "float32")) + _check_inference(bb, relax.op.eye(n, k=k), relax.TensorStructInfo((n, n), "float32")) + + +def test_eye_like_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((2, 5), "int64")) + x2 = relax.Var("x", R.Tensor((3, 3))) + + _check_inference(bb, relax.op.eye_like(x0), relax.TensorStructInfo((3, 4), "float32")) + _check_inference(bb, relax.op.eye_like(x1), relax.TensorStructInfo((2, 5), "int64")) + _check_inference(bb, relax.op.eye_like(x2), relax.TensorStructInfo((3, 3), dtype="")) + _check_inference(bb, relax.op.eye_like(x0, k=1), relax.TensorStructInfo((3, 4), "float32")) + _check_inference( + bb, relax.op.eye_like(x1, dtype="float32"), relax.TensorStructInfo((2, 5), "float32") + ) + + +def test_eye_like_infer_struct_info_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + x = relax.Var("x", R.Tensor((n, m), "float32")) + k = tir.Var("k", "int64") + + _check_inference(bb, relax.op.eye_like(x), relax.TensorStructInfo((n, m), "float32")) + _check_inference(bb, relax.op.eye_like(x, k=k), relax.TensorStructInfo((n, m), "float32")) + + +def test_eye_like_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.eye_like(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.eye_like(x1)) + + def test_arange_infer_struct_info(): bb = relax.BlockBuilder() diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index ddb92725d438..f6aefc859114 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -45,6 +45,7 @@ def test_op_correctness(): assert relax.op.einsum(x, subscripts="ii").op == Op.get("relax.einsum") assert relax.op.flip(x, axis=1).op == Op.get("relax.flip") assert relax.op.scatter_elements(x, x, x).op == Op.get("relax.scatter_elements") + assert relax.op.scatter_nd(x, x, x).op == Op.get("relax.scatter_nd") def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -3352,5 +3353,81 @@ def test_scatter_elements_infer_struct_info_rank_shape_mismatch(): bb.normalize(relax.op.scatter_elements(d0, i0, u4)) +def test_scatter_nd_infer_struct_info(): + bb = relax.BlockBuilder() + + d0 = relax.Var("data", R.Tensor((8,), "float32")) + i0 = relax.Var("indices", R.Tensor((4, 1), "int64")) + u0 = relax.Var("updates", R.Tensor((4,), "float32")) + + _check_inference( + bb, + relax.op.scatter_nd(d0, i0, u0, "update"), + relax.TensorStructInfo((8,), dtype="float32"), + ) + + d1 = relax.Var("data", R.Tensor((4, 4, 4), "float32")) + i1 = relax.Var("indices", R.Tensor((2, 1), "int64")) + u1 = relax.Var("updates", R.Tensor((2, 4, 4), "float32")) + + _check_inference( + bb, + relax.op.scatter_nd(d1, i1, u1, "update"), + relax.TensorStructInfo((4, 4, 4), dtype="float32"), + ) + + +def test_one_hot_infer_struct_info(): + bb = relax.BlockBuilder() + + # Test case 1: Basic usage + i0 = relax.Var("indices", R.Tensor((3,), "int32")) + _check_inference( + bb, + relax.op.one_hot(i0, relax.PrimValue(1.0), relax.PrimValue(0.0), 5), + relax.TensorStructInfo((3, 5), "float32"), + ) + + # Test case 2: With specified axis + i1 = relax.Var("indices", R.Tensor((2, 2), "int32")) + _check_inference( + bb, + relax.op.one_hot(i1, relax.PrimValue(1), relax.PrimValue(0), 3, axis=1), + relax.TensorStructInfo((2, 3, 2), "int64"), + ) + + # Test case 3: With symbolic shape + n = tir.Var("n", "int64") + i2 = relax.Var("indices", R.Tensor((n,), "int32")) + _check_inference( + bb, + relax.op.one_hot(i2, relax.PrimValue(1.0), relax.PrimValue(0.0), 4), + relax.TensorStructInfo((n, 4), "float32"), + ) + + # Test case 4: With unknown shape + i3 = relax.Var("indices", R.Tensor("int32")) + _check_inference( + bb, + relax.op.one_hot(i3, relax.PrimValue(1.0), relax.PrimValue(0.0), 6), + relax.TensorStructInfo(dtype="float32"), + ) + + # Test case 5: With different on_value and off_value dtypes + i3 = relax.Var("indices", R.Tensor((2, 3), "int32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.one_hot(i3, relax.PrimValue(1.0), relax.PrimValue(0), 5)) + + # Test case 6: With invalid indices dtype + i4 = relax.Var("indices", R.Tensor((2, 3), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.one_hot(i4, relax.PrimValue(1.0), relax.PrimValue(0.0), 5)) + + # Test case 7: With invalid depth + i5 = relax.Var("indices", R.Tensor((2, 3), "int32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.one_hot(i5, relax.PrimValue(1.0), relax.PrimValue(0.0), -1)) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_op_set.py b/tests/python/relax/test_op_set.py index 741d7869d52f..e9070f99fc3f 100644 --- a/tests/python/relax/test_op_set.py +++ b/tests/python/relax/test_op_set.py @@ -867,5 +867,39 @@ def test_unique_infer_struct_info_wrong_input_dtype(): bb.normalize(relax.op.unique(x1)) +@pytest.mark.parametrize("shape", [(1,), (2, 3), (4, 5, 6)]) +def test_nonzero_infer_struct_info(shape): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor(shape, "bool")) + + _check_inference( + bb, + relax.op.nonzero(x0), + relax.TensorStructInfo(ndim=len(shape) + 1, dtype="int64"), + ) + + +def test_nonzero_infer_struct_info_ndim_zero(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((), "bool")) + + _check_inference( + bb, + relax.op.nonzero(x), + relax.TensorStructInfo(ndim=2, dtype="int64"), + ) + + +def test_nonzero_infer_struct_info_wrong_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nonzero(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nonzero(x1)) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py new file mode 100644 index 000000000000..46f7c8aa87be --- /dev/null +++ b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py @@ -0,0 +1,311 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm.testing + +from tvm import relax, tir +from tvm.script import relax as R, tir as T, ir as I +from tvm.relax.transform import CombineParallelMatmul +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder + + +def test_param(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main(x: R.Tensor((32, 32), "float32"), y: R.Tensor((32, 32), "float32")): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir(cls.matmul, (x, y), out_sinfo=R.Tensor((32, 32), "float32")) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def matmul1( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main(x: R.Tensor((32, 32), "float32"), y: R.Tensor((32, 32), "float32")): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + gv = R.call_tir(cls.matmul1, (x, y), out_sinfo=R.Tensor((32, 32), "float32")) + R.output(gv) + return gv + + after = relax.transform.AttachAttrLayoutFreeBuffers()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_const(): + const_value = np.ones((32, 32), dtype="float32") + + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main(x: R.Tensor((32, 32), "float32")): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.matmul, + (x, relax.const(const_value)), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def matmul1( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main(x: R.Tensor((32, 32), "float32")): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + gv = R.call_tir( + cls.matmul1, + (x, relax.const(const_value)), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + after = relax.transform.AttachAttrLayoutFreeBuffers()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_multiple_same_func(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w1: R.Tensor((32, 32), "float32"), + w2: R.Tensor((32, 32), "float32"), + ): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + lv1 = R.call_tir( + cls.matmul, + (x, w1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + gv = R.call_tir( + cls.matmul, + (lv1, w2), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def matmul1( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w1: R.Tensor((32, 32), "float32"), + w2: R.Tensor((32, 32), "float32"), + ): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + lv1 = R.call_tir( + cls.matmul1, + (x, w1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + gv = R.call_tir( + cls.matmul1, + (lv1, w2), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + after = relax.transform.AttachAttrLayoutFreeBuffers()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_multiple_same_func_with_different_free_buffers(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w1: R.Tensor((32, 32), "float32"), + w2: R.Tensor((32, 32), "float32"), + ): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + lv1 = R.call_tir( + cls.matmul, + (x, w1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + gv = R.call_tir( + cls.matmul, + (w2, lv1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def matmul1( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @T.prim_func(private=True) + def matmul2( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [0]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w1: R.Tensor((32, 32), "float32"), + w2: R.Tensor((32, 32), "float32"), + ): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + lv1 = R.call_tir( + cls.matmul1, + (x, w1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + gv = R.call_tir( + cls.matmul2, + (w2, lv1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + after = relax.transform.AttachAttrLayoutFreeBuffers()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index a0ecd3c73dc9..0565b7a5790a 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import pytest import tvm from tvm import relax from tvm.relax.transform import LegalizeOps @@ -1739,5 +1738,66 @@ def te_layout_transform( tvm.ir.assert_structural_equal(Expected, After) +def test_scatter_nd(): + + # fmt: off + @I.ir_module + class Before: + @R.function + def main( + data: R.Tensor((8,), "float32"), + indices: R.Tensor((4, 1), "int64"), + updates: R.Tensor((4,), "float32"), + ) -> R.Tensor((8,), "float32"): + gv: R.Tensor((8,), "float32") = R.scatter_nd(data, indices, updates, reduction="update") + return gv + + After = relax.transform.LegalizeOps()(Before) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((8,), "float32"), + indices: R.Tensor((4, 1), "int64"), + updates: R.Tensor((4,), "float32"), + ) -> R.Tensor((8,), "float32"): + gv = R.call_tir( + Expected.scatter_nd, (data, indices, updates), R.Tensor((8,), dtype="float32") + ) + return gv + + @T.prim_func(private=True) + def scatter_nd(var_data: T.handle, var_indices: T.handle, var_updates: T.handle, var_scatter_nd_generic: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + data = T.match_buffer(var_data, (T.int64(8),), offset_factor=1) + indices = T.match_buffer(var_indices, (T.int64(4), T.int64(1)), "int64") + updates = T.match_buffer(var_updates, (T.int64(4),), offset_factor=1) + out_buf = T.match_buffer(var_scatter_nd_generic, (T.int64(8),)) + with T.block("root"): + T.reads() + T.writes() + T_transpose = T.alloc_buffer((T.int64(1), T.int64(4)), "int64") + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(4)): + with T.block("T_transpose"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(4), ax1) + T.reads(indices[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = indices[v_ax1, v_ax0] + with T.block("scatter_nd_generic"): + T.reads() + T.writes() + for i in range(T.int64(8)): + out_buf[i] = data[i] + for j in range(T.int64(4)): + for k in T.parallel(T.int64(1)): + out_buf[k + T_transpose[j // T.int64(4), j % T.int64(4)]] = updates[j + k] + + # fmt: on + tvm.ir.assert_structural_equal(After, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py new file mode 100644 index 000000000000..e6b4c8ec4e2a --- /dev/null +++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.testing +from tvm import relax +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_single_buffer(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def tir_func( + X: T.Buffer((224, 224), "float32"), + W: T.Buffer((224, 224), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + W_rewrite = T.alloc_buffer((4, 4, 56, 56)) + for i, j in T.grid(224, 224): + with T.block("W_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": T.bool(True)}) + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224), dtype="float32") + ) + R.output(gv) + return gv + + @I.ir_module + class After: + @T.prim_func(private=True) + def tir_func_prepacked( + X: T.Buffer((224, 224), "float32"), + W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + @T.prim_func(private=True) + def tir_func_weight_prepack( + W: T.Buffer((224, 224), "float32"), + W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + ): + for i, j in T.grid(224, 224): + with T.block("W_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = After + with R.dataflow(): + lv = R.call_tir( + cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4, 4, 56, 56), "float32") + ) + lv1 = R.call_tir( + cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224, 224), "float32") + ) + gv: R.Tensor((224, 224), dtype="float32") = lv1 + R.output(gv) + return gv + + mod = relax.transform.SplitLayoutRewritePreproc()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +def test_multiple_buffers(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def tir_func( + X: T.Buffer((224, 224), "float32"), + W1: T.Buffer((224, 224), "float32"), + W2: T.Buffer((224, 224), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + W1_rewrite = T.alloc_buffer((4, 4, 56, 56)) + W2_rewrite = T.alloc_buffer((4, 4, 56, 56)) + for i, j in T.grid(224, 224): + with T.block("W1_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": T.bool(True)}) + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi, vj] + for i, j in T.grid(224, 224): + with T.block("W2_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": T.bool(True)}) + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi, vj] + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = ( + X[vi, vj] + + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + ) + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w1: R.Tensor((224, 224), dtype="float32"), + w2: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.tir_func, (x, w1, w2), out_sinfo=R.Tensor((224, 224), dtype="float32") + ) + R.output(gv) + return gv + + @I.ir_module + class After: + @T.prim_func(private=True) + def tir_func_prepacked( + X: T.Buffer((224, 224), "float32"), + W1_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = ( + X[vi, vj] + + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + ) + + @T.prim_func(private=True) + def tir_func_weight_prepack( + W1: T.Buffer((224, 224), "float32"), + W2: T.Buffer((224, 224), "float32"), + W1_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + ): + for i, j in T.grid(224, 224): + with T.block("W1_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi, vj] + for i, j in T.grid(224, 224): + with T.block("W2_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi, vj] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w1: R.Tensor((224, 224), dtype="float32"), + w2: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = After + with R.dataflow(): + lv0 = R.call_tir( + cls.tir_func_weight_prepack, + (w1, w2), + out_sinfo=[ + R.Tensor((4, 4, 56, 56), "float32"), + R.Tensor((4, 4, 56, 56), "float32"), + ], + ) + lv1 = R.call_tir( + cls.tir_func_prepacked, + (x, lv0[0], lv0[1]), + out_sinfo=R.Tensor((224, 224), "float32"), + ) + gv: R.Tensor((224, 224), dtype="float32") = lv1 + R.output(gv) + return gv + + mod = relax.transform.SplitLayoutRewritePreproc()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py b/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py new file mode 100644 index 000000000000..cc09a807dcac --- /dev/null +++ b/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py @@ -0,0 +1,332 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) + + +def test_annotate_read_buffer_access(): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi - 1 : vi - 1 + 2, vj - 1 : vj - 1 + 2]) + T.writes(B[vi, vj]) + T.block_attr({"explicit_read_region": [0]}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("B") + sch.annotate_buffer_access( + block, 0, "read", lambda vi, vj: ((vi - 1, vi + 1), (vj - 1, vj + 1)) + ) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_annotate_write_buffer_access(): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi : vi + 2, vj : vj + 2]) + T.block_attr({"explicit_write_region": [0]}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("B") + sch.annotate_buffer_access(block, 0, "write", lambda vi, vj: ((vi, vi + 2), (vj, vj + 2))) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_annotate_buffer_access_for_resize(): + # fmt: off + @T.prim_func + def resize_before(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1, 16, 16), "float16")): + for i0, i1, i2, i3 in T.grid(1, 1, 16, 16): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(x[v_i0, v_i1, 0:32, 0:32]) + T.writes(resize[v_i0, v_i1, v_i2, v_i3]) + resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0), T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0)])) + + @T.prim_func + def resize_expected(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1, 16, 16), "float16")): + for i0, i1, i2, i3 in T.grid(1, 1, 16, 16): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(x[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 + 3, v_i3 * 2 - 3:v_i3 * 2 + 3]) + T.writes(resize[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"explicit_read_region": [0]}) + resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0), T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0)])) + # fmt: on + sch = tir.Schedule(resize_before, debug_mask="all") + block = sch.get_block("resize") + sch.annotate_buffer_access( + block, + 0, + "read", + gen_new_ranges=lambda v_i0, v_i1, v_i2, v_i3: [ + v_i0, + v_i1, + (v_i2 * 2 - 3, v_i2 * 2 + 3), + (v_i3 * 2 - 3, v_i3 * 2 + 3), + ], + ) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], resize_expected) + verify_trace_roundtrip(sch=sch, mod=resize_before) + + +def test_annotate_buffer_access_read_and_write(): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi - 1 : vi + 2, vj - 1 : vj + 2]) + T.writes(B[vi : vi + 2, vj : vj + 2]) + T.block_attr({"explicit_read_region": [0], "explicit_write_region": [0]}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("B") + + sch.annotate_buffer_access( + block, 0, "read", lambda vi, vj: ((vi - 1, vi + 2), (vj - 1, vj + 2)) + ) + + sch.annotate_buffer_access(block, 0, "write", lambda vi, vj: ((vi, vi + 2), (vj, vj + 2))) + + assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_double_annotate_buffer_access_read(): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi - 2 : vi + 3, vj - 2 : vj + 3]) + T.writes(B[vi, vj]) + T.block_attr({"explicit_read_region": [0]}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("B") + + sch.annotate_buffer_access( + block, 0, "read", lambda vi, vj: ((vi - 1, vi + 2), (vj - 1, vj + 2)) + ) + + sch.annotate_buffer_access( + block, 0, "read", lambda vi, vj: ((vi - 2, vi + 3), (vj - 2, vj + 3)) + ) + + assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_annotate_buffer_access_with_compute_at_for_resize(): + # fmt: off + @T.prim_func + def before(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")): + x_global = T.alloc_buffer([1, 3, 200, 200], dtype="float32") + for ax0, ax1, ax2, ax3 in T.grid(1, 3, 200, 200): + with T.block("cache"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] + for i0, i1, i2, i3 in T.grid(1, 3, 100, 100): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(v_i2 * 2 + 0.5)), T.Cast("int32", T.floor(v_i3 * 2 + 0.5))] + + @T.prim_func + def after(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")): + x_global = T.alloc_buffer((1, 3, 200, 200)) + for i0, i1, i2_0, i3_0 in T.grid(1, 3, 10, 10): + for ax0, ax1 in T.grid(24, 24): + with T.block("cache"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(3, i1) + v2 = T.axis.spatial(200, i2_0 * 20 - 3 + ax0) + v3 = T.axis.spatial(200, i3_0 * 20 - 3 + ax1) + T.where(3 <= i2_0 * 20 + ax0 and i2_0 * 20 + ax0 < 203 and 3 <= i3_0 * 20 + ax1 and i3_0 * 20 + ax1 < 203) + T.reads(x[v0, v1, v2, v3]) + T.writes(x_global[v0, v1, v2, v3]) + x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] + for i2_1, i3_1 in T.grid(10, 10): + with T.block("resize"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + v_i2 = T.axis.spatial(100, i2_0 * 10 + i2_1) + v_i3 = T.axis.spatial(100, i3_0 * 10 + i3_1) + T.reads(x_global[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 - 3 + 6, v_i3 * 2 - 3:v_i3 * 2 - 3 + 6]) + T.writes(y[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"explicit_read_region": [0]}) + y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(T.Cast("float32", v_i2 * 2) + T.float32(0.5))), T.Cast("int32", T.floor(T.Cast("float32", v_i3 * 2) + T.float32(0.5)))] + + @T.prim_func + def after_without_annotate_buffer_access(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")): + x_global = T.alloc_buffer((1, 3, 200, 200)) + for i0, i1, i2_0, i3_0 in T.grid(1, 3, 10, 10): + for ax0, ax1 in T.grid(200, 200): + with T.block("cache"): + v0 = T.axis.spatial(1, 0) + v1, v2, v3 = T.axis.remap("SSS", [i1, ax0, ax1]) + T.reads(x[v0, v1, v2, v3]) + T.writes(x_global[v0, v1, v2, v3]) + x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] + for i2_1, i3_1 in T.grid(10, 10): + with T.block("resize"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + v_i2 = T.axis.spatial(100, i2_0 * 10 + i2_1) + v_i3 = T.axis.spatial(100, i3_0 * 10 + i3_1) + T.reads(x_global[v_i0, v_i1, 0:200, 0:200]) + T.writes(y[v_i0, v_i1, v_i2, v_i3]) + y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(T.Cast("float32", v_i2 * 2) + T.float32(0.5))), T.Cast("int32", T.floor(T.Cast("float32", v_i3 * 2) + T.float32(0.5)))] + # fmt: on + + # Schedule with annotate_buffer_access + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("resize") + cache_block = sch.get_block("cache") + + # Annotate buffer access + sch.annotate_buffer_access( + block, + 0, + "read", + lambda vn, vc, vh, vw: (vn, vc, (vh * 2 - 3, vh * 2 + 3), (vw * 2 - 3, vw * 2 + 3)), + ) + + h, w = sch.get_loops(block)[-2:] + ho, hi = sch.split(h, factors=[10, 10]) + wo, wi = sch.split(w, factors=[10, 10]) + sch.reorder(ho, wo, hi, wi) + sch.compute_at(cache_block, wo) + + assert_structural_equal_ignore_global_symbol(sch.mod["main"], after) + verify_trace_roundtrip(sch=sch, mod=before) + + # Schedule without annotate_buffer_access + sch_without_annotate = tir.Schedule(before, debug_mask="all") + block_without_annotate = sch_without_annotate.get_block("resize") + cache_block_without_annotate = sch_without_annotate.get_block("cache") + + h, w = sch_without_annotate.get_loops(block_without_annotate)[-2:] + ho, hi = sch_without_annotate.split(h, factors=[10, 10]) + wo, wi = sch_without_annotate.split(w, factors=[10, 10]) + sch_without_annotate.reorder(ho, wo, hi, wi) + sch_without_annotate.compute_at(cache_block_without_annotate, wo) + + assert_structural_equal_ignore_global_symbol( + sch_without_annotate.mod["main"], after_without_annotate_buffer_access + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_sampling.py b/tests/python/tir-schedule/test_tir_schedule_sampling.py index 8ae576e9b922..f37c818e7992 100644 --- a/tests/python/tir-schedule/test_tir_schedule_sampling.py +++ b/tests/python/tir-schedule/test_tir_schedule_sampling.py @@ -212,5 +212,33 @@ def test_sample_perfect_tile_after_copy(): sch_copy.sample_perfect_tile(i, n=4) +def test_sample_perfect_tile_on_dynamic_loops(): + """Currently dynamic loop is trivially tiled""" + + @T.prim_func + def workload(a: T.handle) -> None: + n = T.int32() + A = T.match_buffer(a, (n, 1024)) + for i, j in T.grid(n, 1024): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = 1.0 + + sch = tir.Schedule(workload, debug_mask="all") + di, si = sch.get_loops(sch.get_block("B")) + + factors = sch.sample_perfect_tile(si, n=4) + factors = [sch.get(i) for i in factors] + prod = factors[0] * factors[1] * factors[2] * factors[3] + assert prod == 1024 + + factors = sch.sample_perfect_tile(di, n=4) + assert factors[0] is None + factors = [sch.get(i) for i in factors[1:]] + prod = factors[0] * factors[1] * factors[2] + assert prod == 1 + verify_trace_roundtrip(sch, mod=workload) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py index f5e5b3b54e76..22344acfe1d4 100644 --- a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py +++ b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py @@ -389,6 +389,41 @@ def test_split_with_inferred_factor(): verify_trace_roundtrip(sch=sch, mod=elementwise) +def test_split_with_dynamic_inferred_factor(): + @T.prim_func + def before(a: T.handle, b: T.handle) -> None: + N = T.int32() + M = T.int32() + A = T.match_buffer(a, (N, 128, M)) + B = T.match_buffer(b, (N, 128, M)) + for i, j, k in T.grid(N, 128, M): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle) -> None: + N, M = T.int32(), T.int32() + A = T.match_buffer(a, (N, 128, M)) + B = T.match_buffer(b, (N, 128, M)) + for i_0, i_1, j_0, j_1, k_0, k_1 in T.grid((N + 15) // 16, 16, 4, 32, 16, (M + 15) // 16): + with T.block("B"): + vi = T.axis.spatial(N, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 32 + j_1) + vk = T.axis.spatial(M, k_0 * ((M + 15) // 16) + k_1) + T.where(i_0 * 16 + i_1 < N and k_0 * ((M + 15) // 16) + k_1 < M) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2.0) + + sch = tir.Schedule(before, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[None, 16]) + sch.split(j, factors=[4, 32]) + sch.split(k, factors=[16, None]) + assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=before) + + def test_split_with_predicate(): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") diff --git a/tests/python/topi/test_topi_depthwise_conv2d_back_input.py b/tests/python/topi/test_topi_depthwise_conv2d_back_input.py index b0a263172010..5087b0047315 100644 --- a/tests/python/topi/test_topi_depthwise_conv2d_back_input.py +++ b/tests/python/topi/test_topi_depthwise_conv2d_back_input.py @@ -36,8 +36,8 @@ def verify_depthwise_conv2d_back_input( stride_w = stride_h padding_w = padding_h - out_h = np.int((in_h + 2 * padding_h - filter_h) / stride_h + 1) - out_w = np.int((in_w + 2 * padding_w - filter_w) / stride_w + 1) + out_h = np.int32((in_h + 2 * padding_h - filter_h) / stride_h + 1) + out_w = np.int32((in_w + 2 * padding_w - filter_w) / stride_w + 1) out_channel = in_channel * channel_multiplier ishape = [batch, in_h, in_w, in_channel] diff --git a/version.py b/version.py index a827571c6cdf..c8151769ba68 100644 --- a/version.py +++ b/version.py @@ -44,7 +44,7 @@ # Two tag formats are supported: # - vMAJ.MIN.PATCH (e.g. v0.8.0) or # - vMAJ.MIN.devN (e.g. v0.8.dev0) -__version__ = "0.18.dev0" +__version__ = "0.19.dev0" # --------------------------------------------------- diff --git a/web/package-lock.json b/web/package-lock.json index 751aaf2ef442..ddc14c7f134d 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "tvmjs", - "version": "0.18.0-dev2", + "version": "0.19.0-dev0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "tvmjs", - "version": "0.18.0-dev2", + "version": "0.19.0-dev0", "license": "Apache-2.0", "devDependencies": { "@rollup/plugin-commonjs": "^20.0.0", diff --git a/web/package.json b/web/package.json index a63997bb2f1c..a89b078cd776 100644 --- a/web/package.json +++ b/web/package.json @@ -3,7 +3,7 @@ "description": "TVM WASM/WebGPU runtime for JS/TS", "license": "Apache-2.0", "homepage": "https://github.com/apache/tvm/tree/main/web", - "version": "0.18.0-dev2", + "version": "0.19.0-dev0", "files": [ "lib" ], From ae4fcc72983151d7fd4f4351bd224f61950e85d3 Mon Sep 17 00:00:00 2001 From: M N Ganesan Date: Thu, 24 Oct 2024 09:49:03 +0000 Subject: [PATCH 8/8] [Frontend][ArgParse] Compile with default target(LLVM) when built USE_MRVL=ON(#17454) This is a use-case of invoking TVMC with default target though it is built with MRVL_ON. In command line processing, validate_target_args checks if there are add-on options derived from the default arguments of codegen/BYOC and it expects that particular codegen to be given explicitly in command line. However, certain codegen's can have default target alone, in that case codegen optios are not extracted there by relaxing the validation Signed-off-by: M N Ganesan --- python/tvm/driver/tvmc/target.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index 428ec24d4f6f..4cfaf130e4db 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -186,12 +186,6 @@ def validate_targets(parse_targets, additional_target_options=None): if additional_target_options is not None: for target_name in additional_target_options: if not any([target for target in parse_targets if target["name"] == target_name]): - # When built with USE_MRVL=ON, add-on target options are passed from MRVL codegen's - # config which has pass_default=True and compiled with default target, don't error - # Use case: --target="llvm" cnn.onnx - if (len(tvm_targets) == 1) and (target_name == "mrvl"): - return - first_option = list(additional_target_options[target_name].keys())[0] raise TVMCException( f"Passed --target-{target_name}-{first_option}"