From e42431e39acf546715dd25732cdd01f0a7752fa9 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Thu, 2 Nov 2023 09:48:14 +0000 Subject: [PATCH 1/2] finished --- include/tvm/target/tag.h | 11 +++++++++++ src/target/tag.cc | 7 +++++-- src/target/target_kind.cc | 1 + tests/python/unittest/test_target_target.py | 7 +++++++ 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/include/tvm/target/tag.h b/include/tvm/target/tag.h index 7add206f3ec5..dd798810d201 100644 --- a/include/tvm/target/tag.h +++ b/include/tvm/target/tag.h @@ -104,6 +104,12 @@ class TargetTagRegEntry { * \param config The config dict for target creation */ inline TargetTagRegEntry& set_config(Map config); + /*! + * \brief Add a key-value pair to the config dict + * \param key The attribute name + * \param value The attribute value + */ + inline TargetTagRegEntry& add_config(String key, ObjectRef value); /*! \brief Set name of the TargetTag to be the same as registry if it is empty */ inline TargetTagRegEntry& set_name(); /*! @@ -131,6 +137,11 @@ inline TargetTagRegEntry& TargetTagRegEntry::set_config(Map c return *this; } +inline TargetTagRegEntry& TargetTagRegEntry::add_config(String key, ObjectRef value) { + tag_->config.Set(key, value); + return *this; +} + inline TargetTagRegEntry& TargetTagRegEntry::set_name() { if (tag_->name.empty()) { tag_->name = name; diff --git a/src/target/tag.cc b/src/target/tag.cc index 28f762609b7d..9c4bce6a7729 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -115,7 +115,7 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano") {"max_threads_per_block", Integer(1024)}, \ {"thread_warp_size", Integer(32)}, \ {"registers_per_block", Integer(RegPerBlock)}, \ - }); + }) // Naming convention for CUDA tags see https://developer.nvidia.com/cuda-gpus // Parameters see Table 15. Technical Specifications per Compute Capability @@ -129,7 +129,8 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-k20", "sm_35", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2075", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768); -TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536); +TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536) + .add_config("l2_cache_size_bytes", Integer(41943040)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536); @@ -231,6 +232,8 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvidia-nvs-310", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5400m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-4200m", "sm_21", 49152, 32768); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4090", "sm_89", 49152, 65536) + .add_config("l2_cache_size_bytes", Integer(75497472)); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090-ti", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3080-ti", "sm_86", 49152, 65536); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 56066fcfb6ab..aa4499ec9667 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -334,6 +334,7 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) .add_attr_option("max_threads_per_block") .add_attr_option("thread_warp_size", Integer(32)) .add_attr_option("registers_per_block") + .add_attr_option("l2_cache_size_bytes") .add_attr_option("max_num_threads", Integer(1024)) // TODO(@zxybazh): deprecate it .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateCUDAAttrs); diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index da1bbc2c211b..d5e8d060254e 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -478,6 +478,13 @@ def test_target_attr_bool_value(): assert target3.attrs["supports_float16"] == 0 +def test_target_attr_l2_cache_size_bytes(): + target0 = Target("nvidia/nvidia-a100") + assert target0.l2_cache_size_bytes == 41943040 + target1 = Target("nvidia/geforce-rtx-4090") + assert target1.l2_cache_size_bytes == 75497472 + + def test_target_features(): target_no_features = Target("cuda") assert target_no_features.features From 0ad186260576cc3c42e69e685e1630234abfdcf0 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Tue, 7 Nov 2023 00:52:56 +0000 Subject: [PATCH 2/2] 1106 --- include/tvm/target/tag.h | 4 ++-- src/target/tag.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/tvm/target/tag.h b/include/tvm/target/tag.h index dd798810d201..a2974c89cb77 100644 --- a/include/tvm/target/tag.h +++ b/include/tvm/target/tag.h @@ -109,7 +109,7 @@ class TargetTagRegEntry { * \param key The attribute name * \param value The attribute value */ - inline TargetTagRegEntry& add_config(String key, ObjectRef value); + inline TargetTagRegEntry& with_config(String key, ObjectRef value); /*! \brief Set name of the TargetTag to be the same as registry if it is empty */ inline TargetTagRegEntry& set_name(); /*! @@ -137,7 +137,7 @@ inline TargetTagRegEntry& TargetTagRegEntry::set_config(Map c return *this; } -inline TargetTagRegEntry& TargetTagRegEntry::add_config(String key, ObjectRef value) { +inline TargetTagRegEntry& TargetTagRegEntry::with_config(String key, ObjectRef value) { tag_->config.Set(key, value); return *this; } diff --git a/src/target/tag.cc b/src/target/tag.cc index 9c4bce6a7729..e6521d384397 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -130,7 +130,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2075", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536) - .add_config("l2_cache_size_bytes", Integer(41943040)); + .with_config("l2_cache_size_bytes", Integer(41943040)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536); @@ -233,7 +233,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvs-5400m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-4200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4090", "sm_89", 49152, 65536) - .add_config("l2_cache_size_bytes", Integer(75497472)); + .with_config("l2_cache_size_bytes", Integer(75497472)); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090-ti", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3080-ti", "sm_86", 49152, 65536);