diff --git a/.gitmodules b/.gitmodules index a1367c97b2f5..6ef740e33153 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "3rdparty/vta-hw"] path = 3rdparty/vta-hw url = https://github.com/apache/incubator-tvm-vta +[submodule "3rdparty/libbacktrace"] + path = 3rdparty/libbacktrace + url = https://github.com/tlc-pack/libbacktrace.git diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 6c401e242c59..21cc7de0dc9f 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 6c401e242c59a1f4c913918246591bb13fd714e7 +Subproject commit 21cc7de0dc9fd6acb796e1be6181fa8e6b6c8f41 diff --git a/3rdparty/libbacktrace b/3rdparty/libbacktrace new file mode 160000 index 000000000000..08f7c7e69f8e --- /dev/null +++ b/3rdparty/libbacktrace @@ -0,0 +1 @@ +Subproject commit 08f7c7e69f8ea61a0c4151359bc8023be8e9217b diff --git a/CMakeLists.txt b/CMakeLists.txt index 451b6a7ee2c2..dab99f501665 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,11 @@ tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF) tvm_option(USE_FALLBACK_STL_MAP "Use TVM's POD compatible Map" OFF) tvm_option(USE_ETHOSN "Build with Arm Ethos-N" OFF) tvm_option(INDEX_DEFAULT_I64 "Defaults the index datatype to int64" ON) +set(_LIBBACKTRACE_DEFAULT OFF) +if(CMAKE_SYSTEM_NAME MATCHES "Darwin" OR CMAKE_SYSTEM_NAME MATCHES "Linux") + set(_LIBBACKTRACE_DEFAULT ON) +endif() +tvm_option(USE_LIBBACKTRACE "Build libbacktrace to supply linenumbers on stack traces" ${_LIBBACKTRACE_DEFAULT}) # 3rdparty libraries tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include") @@ -137,6 +142,8 @@ if(MSVC) add_compile_options(/wd4146) # 'inline': used more than once add_compile_options(/wd4141) + # unknown pragma + add_compile_options(/wd4068) else(MSVC) set(WARNING_FLAG -Wall) if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") @@ -388,6 +395,26 @@ set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}") add_library(tvm_runtime SHARED $) set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}") +target_compile_definitions(tvm_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=) +target_compile_definitions(tvm_runtime_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=) +target_compile_definitions(tvm PUBLIC DMLC_USE_LOGGING_LIBRARY=) +target_compile_definitions(tvm_runtime PUBLIC DMLC_USE_LOGGING_LIBRARY=) +if(USE_LIBBACKTRACE) + message(STATUS "Building with libbacktrace...") + include(cmake/modules/Libbacktrace.cmake) + target_link_libraries(tvm PRIVATE libbacktrace) + target_link_libraries(tvm_runtime PRIVATE libbacktrace) + add_dependencies(tvm_runtime_objs libbacktrace) + # pre 3.12 versions of cmake cannot propagate include directories from imported targets so we set them manually + target_include_directories(tvm PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/libbacktrace/include") + target_include_directories(tvm_objs PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/libbacktrace/include") + target_include_directories(tvm_runtime PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/libbacktrace/include") + target_include_directories(tvm_runtime_objs PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/libbacktrace/include") +else() + target_compile_definitions(tvm_objs PRIVATE TVM_BACKTRACE_DISABLED) + target_compile_definitions(tvm_runtime_objs PRIVATE TVM_BACKTRACE_DISABLED) +endif() + if(USE_MICRO) # NOTE: cmake doesn't track dependencies at the file level across subdirectories. For the # Unix Makefiles generator, need to add these explicit target-level dependency) @@ -402,9 +429,9 @@ endif() if(USE_RELAY_DEBUG) message(STATUS "Building Relay in debug mode...") target_compile_definitions(tvm_objs PRIVATE "USE_RELAY_DEBUG") - target_compile_definitions(tvm_objs PRIVATE "DMLC_LOG_DEBUG") + target_compile_definitions(tvm_objs PRIVATE "TVM_LOG_DEBUG") target_compile_definitions(tvm_runtime_objs PRIVATE "USE_RELAY_DEBUG") - target_compile_definitions(tvm_runtime_objs PRIVATE "DMLC_LOG_DEBUG") + target_compile_definitions(tvm_runtime_objs PRIVATE "TVM_LOG_DEBUG") else() target_compile_definitions(tvm_objs PRIVATE "NDEBUG") target_compile_definitions(tvm_runtime_objs PRIVATE "NDEBUG") @@ -475,6 +502,7 @@ if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") # once minimum CMake version is bumped up to 3.13 or above. target_link_libraries(tvm PRIVATE ${HIDE_SYMBOLS_LINKER_FLAGS}) target_link_libraries(tvm_runtime PRIVATE ${HIDE_SYMBOLS_LINKER_FLAGS}) + target_compile_definitions(tvm_allvisible PUBLIC DMLC_USE_LOGGING_LIBRARY=) endif() # Tests @@ -543,3 +571,33 @@ if(MSVC) target_compile_definitions(tvm_objs PRIVATE -DTVM_EXPORTS) target_compile_definitions(tvm_runtime_objs PRIVATE -DTVM_EXPORTS) endif() + +set(TVM_IS_DEBUG_BUILD OFF) +if(CMAKE_BUILD_TYPE STREQUAL "Debug" OR CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo" OR CMAKE_CXX_FLAGS MATCHES "-g") + set(TVM_IS_DEBUG_BUILD ON) +endif() + +# Change relative paths in backtrace to absolute ones +if(TVM_IS_DEBUG_BUILD) + set(FILE_PREFIX_MAP_FLAG "-ffile-prefix-map=..=${CMAKE_CURRENT_SOURCE_DIR}") + target_compile_options(tvm PRIVATE "${FILE_PREFIX_MAP_FLAG}") + CHECK_CXX_COMPILER_FLAG("${FILE_PREFIX_MAP_FLAG}" FILE_PREFIX_MAP_SUPPORTED) + if(FILE_PREFIX_MAP_SUPPORTED) + target_compile_options(tvm PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) + target_compile_options(tvm_objs PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) + target_compile_options(tvm_runtime PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) + target_compile_options(tvm_runtime_objs PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) + endif() +endif() + +# Run dsymutil to generate debugging symbols for backtraces +if(APPLE AND TVM_IS_DEBUG_BUILD) + find_program(DSYMUTIL dsymutil) + mark_as_advanced(DSYMUTIL) + add_custom_command(TARGET tvm + POST_BUILD + COMMAND ${DSYMUTIL} ARGS $ + COMMENT "Running dsymutil" + VERBATIM + ) +endif() diff --git a/apps/android_camera/app/src/main/jni/Application.mk b/apps/android_camera/app/src/main/jni/Application.mk index 63a79458ef94..5c8774889685 100644 --- a/apps/android_camera/app/src/main/jni/Application.mk +++ b/apps/android_camera/app/src/main/jni/Application.mk @@ -31,7 +31,7 @@ include $(config) APP_ABI ?= all APP_STL := c++_shared -APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti +APP_CPPFLAGS += -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif @@ -43,4 +43,4 @@ endif ifeq ($(USE_SORT), 1) APP_CPPFLAGS += -DUSE_SORT=1 -endif \ No newline at end of file +endif diff --git a/apps/android_camera/app/src/main/jni/tvm_runtime.h b/apps/android_camera/app/src/main/jni/tvm_runtime.h index 5f3db04274a1..47a3a3de6bba 100644 --- a/apps/android_camera/app/src/main/jni/tvm_runtime.h +++ b/apps/android_camera/app/src/main/jni/tvm_runtime.h @@ -25,17 +25,13 @@ #include -/* Enable custom logging - this will cause TVM to pass every log message - * through CustomLogMessage instead of LogMessage. By enabling this, we must - * implement dmlc::CustomLogMessage::Log. We use this to pass TVM log - * messages to Android logcat. +#define DMLC_USE_LOGGING_LIBRARY +#define TVM_BACKTRACE_DISABLED 1 +/* Enable custom logging - this will cause TVM to use a custom implementation + * of tvm::runtime::detail::LogMessage. We use this to pass TVM log messages to + * Android logcat. */ -#define DMLC_LOG_CUSTOMIZE 1 - -/* Ensure that fatal errors are passed to the logger before throwing - * in LogMessageFatal - */ -#define DMLC_LOG_BEFORE_THROW 1 +#define TVM_LOG_CUSTOMIZE 1 #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" @@ -72,8 +68,20 @@ #include -void dmlc::CustomLogMessage::Log(const std::string& msg) { - // This is called for every message logged by TVM. - // We pass the message to logcat. - __android_log_write(ANDROID_LOG_DEBUG, "TVM_RUNTIME", msg.c_str()); -} \ No newline at end of file +namespace tvm { +namespace runtime { +namespace detail { +// Override logging mechanism +void LogFatalImpl(const std::string& file, int lineno, const std::string& message) { + std::string m = file + ":" + std::to_string(lineno) + ": " + message; + __android_log_write(ANDROID_LOG_DEBUG, "TVM_RUNTIME", m.c_str()); + throw InternalError(file, lineno, message); +} +void LogMessageImpl(const std::string& file, int lineno, const std::string& message) { + std::string m = file + ":" + std::to_string(lineno) + ": " + message; + __android_log_write(ANDROID_LOG_DEBUG, "TVM_RUNTIME", m.c_str()); +} + +} // namespace detail +} // namespace runtime +} // namespace tvm diff --git a/apps/android_deploy/app/src/main/jni/Application.mk b/apps/android_deploy/app/src/main/jni/Application.mk index a50a40bf5cd1..42c4f232a553 100644 --- a/apps/android_deploy/app/src/main/jni/Application.mk +++ b/apps/android_deploy/app/src/main/jni/Application.mk @@ -27,7 +27,7 @@ include $(config) APP_STL := c++_static -APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti +APP_CPPFLAGS += -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/android_deploy/app/src/main/jni/tvm_runtime.h b/apps/android_deploy/app/src/main/jni/tvm_runtime.h index 362d278c38c4..4412e9c62e9d 100644 --- a/apps/android_deploy/app/src/main/jni/tvm_runtime.h +++ b/apps/android_deploy/app/src/main/jni/tvm_runtime.h @@ -25,6 +25,9 @@ #include +#define DMLC_USE_LOGGING_LIBRARY +#define TVM_BACKTRACE_DISABLED 1 + #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/dso_library.cc" diff --git a/apps/android_rpc/app/src/main/jni/Application.mk b/apps/android_rpc/app/src/main/jni/Application.mk index 5f885f1c6f14..088eeed750b8 100644 --- a/apps/android_rpc/app/src/main/jni/Application.mk +++ b/apps/android_rpc/app/src/main/jni/Application.mk @@ -31,7 +31,7 @@ include $(config) APP_ABI ?= armeabi-v7a arm64-v8a x86 x86_64 mips APP_STL := c++_shared -APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti +APP_CPPFLAGS += -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index fb5993066448..40e6279fb386 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -25,17 +25,13 @@ #include -/* Enable custom logging - this will cause TVM to pass every log message - * through CustomLogMessage instead of LogMessage. By enabling this, we must - * implement dmlc::CustomLogMessage::Log. We use this to pass TVM log - * messages to Android logcat. +#define DMLC_USE_LOGGING_LIBRARY +#define TVM_BACKTRACE_DISABLED 1 +/* Enable custom logging - this will cause TVM to use a custom implementation + * of tvm::runtime::detail::LogMessage. We use this to pass TVM log messages to + * Android logcat. */ -#define DMLC_LOG_CUSTOMIZE 1 - -/* Ensure that fatal errors are passed to the logger before throwing - * in LogMessageFatal - */ -#define DMLC_LOG_BEFORE_THROW 1 +#define TVM_LOG_CUSTOMIZE 1 #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" @@ -81,8 +77,20 @@ #include -void dmlc::CustomLogMessage::Log(const std::string& msg) { - // This is called for every message logged by TVM. - // We pass the message to logcat. - __android_log_write(ANDROID_LOG_DEBUG, "TVM_RUNTIME", msg.c_str()); +namespace tvm { +namespace runtime { +namespace detail { +// Override logging mechanism +void LogFatalImpl(const std::string& file, int lineno, const std::string& message) { + std::string m = file + ":" + std::to_string(lineno) + ": " + message; + __android_log_write(ANDROID_LOG_DEBUG, "TVM_RUNTIME", m.c_str()); + throw InternalError(file, lineno, message); } +void LogMessageImpl(const std::string& file, int lineno, const std::string& message) { + std::string m = file + ":" + std::to_string(lineno) + ": " + message; + __android_log_write(ANDROID_LOG_DEBUG, "TVM_RUNTIME", m.c_str()); +} + +} // namespace detail +} // namespace runtime +} // namespace tvm diff --git a/apps/bundle_deploy/Makefile b/apps/bundle_deploy/Makefile index 38d9d3456d55..8e23a92afa93 100644 --- a/apps/bundle_deploy/Makefile +++ b/apps/bundle_deploy/Makefile @@ -32,12 +32,14 @@ PKG_CXXFLAGS = ${PKG_COMPILE_OPTS} -std=c++14 \ -I${TVM_ROOT}/include \ -I${DMLC_CORE}/include \ -I${TVM_ROOT}/3rdparty/dlpack/include \ - -Icrt_config + -Icrt_config \ + -DDMLC_USE_LOGGING_LIBRARY=\ PKG_CFLAGS = ${PKG_COMPILE_OPTS} \ -I${TVM_ROOT}/include \ -I${DMLC_CORE}/include \ -I${TVM_ROOT}/3rdparty/dlpack/include \ - -Icrt_config + -Icrt_config \ + -DDMLC_USE_LOGGING_LIBRARY=\ PKG_LDFLAGS = -pthread -lm diff --git a/apps/dso_plugin_module/Makefile b/apps/dso_plugin_module/Makefile index c2ce3306870a..438d9db223a8 100644 --- a/apps/dso_plugin_module/Makefile +++ b/apps/dso_plugin_module/Makefile @@ -19,7 +19,8 @@ TVM_ROOT=$(shell cd ../..; pwd) PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\ - -I${TVM_ROOT}/3rdparty/dlpack/include + -I${TVM_ROOT}/3rdparty/dlpack/include\ + -DDMLC_USE_LOGGING_LIBRARY=\ PKG_LDFLAGS =-L${TVM_ROOT}/build UNAME_S := $(shell uname -s) diff --git a/apps/extension/Makefile b/apps/extension/Makefile index 91d914aba63b..6eba941f7c98 100644 --- a/apps/extension/Makefile +++ b/apps/extension/Makefile @@ -20,7 +20,8 @@ TVM_ROOT=$(shell cd ../..; pwd) PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\ - -I${TVM_ROOT}/3rdparty/dlpack/include + -I${TVM_ROOT}/3rdparty/dlpack/include\ + -DDMLC_USE_LOGGING_LIBRARY=\ PKG_LDFLAGS =-L${TVM_ROOT}/build diff --git a/apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj b/apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj index b33c892cf002..28079e710a38 100644 --- a/apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj +++ b/apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj @@ -349,6 +349,8 @@ GCC_PREPROCESSOR_DEFINITIONS = ( "DEBUG=1", "$(inherited)", + "DMLC_USE_LOGGING_LIBRARY=", + "TVM_BACKTRACE_DISABLED=1", ); GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; @@ -393,6 +395,10 @@ ENABLE_STRICT_OBJC_MSGSEND = YES; GCC_C_LANGUAGE_STANDARD = gnu99; GCC_NO_COMMON_BLOCKS = YES; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DMLC_USE_LOGGING_LIBRARY=", + "TVM_BACKTRACE_DISABLED=1", + ); GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; GCC_WARN_UNDECLARED_SELECTOR = YES; diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.h b/apps/ios_rpc/tvmrpc/TVMRuntime.h index f6a6dc64c53a..0d172fc3eaa1 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.h +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.h @@ -22,7 +22,7 @@ */ #import // Customize logging mechanism, redirect to NSLOG -#define DMLC_LOG_CUSTOMIZE 1 +#define TVM_LOG_CUSTOMIZE 1 #define TVM_METAL_RUNTIME 1 #include diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index fbe4850e1b57..87cb6f9b4c69 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -53,9 +53,19 @@ // CoreML #include "../../../src/runtime/contrib/coreml/coreml_runtime.mm" -namespace dmlc { +namespace tvm { +namespace runtime { +namespace detail { // Override logging mechanism -void CustomLogMessage::Log(const std::string& msg) { NSLog(@"%s", msg.c_str()); } +void LogFatalImpl(const std::string& file, int lineno, const std::string& message) { + throw tvm::runtime::InternalError(file, lineno, message); +} + +void LogMessageImpl(const std::string& file, int lineno, const std::string& message) { + NSLog(@"%s:%d: %s", file.c_str(), lineno, message.c_str()); +} +} +} } // namespace dmlc namespace tvm { @@ -69,7 +79,7 @@ size_t Send(const void* data, size_t size) final { ssize_t nbytes = [stream_ write:reinterpret_cast(data) maxLength:size]; if (nbytes < 0) { NSLog(@"%@", [stream_ streamError].localizedDescription); - throw dmlc::Error("Stream error"); + throw tvm::Error("Stream error"); } return nbytes; } diff --git a/apps/ios_rpc/tvmrpc/ViewController.mm b/apps/ios_rpc/tvmrpc/ViewController.mm index 910c650aedc1..879ed2334a84 100644 --- a/apps/ios_rpc/tvmrpc/ViewController.mm +++ b/apps/ios_rpc/tvmrpc/ViewController.mm @@ -100,7 +100,7 @@ - (void)onReadAvailable { if (flag == 2) { [self onShutdownReceived]; } - } catch (const dmlc::Error& e) { + } catch (const tvm::Error& e) { [self close]; } } @@ -123,7 +123,7 @@ - (void)onWriteAvailable { if (flag == 2) { [self onShutdownReceived]; } - } catch (const dmlc::Error& e) { + } catch (const tvm::Error& e) { [self close]; } } diff --git a/cmake/config.cmake b/cmake/config.cmake index 65859566a664..ed45f20e6b0d 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -272,3 +272,8 @@ set(USE_TARGET_ONNX OFF) # Whether enable BNNS runtime set(USE_BNNS OFF) + +# Whether to use libbacktrace +# Libbacktrace provides line and column information on stack traces from errors. It is only +# supported on linux and macOS. +# set(USE_LIBBACKTRACE OFF) diff --git a/cmake/modules/Libbacktrace.cmake b/cmake/modules/Libbacktrace.cmake new file mode 100644 index 000000000000..742855358809 --- /dev/null +++ b/cmake/modules/Libbacktrace.cmake @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +include(ExternalProject) + +ExternalProject_Add(project_libbacktrace + PREFIX libbacktrace + SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}/../../3rdparty/libbacktrace + BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/libbacktrace + CONFIGURE_COMMAND "${CMAKE_CURRENT_LIST_DIR}/../../3rdparty/libbacktrace/configure" + "--prefix=${CMAKE_CURRENT_BINARY_DIR}/libbacktrace" --with-pic + INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/libbacktrace" + BUILD_COMMAND make + INSTALL_COMMAND make install + BUILD_BYPRODUCTS "${CMAKE_CURRENT_BINARY_DIR}/libbacktrace/lib/libbacktrace.a" + "${CMAKE_CURRENT_BINARY_DIR}/libbacktrace/include/backtrace.h" + ) + +# Custom step to rebuild libbacktrace if any of the source files change +file(GLOB LIBBACKTRACE_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../3rdparty/libbacktrace/*.c") +ExternalProject_Add_Step(project_libbacktrace checkout + DEPENDERS configure + DEPENDEES download + DEPENDS ${LIBBACKTRACE_SRCS} +) + +add_library(libbacktrace STATIC IMPORTED) +add_dependencies(libbacktrace project_libbacktrace) +set_property(TARGET libbacktrace + PROPERTY IMPORTED_LOCATION ${CMAKE_CURRENT_BINARY_DIR}/libbacktrace/lib/libbacktrace.a) +# create include directory so cmake doesn't complain +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/libbacktrace/include) diff --git a/cmake/modules/VTA.cmake b/cmake/modules/VTA.cmake index 115216680fff..58b58d231d83 100644 --- a/cmake/modules/VTA.cmake +++ b/cmake/modules/VTA.cmake @@ -60,6 +60,7 @@ elseif(PYTHON) # Target lib: vta_fsim add_library(vta_fsim SHARED ${FSIM_RUNTIME_SRCS}) target_include_directories(vta_fsim SYSTEM PUBLIC ${VTA_HW_PATH}/include) + target_compile_definitions(vta_fsim PUBLIC DMLC_USE_LOGGING_LIBRARY=) foreach(__def ${VTA_DEFINITIONS}) string(SUBSTRING ${__def} 3 -1 __strip_def) target_compile_definitions(vta_fsim PUBLIC ${__strip_def}) @@ -81,6 +82,7 @@ elseif(PYTHON) # Target lib: vta_tsim add_library(vta_tsim SHARED ${TSIM_RUNTIME_SRCS}) target_include_directories(vta_tsim SYSTEM PUBLIC ${VTA_HW_PATH}/include) + target_compile_definitions(vta_tsim PUBLIC DMLC_USE_LOGGING_LIBRARY=) foreach(__def ${VTA_DEFINITIONS}) string(SUBSTRING ${__def} 3 -1 __strip_def) target_compile_definitions(vta_tsim PUBLIC ${__strip_def}) @@ -107,6 +109,7 @@ elseif(PYTHON) add_library(vta SHARED ${FPGA_RUNTIME_SRCS}) target_include_directories(vta PUBLIC vta/runtime) target_include_directories(vta PUBLIC ${VTA_HW_PATH}/include) + target_compile_definitions(vta PUBLIC DMLC_USE_LOGGING_LIBRARY=) foreach(__def ${VTA_DEFINITIONS}) string(SUBSTRING ${__def} 3 -1 __strip_def) target_compile_definitions(vta PUBLIC ${__strip_def}) diff --git a/golang/Makefile b/golang/Makefile index 6fd77996e119..137e2a488e29 100644 --- a/golang/Makefile +++ b/golang/Makefile @@ -25,7 +25,7 @@ NATIVE_SRC = tvm_runtime_pack.cc GOPATH=$(CURDIR)/gopath GOPATHDIR=${GOPATH}/src/${TARGET}/ CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/" -CGO_CXXFLAGS="-std=c++14" +CGO_CXXFLAGS="-std=c++14 -DDMLC_USE_LOGGING_LIBRARY=\" CGO_CFLAGS="-I${TVM_BASE}" CGO_LDFLAGS="-ldl -lm" diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index f05ab04c3305..da7bc12619bd 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -92,12 +92,12 @@ inline DataType NullValue() { } /*! \brief Error thrown during attribute checking. */ -struct AttrError : public dmlc::Error { +struct AttrError : public Error { /*! * \brief constructor * \param msg error message */ - explicit AttrError(std::string msg) : dmlc::Error("AttributeError:" + msg) {} + explicit AttrError(std::string msg) : Error("AttributeError:" + msg) {} }; /*! diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index 2053a295a3b8..41130a5be0aa 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -37,6 +37,15 @@ namespace tvm { using tvm::parser::SourceMap; using tvm::runtime::TypedPackedFunc; +/*! \brief The diagnostic level, controls the printing of the message. */ +enum class DiagnosticLevel : int { + kBug = 10, + kError = 20, + kWarning = 30, + kNote = 40, + kHelp = 50, +}; + class DiagnosticBuilder; /*! \brief A compiler diagnostic. */ diff --git a/include/tvm/ir/error.h b/include/tvm/ir/error.h index ac7b96a3bd59..6ff61781ac44 100644 --- a/include/tvm/ir/error.h +++ b/include/tvm/ir/error.h @@ -36,11 +36,11 @@ namespace tvm { /*! * \brief A wrapper around std::stringstream to build error. * - * Can be consumed by Error to construct an error. + * Can be consumed by CompileError to construct an error. * * \code * - * void ReportError(const Error& err); + * void ReportError(const CompileError& err); * * void Test(int number) { * // Use error reporter to construct an error. @@ -59,13 +59,13 @@ struct ErrorBuilder { private: std::stringstream stream_; - friend class Error; + friend class CompileError; }; /*! * \brief Custom Error class to be thrown during compilation. */ -class Error : public dmlc::Error { +class CompileError : public Error { public: /*! \brief Location of the error */ Span span; @@ -73,20 +73,20 @@ class Error : public dmlc::Error { * \brief construct error from message. * \param msg The message */ - explicit Error(const std::string& msg) : dmlc::Error(msg), span(nullptr) {} + explicit CompileError(const std::string& msg) : Error(msg), span(nullptr) {} /*! * \brief construct error from error builder. * \param err The error builder */ - Error(const ErrorBuilder& err) : dmlc::Error(err.stream_.str()), span(nullptr) {} // NOLINT(*) + CompileError(const ErrorBuilder& err) : Error(err.stream_.str()), span(nullptr) {} // NOLINT(*) /*! * \brief copy constructor. * \param other The other ereor. */ - Error(const Error& other) : dmlc::Error(other.what()), span(other.span) {} // NOLINT(*) + CompileError(const CompileError& other) : Error(other.what()), span(other.span) {} // NOLINT(*) /*! * \brief default constructor. */ - Error() : dmlc::Error(""), span(nullptr) {} + CompileError() : Error(""), span(nullptr) {} }; /*! @@ -115,13 +115,13 @@ class ErrorReporter { ErrorReporter() : errors_(), node_to_error_() {} /*! - * \brief Report a tvm::Error. + * \brief Report a CompileError. * * This API is useful for reporting spanned errors. * * \param err The error to report. */ - void Report(const Error& err) { + void Report(const CompileError& err) { if (!err.span.defined()) { throw err; } @@ -143,7 +143,7 @@ class ErrorReporter { */ void ReportAt(const GlobalVar& global, const ObjectRef& node, std::stringstream& err) { std::string err_msg = err.str(); - this->ReportAt(global, node, Error(err_msg)); + this->ReportAt(global, node, CompileError(err_msg)); } /*! @@ -158,7 +158,7 @@ class ErrorReporter { * \param node The expression or type to report the error at. * \param err The error to report. */ - void ReportAt(const GlobalVar& global, const ObjectRef& node, const Error& err); + void ReportAt(const GlobalVar& global, const ObjectRef& node, const CompileError& err); /*! * \brief Render all reported errors and exit the program. @@ -176,7 +176,7 @@ class ErrorReporter { inline bool AnyErrors() { return errors_.size() != 0; } private: - std::vector errors_; + std::vector errors_; std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_to_error_; std::unordered_map node_to_gv_; }; diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index 462588006c9b..dd6861750a10 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -29,7 +29,7 @@ #include #include #include -#include +#include namespace tvm { diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index 5dd837038731..f88b04994099 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 336fef21ab88..362582f4dab9 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -30,6 +30,7 @@ #include #include +#include #include #include diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 7d914ce6bff9..b4fdcbff58b4 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -25,7 +25,7 @@ #define TVM_RUNTIME_DATA_TYPE_H_ #include -#include +#include #include #include diff --git a/include/tvm/runtime/logging.h b/include/tvm/runtime/logging.h new file mode 100644 index 000000000000..952a5ffec637 --- /dev/null +++ b/include/tvm/runtime/logging.h @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/logging.h + * \brief logging utilities + * + * We define our own CHECK and LOG macros to replace those from dmlc-core. + * These macros are then injected into dmlc-core via the + * DMLC_USE_LOGGING_LIBRARY define. dmlc-core will #include this file wherever + * it needs logging. + */ +#ifndef TVM_RUNTIME_LOGGING_H_ +#define TVM_RUNTIME_LOGGING_H_ + +#include + +#include +#include +#include +#include +#include + +#include "tvm/runtime/c_runtime_api.h" + +// a technique that enables overriding macro names on the number of parameters. This is used +// to define other macros below +#define GET_MACRO(_1, _2, _3, _4, _5, NAME, ...) NAME + +/*! + * \brief COND_X calls COND_X_N where N is the number of parameters passed to COND_X + * X can be any of CHECK_GE, CHECK_EQ, CHECK, or LOG COND_X (but not COND_X_N) + * are supposed to be used outside this file. + * The first parameter of COND_X (and therefore, COND_X_N), which we call 'quit_on_assert', + * is a boolean. The rest of the parameters of COND_X is the same as the parameters of X. + * quit_on_assert determines the overall behavior of COND_X. If it's true COND_X + * quits the program on assertion failure. If it's false, then it moves on and somehow reports + * the assertion failure back to the macro caller in an appropriate manner (e.g, 'return false' + * in a function, or 'continue' or 'break' in a loop) + * The default behavior when quit_on_assertion is false, is to 'return false'. If this is not + * desirable, the macro caller can pass one more last parameter to COND_X to tell COND_X what + * to do when when quit_on_assertion is false and the assertion fails. + * + * Rationale: These macros were designed to implement functions that have two behaviors + * in a concise way. Those behaviors are quitting on assertion failures, or trying to + * move on from assertion failures. Note that these macros hide lots of control flow in them, + * and therefore, makes the logic of the whole code slightly harder to understand. However, + * in pieces of code that use these macros frequently, it will significantly shorten the + * amount of code needed to be read, and we won't need to clutter the main logic of the + * function by repetitive control flow structure. The first problem + * mentioned will be improved over time as the developer gets used to the macro. + * + * Here is an example of how to use it + * \code + * bool f(..., bool quit_on_assertion) { + * int a = 0, b = 0; + * ... + * a = ... + * b = ... + * // if quit_on_assertion is true, if a==b, continue, otherwise quit. + * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default + * behaviour) COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting" + * ... + * for (int i = 0; i < N; i++) { + * a = ... + * b = ... + * // if quit_on_assertion is true, if a==b, continue, otherwise quit. + * // if quit_on_assertion is false, if a==b, continue, otherwise 'break' (non-default + * // behaviour, therefore, has to be explicitly specified) + * COND_CHECK_EQ(quit_on_assertion, a, b, break) << "some error message when quiting" + * } + * } + * \endcode + */ +#define COND_CHECK_GE(...) \ + GET_MACRO(__VA_ARGS__, COND_CHECK_GE_5, COND_CHECK_GE_4, COND_CHECK_GE_3)(__VA_ARGS__) +#define COND_CHECK_EQ(...) \ + GET_MACRO(__VA_ARGS__, COND_CHECK_EQ_5, COND_CHECK_EQ_4, COND_CHECK_EQ_3)(__VA_ARGS__) +#define COND_CHECK(...) \ + GET_MACRO(__VA_ARGS__, COND_CHECK_5, COND_CHECK_4, COND_CHECK_3, COND_CHECK_2)(__VA_ARGS__) +#define COND_LOG(...) \ + GET_MACRO(__VA_ARGS__, COND_LOG_5, COND_LOG_4, COND_LOG_3, COND_LOG_2)(__VA_ARGS__) + +// Not supposed to be used by users directly. +#define COND_CHECK_OP(quit_on_assert, x, y, what, op) \ + if (!quit_on_assert) { \ + if (!((x)op(y))) what; \ + } else /* NOLINT(*) */ \ + CHECK_##op(x, y) + +#define COND_CHECK_EQ_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, ==) +#define COND_CHECK_GE_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, >=) + +#define COND_CHECK_3(quit_on_assert, x, what) \ + if (!quit_on_assert) { \ + if (!(x)) what; \ + } else /* NOLINT(*) */ \ + CHECK(x) + +#define COND_LOG_3(quit_on_assert, x, what) \ + if (!quit_on_assert) { \ + what; \ + } else /* NOLINT(*) */ \ + LOG(x) + +#define COND_CHECK_EQ_3(quit_on_assert, x, y) COND_CHECK_EQ_4(quit_on_assert, x, y, return false) +#define COND_CHECK_GE_3(quit_on_assert, x, y) COND_CHECK_GE_4(quit_on_assert, x, y, return false) +#define COND_CHECK_2(quit_on_assert, x) COND_CHECK_3(quit_on_assert, x, return false) +#define COND_LOG_2(quit_on_assert, x) COND_LOG_3(quit_on_assert, x, return false) + +#ifdef _MSC_VER +#define TVM_THROW_EXCEPTION noexcept(false) __declspec(noreturn) +#else +#define TVM_THROW_EXCEPTION noexcept(false) +#endif + +namespace tvm { +namespace runtime { + +/* \brief Generate a backtrace when called. + * \return A multiline string of the backtrace. There will be either one or two lines per frame. + */ +std::string Backtrace(); + +/*! \brief Base error type for TVM. Wraps a string message. */ +class Error : public ::dmlc::Error { // for backwards compatibility + public: + /*! \brief Construct an error. + * \param s The message to be displayed with the error. + */ + explicit Error(const std::string& s) : ::dmlc::Error(s) {} +}; + +/*! \brief Error type for errors from CHECK, ICHECK, and LOG(FATAL). This error + * contains a backtrace of where it occured. + */ +class InternalError : public Error { + public: + /*! \brief Construct an error. Not recommended to use directly. Instead use LOG(FATAL). + * + * \param file The file where the error occurred. + * \param lineno The line number where the error occurred. + * \param message The error message to display. + * \param time The time at which the error occurred. This should be in local time. + * \param backtrace Backtrace from when the error occurred. + */ + InternalError(std::string file, int lineno, std::string message, + std::time_t time = std::time(nullptr), std::string backtrace = Backtrace()) + : Error(""), + file_(file), + lineno_(lineno), + message_(message), + time_(time), + backtrace_(backtrace) { + std::ostringstream s; + // XXX: Do not change this format, otherwise all error handling in python will break (because it + // parses the message to reconstruct the error type). + // TODO(tkonolige): Convert errors to Objects, so we can avoid the mess of formatting/parsing + // error messages correctly. + s << "[" << std::put_time(std::localtime(&time), "%H:%M:%S") << "] " << file << ":" << lineno + << ": " << message << std::endl; + if (backtrace.size() > 0) { + s << backtrace << std::endl; + } + full_message_ = s.str(); + } + /*! \return The file in which the error occurred. */ + const std::string& file() const { return file_; } + /*! \return The message associated with this error. */ + const std::string& message() const { return message_; } + /*! \return Formatted error message including file, linenumber, backtrace, and message. */ + const std::string& full_message() const { return full_message_; } + /*! \return The backtrace from where this error occurred. */ + const std::string& backtrace() const { return backtrace_; } + /*! \return The time at which this error occurred. */ + const std::time_t& time() const { return time_; } + /*! \return The line number at which this error occurred. */ + int lineno() const { return lineno_; } + virtual const char* what() const noexcept { return full_message_.c_str(); } + + private: + std::string file_; + int lineno_; + std::string message_; + std::time_t time_; + std::string backtrace_; + std::string full_message_; // holds the full error string +}; + +namespace detail { +#ifndef TVM_LOG_CUSTOMIZE + +/*! \brief Class to accumulate an error message and throw it. Do not use + * directly, instead use LOG(FATAL). + */ +class LogFatal { + public: + LogFatal(const std::string& file, int lineno) : file_(file), lineno_(lineno) {} +#ifdef _MSC_VER +#pragma disagnostic push +#pragma warning(disable : 4722) +#endif + ~LogFatal() noexcept(false) { throw InternalError(file_, lineno_, stream_.str()); } +#ifdef _MSC_VER +#pragma disagnostic pop +#endif + std::ostringstream& stream() { return stream_; } + + private: + std::ostringstream stream_; + std::string file_; + int lineno_; +}; + +/*! \brief Class to accumulate an log message. Do not use directly, instead use + * LOG(INFO), LOG(WARNING), LOG(ERROR). + */ +class LogMessage { + public: + LogMessage(const std::string& file, int lineno) { + std::time_t t = std::time(nullptr); + stream_ << "[" << std::put_time(std::localtime(&t), "%H:%M:%S") << "] " << file << ":" << lineno + << ": "; + } + ~LogMessage() { std::cerr << stream_.str() << std::endl; } + std::ostringstream& stream() { return stream_; } + + private: + std::ostringstream stream_; +}; +#else +// Custom implementations of LogFatal and LogMessage that allow the user to +// override handling of the message. The user must implement LogFatalImpl and LogMessageImpl +void LogFatalImpl(const std::string& file, int lineno, const std::string& message); +class LogFatal { + public: + LogFatal(const std::string& file, int lineno) : file_(file), lineno_(lineno) {} + ~LogFatal() TVM_THROW_EXCEPTION { LogFatalImpl(file_, lineno_, stream_.str()); } + std::ostringstream& stream() { return stream_; } + + private: + std::ostringstream stream_; + std::string file_; + int lineno_; +}; + +void LogMessageImpl(const std::string& file, int lineno, const std::string& message); +class LogMessage { + public: + LogMessage(const std::string& file, int lineno) : file_(file), lineno_(lineno) {} + ~LogMessage() { LogMessageImpl(file_, lineno_, stream_.str()); } + std::ostringstream& stream() { return stream_; } + + private: + std::string file_; + int lineno_; + std::ostringstream stream_; +}; +#endif + +// Below is from dmlc-core +// This class is used to explicitly ignore values in the conditional +// logging macros. This avoids compiler warnings like "value computed +// is not used" and "statement has no effect". +class LogMessageVoidify { + public: + LogMessageVoidify() {} + // This has to be an operator with a precedence lower than << but + // higher than "?:". See its usage. + void operator&(std::ostream&) {} +}; + +// Also from dmlc-core +inline bool DebugLoggingEnabled() { + static int state = 0; + if (state == 0) { + if (auto var = std::getenv("TVM_LOG_DEBUG")) { + if (std::string(var) == "1") { + state = 1; + } else { + state = -1; + } + } else { + // by default hide debug logging. + state = -1; + } + } + return state == 1; +} + +constexpr const char* kTVM_INTERNAL_ERROR_MESSAGE = + "---------------------------------------------------------------\n" + "An internal invariant was violated during the execution of TVM.\n" + "Please read TVM's error reporting guidelines.\n" + "More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.\n" + "---------------------------------------------------------------\n"; + +// Inline _Pragma in macros does not work reliably on old version of MVSC and +// GCC. We wrap all comparisons in a function so that we can use #pragma to +// silence bad comparison warnings. +#define TVM_CHECK_FUNC(name, op) \ + template \ + DMLC_ALWAYS_INLINE bool LogCheck##name(const A& a, const B& b) { \ + return a op b; \ + } + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsign-compare" +TVM_CHECK_FUNC(_LT, <) +TVM_CHECK_FUNC(_GT, >) +TVM_CHECK_FUNC(_LE, <=) +TVM_CHECK_FUNC(_GE, >=) +TVM_CHECK_FUNC(_EQ, ==) +TVM_CHECK_FUNC(_NE, !=) +#pragma GCC diagnostic pop +} // namespace detail + +#define LOG(level) LOG_##level +#define LOG_FATAL ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() +#define LOG_INFO ::tvm::runtime::detail::LogMessage(__FILE__, __LINE__).stream() +#define LOG_ERROR (::tvm::runtime::detail::LogMessage(__FILE__, __LINE__).stream() << "error: ") +#define LOG_WARNING (::tvm::runtime::detail::LogMessage(__FILE__, __LINE__).stream() << "warning: ") + +#define TVM_CHECK_BINARY_OP(name, op, x, y) \ + if (!::tvm::runtime::detail::LogCheck##name(x, y)) \ + ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ + << "Check failed: " << #x " " #op " " #y << ": " + +#define CHECK(x) \ + if (!(x)) \ + ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ + << "Check failed: " #x << " == false: " + +#define CHECK_LT(x, y) TVM_CHECK_BINARY_OP(_LT, <, x, y) +#define CHECK_GT(x, y) TVM_CHECK_BINARY_OP(_GT, >, x, y) +#define CHECK_LE(x, y) TVM_CHECK_BINARY_OP(_LE, <=, x, y) +#define CHECK_GE(x, y) TVM_CHECK_BINARY_OP(_GE, >=, x, y) +#define CHECK_EQ(x, y) TVM_CHECK_BINARY_OP(_EQ, ==, x, y) +#define CHECK_NE(x, y) TVM_CHECK_BINARY_OP(_NE, !=, x, y) +#define CHECK_NOTNULL(x) \ + ((x) == nullptr ? ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ + << "Check not null: " #x << ' ', \ + (x) : (x)) // NOLINT(*) + +#define LOG_IF(severity, condition) \ + !(condition) ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity) + +#if TVM_LOG_DEBUG + +#define LOG_DFATAL LOG_FATAL +#define DFATAL FATAL +#define DLOG(severity) LOG_IF(severity, ::tvm::runtime::detail::DebugLoggingEnabled()) +#define DLOG_IF(severity, condition) \ + LOG_IF(severity, ::tvm::runtime::detail::DebugLoggingEnabled() && (condition)) + +#else + +#define LOG_DFATAL LOG_ERROR +#define DFATAL ERROR +#define DLOG(severity) true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity) +#define DLOG_IF(severity, condition) \ + (true || !(condition)) ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity) + +#endif + +#if TVM_LOG_DEBUG +#define DCHECK(x) \ + while (false) CHECK(x) +#define DCHECK_LT(x, y) \ + while (false) CHECK((x) < (y)) +#define DCHECK_GT(x, y) \ + while (false) CHECK((x) > (y)) +#define DCHECK_LE(x, y) \ + while (false) CHECK((x) <= (y)) +#define DCHECK_GE(x, y) \ + while (false) CHECK((x) >= (y)) +#define DCHECK_EQ(x, y) \ + while (false) CHECK((x) == (y)) +#define DCHECK_NE(x, y) \ + while (false) CHECK((x) != (y)) +#else +#define DCHECK(x) CHECK(x) +#define DCHECK_LT(x, y) CHECK((x) < (y)) +#define DCHECK_GT(x, y) CHECK((x) > (y)) +#define DCHECK_LE(x, y) CHECK((x) <= (y)) +#define DCHECK_GE(x, y) CHECK((x) >= (y)) +#define DCHECK_EQ(x, y) CHECK((x) == (y)) +#define DCHECK_NE(x, y) CHECK((x) != (y)) +#endif + +#define TVM_ICHECK_INDENT " " + +#define ICHECK_BINARY_OP(name, op, x, y) \ + if (!::tvm::runtime::detail::LogCheck##name(x, y)) \ + ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ + << ::tvm::runtime::detail::kTVM_INTERNAL_ERROR_MESSAGE << std::endl \ + << TVM_ICHECK_INDENT << "Check failed: " << #x " " #op " " #y << ": " + +#define ICHECK(x) \ + if (!(x)) \ + ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ + << ::tvm::runtime::detail::kTVM_INTERNAL_ERROR_MESSAGE << TVM_ICHECK_INDENT \ + << "Check failed: " #x << " == false: " + +#define ICHECK_LT(x, y) ICHECK_BINARY_OP(_LT, <, x, y) +#define ICHECK_GT(x, y) ICHECK_BINARY_OP(_GT, >, x, y) +#define ICHECK_LE(x, y) ICHECK_BINARY_OP(_LE, <=, x, y) +#define ICHECK_GE(x, y) ICHECK_BINARY_OP(_GE, >=, x, y) +#define ICHECK_EQ(x, y) ICHECK_BINARY_OP(_EQ, ==, x, y) +#define ICHECK_NE(x, y) ICHECK_BINARY_OP(_NE, !=, x, y) +#define ICHECK_NOTNULL(x) \ + ((x) == nullptr ? ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ + << ::tvm::runtime::detail::kTVM_INTERNAL_ERROR_MESSAGE \ + << TVM_ICHECK_INDENT << "Check not null: " #x << ' ', \ + (x) : (x)) // NOLINT(*) + +} // namespace runtime +// Re-export error types +using runtime::Error; +using runtime::InternalError; +} // namespace tvm +#endif // TVM_RUNTIME_LOGGING_H_ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 47788394126e..048fc1d5af54 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -24,7 +24,7 @@ #define TVM_RUNTIME_OBJECT_H_ #include -#include +#include #include #include diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 751a435c734a..7113863a6fb3 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -24,10 +24,10 @@ #ifndef TVM_RUNTIME_PACKED_FUNC_H_ #define TVM_RUNTIME_PACKED_FUNC_H_ -#include #include #include #include +#include #include #include #include @@ -1086,7 +1086,7 @@ struct PackedFuncValueConverter { Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ rv.MoveToCHost(out_value, out_type_code); \ return 0; \ - } catch (const ::std::runtime_error& _except_) { \ + } catch (const ::std::exception& _except_) { \ TVMAPISetLastError(_except_.what()); \ return -1; \ } \ @@ -1140,7 +1140,7 @@ struct PackedFuncValueConverter { f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ rv.MoveToCHost(out_value, out_type_code); \ return 0; \ - } catch (const ::std::runtime_error& _except_) { \ + } catch (const ::std::exception& _except_) { \ TVMAPISetLastError(_except_.what()); \ return -1; \ } \ diff --git a/include/tvm/runtime/vm/bytecode.h b/include/tvm/runtime/vm/bytecode.h index e858c4458054..72a557fa93b1 100644 --- a/include/tvm/runtime/vm/bytecode.h +++ b/include/tvm/runtime/vm/bytecode.h @@ -25,7 +25,7 @@ #define TVM_RUNTIME_VM_BYTECODE_H_ #include -#include +#include #include #include diff --git a/include/tvm/support/logging.h b/include/tvm/support/logging.h deleted file mode 100644 index ced1902a1bd1..000000000000 --- a/include/tvm/support/logging.h +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/support/logging.h - * \brief logging utilities on top of dmlc-core - */ -#ifndef TVM_SUPPORT_LOGGING_H_ -#define TVM_SUPPORT_LOGGING_H_ - -#include - -// a technique that enables overriding macro names on the number of parameters. This is used -// to define other macros below -#define GET_MACRO(_1, _2, _3, _4, _5, NAME, ...) NAME - -/*! - * \brief COND_X calls COND_X_N where N is the number of parameters passed to COND_X - * X can be any of CHECK_GE, CHECK_EQ, CHECK, or LOG (defined dmlc-core/include/dmlc/logging.h.) - * COND_X (but not COND_X_N) are supposed to be used outside this file. - * The first parameter of COND_X (and therefore, COND_X_N), which we call 'quit_on_assert', - * is a boolean. The rest of the parameters of COND_X is the same as the parameters of X. - * quit_on_assert determines the overall behaviour of COND_X. If it's true COND_X - * quits the program on assertion failure. If it's false, then it moves on and somehow reports - * the assertion failure back to the macro caller in an appropriate manner (e.g, 'return false' - * in a function, or 'continue' or 'break' in a loop) - * The default behavior when quit_on_assertion is false, is to 'return false'. If this is not - * desirable, the macro caller can pass one more last parameter to COND_X to tell COND_X what - * to do when when quit_on_assertion is false and the assertion fails. - * - * Rationale: These macros were designed to implement functions that have two behaviours - * in a concise way. Those behaviours are quitting on assertion failures, or trying to - * move on from assertion failures. Note that these macros hide lots of control flow in them, - * and therefore, makes the logic of the whole code slightly harder to understand. However, - * in pieces of code that use these macros frequently, it will significantly shorten the - * amount of code needed to be read, and we won't need to clutter the main logic of the - * function by repetitive control flow structure. The first problem - * mentioned will be improved over time as the developer gets used to the macro. - * - * Here is an example of how to use it - * \code - * bool f(..., bool quit_on_assertion) { - * int a = 0, b = 0; - * ... - * a = ... - * b = ... - * // if quit_on_assertion is true, if a==b, continue, otherwise quit. - * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default - * behaviour) COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting" - * ... - * for (int i = 0; i < N; i++) { - * a = ... - * b = ... - * // if quit_on_assertion is true, if a==b, continue, otherwise quit. - * // if quit_on_assertion is false, if a==b, continue, otherwise 'break' (non-default - * // behaviour, therefore, has to be explicitly specified) - * COND_CHECK_EQ(quit_on_assertion, a, b, break) << "some error message when quiting" - * } - * } - * \endcode - */ -#define COND_CHECK_GE(...) \ - GET_MACRO(__VA_ARGS__, COND_CHECK_GE_5, COND_CHECK_GE_4, COND_CHECK_GE_3)(__VA_ARGS__) -#define COND_CHECK_EQ(...) \ - GET_MACRO(__VA_ARGS__, COND_CHECK_EQ_5, COND_CHECK_EQ_4, COND_CHECK_EQ_3)(__VA_ARGS__) -#define COND_CHECK(...) \ - GET_MACRO(__VA_ARGS__, COND_CHECK_5, COND_CHECK_4, COND_CHECK_3, COND_CHECK_2)(__VA_ARGS__) -#define COND_LOG(...) \ - GET_MACRO(__VA_ARGS__, COND_LOG_5, COND_LOG_4, COND_LOG_3, COND_LOG_2)(__VA_ARGS__) - -// Not supposed to be used by users directly. -#define COND_CHECK_OP(quit_on_assert, x, y, what, op) \ - if (!quit_on_assert) { \ - if (!((x)op(y))) what; \ - } else /* NOLINT(*) */ \ - CHECK_##op(x, y) - -#define COND_CHECK_EQ_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, ==) -#define COND_CHECK_GE_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, >=) - -#define COND_CHECK_3(quit_on_assert, x, what) \ - if (!quit_on_assert) { \ - if (!(x)) what; \ - } else /* NOLINT(*) */ \ - CHECK(x) - -#define COND_LOG_3(quit_on_assert, x, what) \ - if (!quit_on_assert) { \ - what; \ - } else /* NOLINT(*) */ \ - LOG(x) - -#define COND_CHECK_EQ_3(quit_on_assert, x, y) COND_CHECK_EQ_4(quit_on_assert, x, y, return false) -#define COND_CHECK_GE_3(quit_on_assert, x, y) COND_CHECK_GE_4(quit_on_assert, x, y, return false) -#define COND_CHECK_2(quit_on_assert, x) COND_CHECK_3(quit_on_assert, x, return false) -#define COND_LOG_2(quit_on_assert, x) COND_LOG_3(quit_on_assert, x, return false) - -namespace tvm { - -constexpr const char* kTVM_INTERNAL_ERROR_MESSAGE = - "\n---------------------------------------------------------------\n" - "An internal invariant was violated during the execution of TVM.\n" - "Please read TVM's error reporting guidelines.\n" - "More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.\n" - "---------------------------------------------------------------\n"; - -#define ICHECK_INDENT " " - -#define ICHECK_BINARY_OP(name, op, x, y) \ - if (dmlc::LogCheckError _check_err = dmlc::LogCheck##name(x, y)) \ - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ - << tvm::kTVM_INTERNAL_ERROR_MESSAGE << std::endl \ - << ICHECK_INDENT << "Check failed: " << #x " " #op " " #y << *(_check_err.str) << ": " - -#define ICHECK(x) \ - if (!(x)) \ - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ - << tvm::kTVM_INTERNAL_ERROR_MESSAGE << ICHECK_INDENT << "Check failed: " #x << " == false: " - -#define ICHECK_LT(x, y) ICHECK_BINARY_OP(_LT, <, x, y) -#define ICHECK_GT(x, y) ICHECK_BINARY_OP(_GT, >, x, y) -#define ICHECK_LE(x, y) ICHECK_BINARY_OP(_LE, <=, x, y) -#define ICHECK_GE(x, y) ICHECK_BINARY_OP(_GE, >=, x, y) -#define ICHECK_EQ(x, y) ICHECK_BINARY_OP(_EQ, ==, x, y) -#define ICHECK_NE(x, y) ICHECK_BINARY_OP(_NE, !=, x, y) -#define ICHECK_NOTNULL(x) \ - ((x) == nullptr ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ - << tvm::kTVM_INTERNAL_ERROR_MESSAGE << ICHECK_INDENT \ - << "Check not null: " #x << ' ', \ - (x) : (x)) // NOLINT(*) - -/*! \brief The diagnostic level, controls the printing of the message. */ -enum class DiagnosticLevel : int { - kBug = 10, - kError = 20, - kWarning = 30, - kNote = 40, - kHelp = 50, -}; - -} // namespace tvm -#endif // TVM_SUPPORT_LOGGING_H_ diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h index 90c82c4f3a06..d4547a304e8f 100644 --- a/include/tvm/support/with.h +++ b/include/tvm/support/with.h @@ -25,7 +25,7 @@ #ifndef TVM_SUPPORT_WITH_H_ #define TVM_SUPPORT_WITH_H_ -#include +#include #include diff --git a/licenses/LICENSE.libbacktrace.txt b/licenses/LICENSE.libbacktrace.txt new file mode 100644 index 000000000000..097d2774e5df --- /dev/null +++ b/licenses/LICENSE.libbacktrace.txt @@ -0,0 +1,29 @@ +# Copyright (C) 2012-2016 Free Software Foundation, Inc. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# (1) Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# (2) Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. + +# (3) The name of the author may not be used to +# endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, +# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. diff --git a/python/setup.py b/python/setup.py index e02369e97777..b47e5b14f6a7 100644 --- a/python/setup.py +++ b/python/setup.py @@ -94,7 +94,7 @@ def config_cython(): subdir = "_cy2" ret = [] path = "tvm/_ffi/_cython" - extra_compile_args = ["-std=c++14"] + extra_compile_args = ["-std=c++14", "-DDMLC_USE_LOGGING_LIBRARY="] if os.name == "nt": library_dirs = ["tvm", "../build/Release", "../build"] libraries = ["tvm"] diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 397090618ade..0496195fd73f 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -253,7 +253,9 @@ def c2pyerror(err_msg): message = [] for line in arr: if trace_mode: - if line.startswith(" "): + if line.startswith(" "): + stack_trace[-1] += "\n" + line + elif line.startswith(" "): stack_trace.append(line) else: trace_mode = False diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index 3837d423f8bd..d95f14f0349e 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -118,7 +118,7 @@ def get_runtime_libs() -> str: RUNTIME_SRC_REGEX = re.compile(r"^.*\.cc?$", re.IGNORECASE) -_COMMON_CFLAGS = ["-Wall", "-Werror"] +_COMMON_CFLAGS = ["-Wall", "-Werror", "-DDMLC_USE_LOGGING_LIBRARY="] def _build_default_compiler_options(standalone_crt_dir: typing.Optional[str] = None) -> str: diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 4e7fb05660a4..abbcba234848 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -1367,7 +1367,7 @@ Array ComputeDAG::InferBound(const Array& states) const { support::parallel_for(0, states.size(), [this, &states, &out_states](int i) { try { out_states.Set(i, (states[i].defined()) ? this->InferBound(states[i]) : states[i]); - } catch (dmlc::Error& e) { + } catch (Error& e) { LOG(WARNING) << "InferBound fails on the state:\n" << states[i] << "\n" << "with: " << e.what() << std::endl; diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index d93218c0208c..b3c62f01c7c8 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1328,7 +1328,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i const auto& prim_func = (*it).second.as(); GetPerStoreFeature(prim_func->body, task->hardware_params->cache_line_bytes, max_n_bufs, feature); - } catch (dmlc::Error& e) { + } catch (Error& e) { (*error_ct)++; } } diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 110be6bd6f68..8eaf80321456 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -1106,7 +1106,7 @@ PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNo } try { StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag); - } catch (dmlc::Error& e) { + } catch (Error& e) { return ResultKind::kInvalid; } } @@ -1228,7 +1228,7 @@ PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* pol tmp_s.CopyOnWrite()->transform_steps.push_back(step); try { StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag); - } catch (dmlc::Error& e) { + } catch (Error& e) { return ResultKind::kInvalid; } } diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc old mode 100755 new mode 100644 index 5ba3eee07098..b67d5cdd7bd9 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -26,8 +26,8 @@ #include #include #include +#include #include -#include #include #include diff --git a/src/ir/error.cc b/src/ir/error.cc index 5d3978dda4ff..0089f55a4da8 100644 --- a/src/ir/error.cc +++ b/src/ir/error.cc @@ -132,7 +132,8 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { LOG(FATAL) << annotated_prog.str() << std::endl; } -void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node, const Error& err) { +void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node, + const CompileError& err) { size_t index_to_insert = this->errors_.size(); this->errors_.push_back(err); auto it = this->node_to_error_.find(node); diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 3061735eff7c..c7d8e025848a 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -28,9 +28,9 @@ #include #include #include +#include #include #include -#include #include @@ -172,8 +172,8 @@ class ScopeStack { void PopStack() { this->scope_stack.pop_back(); } }; -struct DuplicateKeyError : public dmlc::Error { - explicit DuplicateKeyError(const std::string& msg) : dmlc::Error(msg) {} +struct DuplicateKeyError : public Error { + explicit DuplicateKeyError(const std::string& msg) : Error(msg) {} }; /*! \brief A table of interning strings as global function and type names. */ @@ -1492,7 +1492,7 @@ class Parser { DLOG(INFO) << "op_name=" << op_name << " span=" << span; try { return Op::Get(op_name); - } catch (const dmlc::Error& e) { + } catch (const Error& e) { // we can relax this, but probably need to relax checks or return non-null here. this->diag_ctx.EmitFatal(Diagnostic::Error(span) << "operator `" << op_name diff --git a/src/parser/span_check.h b/src/parser/span_check.h index 9a887474fe67..ab71d30a54f5 100644 --- a/src/parser/span_check.h +++ b/src/parser/span_check.h @@ -30,8 +30,8 @@ #include #include #include +#include #include -#include #include #include diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 04a18c4b7351..85a9c51a2fa8 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -157,8 +157,9 @@ class AnnotatedRegionSet::Creator : protected MixedModeVisitor { // Check if the argument already belongs to a region auto region = region_set_->GetRegion(call->args[0]); if (!region.defined()) { - throw Error(ErrorBuilder() << "Cannot find the corresponding region for end annotation:\n" - << AsText(GetRef(call), false)); + throw CompileError(ErrorBuilder() + << "Cannot find the corresponding region for end annotation:\n" + << AsText(GetRef(call), false)); } else { // If the argument is belonged to a region, it must have the same target. // Otherwise we should see a region_begin op. diff --git a/src/relay/analysis/kind_check.cc b/src/relay/analysis/kind_check.cc index c7c5a0a9f083..65b8516cb16c 100644 --- a/src/relay/analysis/kind_check.cc +++ b/src/relay/analysis/kind_check.cc @@ -139,7 +139,7 @@ struct KindChecker : TypeFunctor { << "Expected " << data->type_vars.size() << "arguments for " << tc << "; got " << op->args.size()); } - } catch (const dmlc::Error& err) { + } catch (const Error& err) { // TODO(@jroesch): can probably relax to just emit EmitFatal(Diagnostic::Error(op->span) << "the type variable : `" << var->name_hint << "` is undefined"); diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index cc1ada677c65..22e2e9a71040 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -617,10 +617,10 @@ bool TypeSolver::Solve() { } rnode->resolved = resolved; - } catch (const Error& err) { + } catch (const CompileError& err) { this->diag_ctx_.Emit(Diagnostic::Error(rnode->span) << err.what()); rnode->resolved = false; - } catch (const dmlc::Error& e) { + } catch (const Error& e) { ICHECK(false) << e.what(); } diff --git a/src/relay/analysis/well_formed.cc b/src/relay/analysis/well_formed.cc index 856c5dc7aac1..acc1a9adc9f4 100644 --- a/src/relay/analysis/well_formed.cc +++ b/src/relay/analysis/well_formed.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 251a55f10b72..9d3ffc558aae 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -33,8 +33,8 @@ #include #include #include +#include #include -#include #include #include diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 615a8181b387..9c813a4f561c 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -29,8 +29,8 @@ #include #include #include +#include #include -#include #include #include diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index eb848eb7a828..05fb2a120620 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index cc530a10188e..c768a2c300ec 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index cdf898fca756..5e9b1b7978f9 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 2a49a2e251f8..379fa3fa71d3 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -25,7 +25,7 @@ #define TVM_RELAY_OP_NN_CONVOLUTION_H_ #include -#include +#include #include #include diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e3929bf8b77e..b65068bd0506 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -312,7 +312,7 @@ bool StackRel(const Array& types, int num_inputs, const Attrs& attrs, if (first->shape[j].as() || e->shape[j].as() || reporter->AssertEQ(first->shape[j], e->shape[j])) continue; - throw Error( + throw CompileError( "relay.stack requires all tensors have the same shape " "on non-stacking axes"); } @@ -483,7 +483,7 @@ Array> TransposeInferCorrectLayout(const Attrs& attrs, } try { return Array>({{Layout(in_layout_str)}, {Layout(out_layout_str)}}); - } catch (const dmlc::Error& e) { + } catch (const tvm::Error& e) { // If the layout string is invalid for any reason, give up. return Array>({{Layout::Undef()}, {Layout::Undef()}}); } @@ -1691,8 +1691,8 @@ bool MeshgridRel(const Array& types, int num_inputs, const Attrs& raw_attr const MeshgridAttrs* attrs = raw_attrs.as(); const auto* tensor_tuple = types[0].as(); if (tensor_tuple == nullptr) { - throw Error( - ErrorBuilder() << "meshgrid requires a tuple of tensors as the first argument, found " + throw CompileError(ErrorBuilder() + << "meshgrid requires a tuple of tensors as the first argument, found " << PrettyPrint(types[0])); } else if (types[0].as() != nullptr) { return false; @@ -1714,14 +1714,14 @@ bool MeshgridRel(const Array& types, int num_inputs, const Attrs& raw_attr int e_ndim = static_cast(e->shape.size()); const DataType& e_dtype = e->dtype; if (e_dtype != dtype) { - throw Error("relay.meshgrid requires all tensors have the same dtype"); + throw CompileError("relay.meshgrid requires all tensors have the same dtype"); } if (e_ndim == 0) { grid_shape.emplace_back(1); } else if (e_ndim == 1) { grid_shape.emplace_back(e->shape[0]); } else { - throw Error("relay.meshgrid requires all tensors be either scalars or 1-D vectors."); + throw CompileError("relay.meshgrid requires all tensors be either scalars or 1-D vectors."); } } diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index dbf8537e0dad..3c670bcaaa51 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -78,8 +78,8 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs // Sanity check: axis int axis = param->axis; if (!(-ndim <= axis && axis < ndim)) { - throw Error(ErrorBuilder() << "concatenate only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis << ", and ndim = " << ndim); + throw CompileError(ErrorBuilder() << "concatenate only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim); } axis = axis < 0 ? ndim + axis : axis; diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 7b30aea2eb57..6e30ad9624c4 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -85,7 +85,7 @@ TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataTyp } else if (EqualCheck(s1, s2)) { oshape.push_back(s1); } else { - throw Error(ErrorBuilder() << "Incompatible broadcast type " << t1 << " and " << t2); + throw CompileError(ErrorBuilder() << "Incompatible broadcast type " << t1 << " and " << t2); } } diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 59a519d66436..eb0f83836a54 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -51,9 +51,10 @@ bool QnnConcatenateRel(const Array& types, int num_inputs, const Attrs& at if (types[1].as()) { return false; } else { - throw Error(ErrorBuilder() - << "qnn concatenate requires a tuple of scales as the second argument, found " - << PrettyPrint(types[1])); + throw CompileError( + ErrorBuilder() + << "qnn concatenate requires a tuple of scales as the second argument, found " + << PrettyPrint(types[1])); } } for (const auto& input_scale : input_scales_tuple->fields) { @@ -68,9 +69,10 @@ bool QnnConcatenateRel(const Array& types, int num_inputs, const Attrs& at if (types[2].as()) { return false; } else { - throw Error(ErrorBuilder() - << "qnn concatenate requires a tuple of zero_points as the third argument, found " - << PrettyPrint(types[2])); + throw CompileError( + ErrorBuilder() + << "qnn concatenate requires a tuple of zero_points as the third argument, found " + << PrettyPrint(types[2])); } } for (const auto& input_zero_point : input_zero_points_tuple->fields) { diff --git a/src/relay/transforms/fold_explicit_padding.cc b/src/relay/transforms/fold_explicit_padding.cc index bab8b814df05..d959e5b75e40 100644 --- a/src/relay/transforms/fold_explicit_padding.cc +++ b/src/relay/transforms/fold_explicit_padding.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include "../op/tensor/transform.h" #include "pattern_utils.h" diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index dae34674de77..6e6505b28dc6 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -36,7 +36,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index b8c87909a025..f75b7ba1fc75 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index fa080a7ff22c..3a87aa8ed498 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -861,8 +861,8 @@ class PartialEvaluator : public ExprFunctor return VisitFunc(GetRef(op), ll); } - struct ReflectError : dmlc::Error { - ReflectError() : dmlc::Error("static value not found") {} + struct ReflectError : Error { + ReflectError() : Error("static value not found") {} }; Expr Reflect(const PStatic& st) { diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 3c8876ceccb5..b4f4cc16e9df 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include "../op/tensor/transform.h" #include "pattern_utils.h" diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 05844477cc5b..91e8d90c1232 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include "../../support/arena.h" #include "../analysis/dependency_graph.h" diff --git a/src/relay/transforms/to_basic_block_normal_form.cc b/src/relay/transforms/to_basic_block_normal_form.cc index 1aab367cf22a..79157bba1918 100644 --- a/src/relay/transforms/to_basic_block_normal_form.cc +++ b/src/relay/transforms/to_basic_block_normal_form.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include "../../support/arena.h" #include "../analysis/dependency_graph.h" diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index b4ccd1659865..4c6013792426 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -166,7 +166,7 @@ class TypeInferencer : private ExprFunctor, bool assign_rhs = true) { try { return solver_.Unify(t1, t2, span, assign_lhs, assign_rhs); - } catch (const dmlc::Error& e) { + } catch (const Error& e) { this->EmitFatal(Diagnostic::Error(span) << "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what()); return Type(); diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 7fd27cba6136..150d7f215da5 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -384,7 +384,7 @@ typedef dmlc::ThreadLocalStore TVMAPIRuntimeStore; const char* TVMGetLastError() { return TVMAPIRuntimeStore::Get()->last_error.c_str(); } -int TVMAPIHandleException(const std::runtime_error& e) { +int TVMAPIHandleException(const std::exception& e) { TVMAPISetLastError(NormalizeError(e.what()).c_str()); return -1; } @@ -518,7 +518,7 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPacked int ret = func(const_cast(args.values), const_cast(args.type_codes), args.num_args, rv, resource_handle); if (ret != 0) { - throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); + throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace()); } }); } else { @@ -529,7 +529,7 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPacked int ret = func(const_cast(args.values), const_cast(args.type_codes), args.num_args, rv, rpack.get()); if (ret != 0) { - throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); + throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace()); } }); } diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index 16496e06aae3..fbac6222488d 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -21,8 +21,8 @@ * \file Use external cblas library call. */ #include +#include #include -#include extern "C" { #include diff --git a/src/runtime/contrib/cblas/mkl.cc b/src/runtime/contrib/cblas/mkl.cc index 273aa45367dd..4323878db276 100644 --- a/src/runtime/contrib/cblas/mkl.cc +++ b/src/runtime/contrib/cblas/mkl.cc @@ -21,8 +21,8 @@ * \file Use external mkl library call. */ #include +#include #include -#include extern "C" { #include diff --git a/src/runtime/contrib/cblas/mkldnn.cc b/src/runtime/contrib/cblas/mkldnn.cc index 1c3fa023dcc7..31abd317c6a4 100644 --- a/src/runtime/contrib/cblas/mkldnn.cc +++ b/src/runtime/contrib/cblas/mkldnn.cc @@ -21,8 +21,8 @@ * \file Use external cblas library call. */ #include +#include #include -#include extern "C" { #include diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index b12992f57159..9af1602cf3c0 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -21,8 +21,8 @@ * \file Use external cblas library call. */ #include +#include #include -#include #include "../cblas/gemm_common.h" #include "cublas_utils.h" diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 32c3b03ddbb0..3edb8300be88 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #if CUDART_VERSION >= 10010 diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 528298b75187..9b8e9fb33f98 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -26,7 +26,7 @@ #include #include -#include +#include #include "../../cuda/cuda_common.h" diff --git a/src/runtime/contrib/miopen/miopen_utils.h b/src/runtime/contrib/miopen/miopen_utils.h index 9982f0914f6b..e5a769a974f0 100644 --- a/src/runtime/contrib/miopen/miopen_utils.h +++ b/src/runtime/contrib/miopen/miopen_utils.h @@ -26,7 +26,7 @@ #include #include -#include +#include #include diff --git a/src/runtime/contrib/mps/mps_utils.h b/src/runtime/contrib/mps/mps_utils.h index d1c49732318a..c2b7e3c7aa99 100644 --- a/src/runtime/contrib/mps/mps_utils.h +++ b/src/runtime/contrib/mps/mps_utils.h @@ -28,8 +28,8 @@ #include #include #include +#include #include -#include #include diff --git a/src/runtime/contrib/nnpack/convolution.cc b/src/runtime/contrib/nnpack/convolution.cc index b3ea6c891d43..0d6359495902 100644 --- a/src/runtime/contrib/nnpack/convolution.cc +++ b/src/runtime/contrib/nnpack/convolution.cc @@ -23,8 +23,8 @@ #include #include #include +#include #include -#include #include "nnpack_utils.h" diff --git a/src/runtime/contrib/nnpack/fully_connected.cc b/src/runtime/contrib/nnpack/fully_connected.cc index 8b72eb38e08c..28570026ada3 100644 --- a/src/runtime/contrib/nnpack/fully_connected.cc +++ b/src/runtime/contrib/nnpack/fully_connected.cc @@ -22,8 +22,8 @@ */ #include #include +#include #include -#include #include "nnpack_utils.h" diff --git a/src/runtime/contrib/nnpack/nnpack_utils.h b/src/runtime/contrib/nnpack/nnpack_utils.h index 231309baaa8e..4396ea0bcde6 100644 --- a/src/runtime/contrib/nnpack/nnpack_utils.h +++ b/src/runtime/contrib/nnpack/nnpack_utils.h @@ -25,8 +25,8 @@ #include #include #include +#include #include -#include namespace tvm { namespace contrib { diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index 49bc056dcafb..699f6bbcf376 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -22,8 +22,8 @@ * \brief mt19937 random engine */ #include +#include #include -#include #include #include diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index edcd20883369..2d111bc322ab 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -22,8 +22,8 @@ */ #include #include +#include #include -#include #include diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index dca1ebc6ed83..d977b1a211b0 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -23,8 +23,8 @@ #include "rocblas.h" #include +#include #include -#include namespace tvm { namespace contrib { diff --git a/src/runtime/contrib/tensorrt/tensorrt_logger.h b/src/runtime/contrib/tensorrt/tensorrt_logger.h index 087cb010189c..eb0164210dbb 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_logger.h +++ b/src/runtime/contrib/tensorrt/tensorrt_logger.h @@ -25,7 +25,7 @@ #ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_LOGGER_H_ #define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_LOGGER_H_ -#include +#include #include "NvInfer.h" #include "tensorrt_utils.h" diff --git a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc index 37dc767d31af..0e5e2ce4c4fa 100755 --- a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc +++ b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc @@ -25,6 +25,7 @@ #include +#include #include #include #include diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index b745be33b456..133bb01d7d13 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -22,8 +22,8 @@ */ #include #include +#include #include -#include #include #include diff --git a/src/runtime/crt/Makefile b/src/runtime/crt/Makefile index 0f3e3096e319..d707d0c63b81 100644 --- a/src/runtime/crt/Makefile +++ b/src/runtime/crt/Makefile @@ -45,8 +45,8 @@ QUIET ?= @ CRT_PREFIX = $(wildcard src/crt) INCLUDES ?= -isystem include -iquote $(dir ${CRT_CONFIG}) -CFLAGS += ${INCLUDES} -Werror -g $(EXTRA_CFLAGS) -CXXFLAGS += ${INCLUDES} -std=c++11 -Werror -g $(EXTRA_CXXFLAGS) +CFLAGS += ${INCLUDES} -Werror -g $(EXTRA_CFLAGS) -DDMLC_USE_LOGGING_LIBRARY=\ +CXXFLAGS += ${INCLUDES} -std=c++11 -Werror -g $(EXTRA_CXXFLAGS) -DDMLC_USE_LOGGING_LIBRARY=\ LDFLAGS += -Werror -g $(EXTRA_LDFLAGS) ${BUILD_DIR}/%.o: src/%.c $(CRT_CONFIG) diff --git a/src/runtime/crt/graph_runtime/load_json.c b/src/runtime/crt/graph_runtime/load_json.c index 6de49a3f9789..3d1fb601a355 100644 --- a/src/runtime/crt/graph_runtime/load_json.c +++ b/src/runtime/crt/graph_runtime/load_json.c @@ -173,7 +173,7 @@ char JSONReader_PeekNextNonSpace(JSONReader* reader) { * \param out_str the output string. NULL to merely consume input and discard it. * \param out_str_size Number of bytes available to write starting from out_str. Includes * terminating \0. - * \throw dmlc::Error when next token is not string + * \throw tvm::Error when next token is not string */ int JSONReader_ReadString(JSONReader* reader, char* out_str, size_t out_str_size) { int status = 0; diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 92c398b559d2..32dd1d8020c9 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -24,9 +24,9 @@ #include #include +#include #include #include -#include #include #include diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 6c51e711aef1..7e98acb6fb3e 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -491,7 +491,7 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name, } else if (name == "share_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { const auto& module = args[0].operator Module(); - ICHECK_EQ(module.operator->()->type_key(), "GraphRuntime"); + ICHECK_EQ(module.operator->()->type_key(), std::string("GraphRuntime")); const auto& param_blob = args[1].operator std::string(); dmlc::MemoryStringStream strm(const_cast(¶m_blob)); this->ShareParams(dynamic_cast(*module.operator->()), &strm); diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index 70cebf5afa44..a01c9def5d5d 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -18,8 +18,8 @@ */ #include +#include #include -#include #include #include diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index 994e24b99084..f6a57ff55355 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -22,8 +22,8 @@ #ifdef __ANDROID__ #include #endif +#include #include -#include #include #include diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index e558997b7a4c..02ed7d2541c2 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -20,8 +20,8 @@ #ifndef TVM_RUNTIME_HEXAGON_HEXAGON_MODULE_H_ #define TVM_RUNTIME_HEXAGON_HEXAGON_MODULE_H_ +#include #include -#include #include #include diff --git a/src/runtime/hexagon/sim/hexagon_device_sim.cc b/src/runtime/hexagon/sim/hexagon_device_sim.cc index 6cc7dcf3209f..1d3f0fd1006f 100644 --- a/src/runtime/hexagon/sim/hexagon_device_sim.cc +++ b/src/runtime/hexagon/sim/hexagon_device_sim.cc @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/runtime/hexagon/target/hexagon_dsprpcapi.cc b/src/runtime/hexagon/target/hexagon_dsprpcapi.cc index d494db82e2c7..a089684c4188 100644 --- a/src/runtime/hexagon/target/hexagon_dsprpcapi.cc +++ b/src/runtime/hexagon/target/hexagon_dsprpcapi.cc @@ -22,7 +22,7 @@ #include #include -#include +#include #include "hexagon_target_log.h" diff --git a/src/runtime/hexagon/target/hexagon_dsprpcapi.h b/src/runtime/hexagon/target/hexagon_dsprpcapi.h index c0e40805ecbf..e4711e3da584 100644 --- a/src/runtime/hexagon/target/hexagon_dsprpcapi.h +++ b/src/runtime/hexagon/target/hexagon_dsprpcapi.h @@ -22,7 +22,7 @@ #ifdef __ANDROID__ #include -#include +#include #include "remote.h" #include "remote64.h" diff --git a/src/runtime/hexagon/target/hexagon_stubapi.cc b/src/runtime/hexagon/target/hexagon_stubapi.cc index 5428ae7c1cff..1fb7d942e968 100644 --- a/src/runtime/hexagon/target/hexagon_stubapi.cc +++ b/src/runtime/hexagon/target/hexagon_stubapi.cc @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include "hexagon_target_log.h" diff --git a/src/runtime/hexagon/target/hexagon_stubapi.h b/src/runtime/hexagon/target/hexagon_stubapi.h index cc5b7b7413ca..fba22b10247c 100644 --- a/src/runtime/hexagon/target/hexagon_stubapi.h +++ b/src/runtime/hexagon/target/hexagon_stubapi.h @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include diff --git a/src/runtime/logging.cc b/src/runtime/logging.cc new file mode 100644 index 000000000000..8a44ec04532c --- /dev/null +++ b/src/runtime/logging.cc @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef TVM_BACKTRACE_DISABLED +#include + +// TODO(bkimball,tkonolige) This inline function is to work around a linking error I am having when +// using MSVC If the function definition is in logging.cc then the linker can't find it no matter +// what kind of attributes (dllexport) I decorate it with. This is temporary and will be addressed +// when we get backtrace working on Windows. +namespace tvm { +namespace runtime { +__declspec(dllexport) std::string Backtrace() { return ""; } +} // namespace runtime +} // namespace tvm +#else + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace { + +struct BacktraceInfo { + std::vector lines; + size_t max_size; + std::string error_message; +}; + +void BacktraceCreateErrorCallback(void* data, const char* msg, int errnum) { + std::cerr << "Could not initialize backtrace state: " << msg << std::endl; +} + +backtrace_state* BacktraceCreate() { + return backtrace_create_state(nullptr, 1, BacktraceCreateErrorCallback, nullptr); +} + +static backtrace_state* _bt_state = BacktraceCreate(); + +std::string DemangleName(std::string name) { + int status = 0; + size_t length = name.size(); + std::unique_ptr demangled_name = { + abi::__cxa_demangle(name.c_str(), nullptr, &length, &status), &std::free}; + if (demangled_name && status == 0 && length > 0) { + return demangled_name.get(); + } else { + return name; + } +} + +void BacktraceErrorCallback(void* data, const char* msg, int errnum) { + // do nothing +} + +void BacktraceSyminfoCallback(void* data, uintptr_t pc, const char* symname, uintptr_t symval, + uintptr_t symsize) { + auto str = reinterpret_cast(data); + + if (symname != nullptr) { + std::string tmp(symname, symsize); + *str = DemangleName(tmp.c_str()); + } else { + std::ostringstream s; + s << "0x" << std::setfill('0') << std::setw(sizeof(uintptr_t) * 2) << std::hex << pc; + *str = s.str(); + } +} + +int BacktraceFullCallback(void* data, uintptr_t pc, const char* filename, int lineno, + const char* symbol) { + auto stack_trace = reinterpret_cast(data); + std::stringstream s; + + std::unique_ptr symbol_str = std::make_unique(""); + if (symbol != nullptr) { + *symbol_str = DemangleName(symbol); + } else { + // see if syminfo gives anything + backtrace_syminfo(_bt_state, pc, BacktraceSyminfoCallback, BacktraceErrorCallback, + symbol_str.get()); + } + s << *symbol_str; + + if (filename != nullptr) { + s << std::endl << " at " << filename; + if (lineno != 0) { + s << ":" << lineno; + } + } + // Skip tvm::backtrace and tvm::LogFatal::~LogFatal at the beginning of the trace as they don't + // add anything useful to the backtrace. + if (!(stack_trace->lines.size() == 0 && + (symbol_str->find("tvm::runtime::Backtrace", 0) == 0 || + symbol_str->find("tvm::runtime::detail::LogFatal", 0) == 0))) { + stack_trace->lines.push_back(s.str()); + } + // TVMFuncCall denotes the API boundary so we stop there. Exceptions should be caught there. + if (*symbol_str == "TVMFuncCall" || stack_trace->lines.size() >= stack_trace->max_size) { + return 1; + } + return 0; +} +} // namespace + +std::string Backtrace() { + BacktraceInfo bt; + bt.max_size = 100; + if (_bt_state == nullptr) { + return ""; + } + // libbacktrace eats memory if run on multiple threads at the same time, so we guard against it + static std::mutex m; + std::lock_guard lock(m); + backtrace_full(_bt_state, 0, BacktraceFullCallback, BacktraceErrorCallback, &bt); + + std::ostringstream s; + s << "Stack trace:\n"; + for (size_t i = 0; i < bt.lines.size(); i++) { + s << " " << i << ": " << bt.lines[i] << "\n"; + } + + return s.str(); +} +} // namespace runtime +} // namespace tvm +#endif diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index bd07dbfde9d0..b5d06192396b 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -32,8 +32,8 @@ #import #include #include +#include #include -#include #include #include diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc index 6c0d0c4c40fe..cd916d46971d 100644 --- a/src/runtime/micro/micro_session.cc +++ b/src/runtime/micro/micro_session.cc @@ -25,8 +25,8 @@ #include #include +#include #include -#include #include #include diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index d5c61eccfd6d..3b9772f2fb60 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -46,7 +46,7 @@ #endif #if TVM_MINRPC_ENABLE_LOGGING -#include +#include #endif namespace tvm { diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index d3ddbf8c0229..d46f0868a2ea 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -23,9 +23,9 @@ */ #include #include +#include #include #include -#include #include "runtime_base.h" diff --git a/src/runtime/object.cc b/src/runtime/object.cc index ad68c70698ea..c9a9669671e6 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -20,9 +20,9 @@ * \file src/runtime/object.cc * \brief Object type management system. */ +#include #include #include -#include #include #include diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 2e7f05f91020..3fca368c758b 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -26,8 +26,8 @@ #include #include +#include #include -#include /* There are many OpenCL platforms that do not yet support OpenCL 2.0, * hence we use 1.2 APIs, some of which are now deprecated. In order diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index a65235090bfd..bb5a794a030b 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -22,8 +22,8 @@ * \brief The global registry of packed function. */ #include +#include #include -#include #include #include diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 5f24ce0eec48..5d03374a4571 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -25,9 +25,9 @@ #include #include #include +#include #include #include -#include #include "rocm_common.h" diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 06737f99a4de..cdeeb368f5a2 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -21,8 +21,8 @@ * \file rpc_device_api.cc */ #include +#include #include -#include #include @@ -72,7 +72,7 @@ class RPCDeviceAPI final : public DeviceAPI { auto remote_ctx = RemoveRPCSessionMask(ctx); try { GetSess(ctx)->GetDeviceAPI(remote_ctx)->FreeDataSpace(remote_ctx, space->data); - } catch (const dmlc::Error& e) { + } catch (const Error& e) { // fault tolerance to remote close. } delete space; diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 8716355fd68f..5e2bba88921e 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -526,7 +526,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { try { fconstructor->CallPacked(constructor_args, &con_ret); - } catch (const dmlc::Error& e) { + } catch (const Error& e) { LOG(FATAL) << "Server[" << name_ << "]:" << " Error caught from session constructor " << constructor_name << ":\n" << e.what(); @@ -540,7 +540,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { ICHECK_EQ(tkey, "rpc") << "Constructor " << constructor_name << " to return an RPCModule"; serving_session_ = RPCModuleGetSession(mod); this->ReturnVoid(); - } catch (const std::runtime_error& e) { + } catch (const std::exception& e) { this->ReturnException(e.what()); } @@ -562,7 +562,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { } this->SwitchToState(kRecvPacketNumBytes); }); - } catch (const std::runtime_error& e) { + } catch (const std::exception& e) { this->ReturnException(e.what()); this->SwitchToState(kRecvPacketNumBytes); } @@ -581,7 +581,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { setter(0, rv); this->ReturnPackedSeq(TVMArgs(&ret_value, &ret_tcode, 1)); - } catch (const std::runtime_error& e) { + } catch (const std::exception& e) { this->ReturnException(e.what()); } this->SwitchToState(kRecvPacketNumBytes); @@ -719,7 +719,7 @@ void RPCEndpoint::Shutdown() { writer_.bytes_available()); if (n == 0) break; } - } catch (const dmlc::Error& e) { + } catch (const Error& e) { } channel_.reset(nullptr); } diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 34691415c1a4..46e1be794520 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -130,7 +130,7 @@ class RPCWrappedFunc : public Object { ~RPCWrappedFunc() { try { sess_->FreeHandle(handle_, kTVMPackedFuncHandle); - } catch (const dmlc::Error& e) { + } catch (const Error& e) { // fault tolerance to remote close } } @@ -165,7 +165,7 @@ class RPCModuleNode final : public ModuleNode { if (module_handle_ != nullptr) { try { sess_->FreeHandle(module_handle_, kTVMModuleHandle); - } catch (const dmlc::Error& e) { + } catch (const Error& e) { // fault tolerance to remote close } module_handle_ = nullptr; diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index 0ac5b8dc74ef..2b75018099d5 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -46,7 +46,7 @@ void RPCSession::AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values try { this->CallFunc(func, arg_values, arg_type_codes, num_args, [&callback](TVMArgs args) { callback(RPCCode::kReturn, args); }); - } catch (const std::runtime_error& e) { + } catch (const std::exception& e) { this->SendException(callback, e.what()); } } @@ -60,7 +60,7 @@ void RPCSession::AsyncCopyToRemote(void* local_from_bytes, DLTensor* remote_to, try { this->CopyToRemote(local_from_bytes, remote_to, nbytes); callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); - } catch (const std::runtime_error& e) { + } catch (const std::exception& e) { this->SendException(callback, e.what()); } } @@ -74,7 +74,7 @@ void RPCSession::AsyncCopyFromRemote(DLTensor* remote_from, void* local_to_bytes try { this->CopyFromRemote(remote_from, local_to_bytes, nbytes); callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); - } catch (const std::runtime_error& e) { + } catch (const std::exception& e) { this->SendException(callback, e.what()); } } @@ -88,7 +88,7 @@ void RPCSession::AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, try { this->GetDeviceAPI(ctx)->StreamSync(ctx, stream); callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); - } catch (const std::runtime_error& e) { + } catch (const std::exception& e) { this->SendException(callback, e.what()); } } diff --git a/src/runtime/runtime_base.h b/src/runtime/runtime_base.h index 21601df1ad39..7abb32935a2b 100644 --- a/src/runtime/runtime_base.h +++ b/src/runtime/runtime_base.h @@ -34,7 +34,7 @@ and finishes with API_END() or API_END_HANDLE_ERROR */ #define API_END() \ } \ - catch (std::runtime_error & _except_) { \ + catch (std::exception & _except_) { \ return TVMAPIHandleException(_except_); \ } \ return 0; // NOLINT(*) @@ -45,7 +45,7 @@ */ #define API_END_HANDLE_ERROR(Finalize) \ } \ - catch (std::runtime_error & _except_) { \ + catch (std::exception & _except_) { \ Finalize; \ return TVMAPIHandleException(_except_); \ } \ @@ -56,6 +56,6 @@ * \param e the exception * \return the return value of API after exception is handled */ -int TVMAPIHandleException(const std::runtime_error& e); +int TVMAPIHandleException(const std::exception& e); #endif // TVM_RUNTIME_RUNTIME_BASE_H_ diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index 5f5a811c2d30..cab04ec0db4a 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -24,10 +24,10 @@ #include #include #include +#include #include #include #include -#include #if TVM_THREADPOOL_USE_OPENMP #include #endif diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index 2527f4799086..7f9cfaa8730c 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -21,8 +21,8 @@ * \file threading_backend.cc * \brief Native threading backend */ +#include #include -#include #include #include diff --git a/src/runtime/vm/bytecode.cc b/src/runtime/vm/bytecode.cc index f82d708468f7..09b928fa1e39 100644 --- a/src/runtime/vm/bytecode.cc +++ b/src/runtime/vm/bytecode.cc @@ -22,8 +22,8 @@ * \brief The bytecode for Relay virtual machine. */ +#include #include -#include #include diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 6d121aa67733..4683398b01d4 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -24,10 +24,10 @@ #include #include +#include #include #include #include -#include #include #include diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index 9cd1f257f091..3083ba6f9ce4 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -22,8 +22,8 @@ #include #include +#include #include -#include #include #include diff --git a/src/runtime/vulkan/vulkan_shader.h b/src/runtime/vulkan/vulkan_shader.h index c9fbb13e938d..513e3bccc36e 100644 --- a/src/runtime/vulkan/vulkan_shader.h +++ b/src/runtime/vulkan/vulkan_shader.h @@ -22,8 +22,8 @@ #include #include +#include #include -#include #include diff --git a/src/support/base64.h b/src/support/base64.h index 901922db8edc..3aac9920a075 100644 --- a/src/support/base64.h +++ b/src/support/base64.h @@ -26,7 +26,7 @@ #ifndef TVM_SUPPORT_BASE64_H_ #define TVM_SUPPORT_BASE64_H_ -#include +#include #include #include diff --git a/src/support/parallel_for.cc b/src/support/parallel_for.cc index f4756c29adeb..4ced0df6ddf3 100644 --- a/src/support/parallel_for.cc +++ b/src/support/parallel_for.cc @@ -21,7 +21,7 @@ * \file parallel_for.cc * \brief An implementation to run loop in parallel. */ -#include +#include #include #include diff --git a/src/support/pipe.h b/src/support/pipe.h index 3c1356ba174c..a2803638e1f3 100644 --- a/src/support/pipe.h +++ b/src/support/pipe.h @@ -25,7 +25,7 @@ #define TVM_SUPPORT_PIPE_H_ #include -#include +#include #ifdef _WIN32 #include diff --git a/src/support/socket.h b/src/support/socket.h index 16fba6b58e3d..11060ae8aae1 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -49,7 +49,7 @@ using ssize_t = int; #include #include #endif -#include +#include #include #include diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc index 35bfc8dc2e5b..61dd7024ff05 100644 --- a/src/target/llvm/llvm_common.cc +++ b/src/target/llvm/llvm_common.cc @@ -24,7 +24,7 @@ #include "llvm_common.h" -#include +#include #include #include diff --git a/src/target/target.cc b/src/target/target.cc index b5ca4c38bbb9..55ef5f1a4e24 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -79,7 +79,7 @@ static const TObj* ObjTypeCheck(const ObjectRef& obj, const std::string& expecte std::ostringstream os; os << ": Expects type \"" << expected_type << "\", but gets \"" << obj->GetTypeKey() << "\" for object: " << obj; - throw dmlc::Error(os.str()); + throw Error(os.str()); } return ptr; } @@ -87,7 +87,7 @@ static const TObj* ObjTypeCheck(const ObjectRef& obj, const std::string& expecte static TargetKind GetTargetKind(const String& name) { Optional kind = TargetKind::Get(name); if (!kind.defined()) { - throw dmlc::Error(": Target kind \"" + name + "\" is not defined"); + throw Error(": Target kind \"" + name + "\" is not defined"); } return kind.value(); } @@ -98,10 +98,10 @@ static std::string RemovePrefixDashes(const std::string& s) { for (; n_dashes < len && s[n_dashes] == '-'; ++n_dashes) { } if (n_dashes == 0) { - throw dmlc::Error(": Attribute keys should start with '-', not an attribute key: " + s); + throw Error(": Attribute keys should start with '-', not an attribute key: " + s); } if (n_dashes >= len) { - throw dmlc::Error(": Not an attribute key: " + s); + throw Error(": Not an attribute key: " + s); } return s.substr(n_dashes); } @@ -133,7 +133,7 @@ static int ParseKVPair(const std::string& s, const std::string& s_next, std::str result_k = s.substr(0, pos); result_v = s.substr(pos + 1); if (result_k.empty() || result_v.empty()) { - throw dmlc::Error(": Empty attribute key or value in \"" + s + "\""); + throw Error(": Empty attribute key or value in \"" + s + "\""); } return 1; } else if (!s_next.empty() && s_next[0] != '-') { @@ -163,7 +163,7 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi } os << kv.first; } - throw dmlc::Error(os.str()); + throw Error(os.str()); } return it->second; } @@ -177,14 +177,14 @@ ObjectRef TargetInternal::ParseType(const std::string& str, // Parsing integer int v; if (!(is >> v)) { - throw dmlc::Error(": Cannot parse into type \"Integer\" from string: " + str); + throw Error(": Cannot parse into type \"Integer\" from string: " + str); } return Integer(v); } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing string std::string v; if (!(is >> v)) { - throw dmlc::Error(": Cannot parse into type \"String\" from string: " + str); + throw Error(": Cannot parse into type \"String\" from string: " + str); } return String(v); } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { @@ -197,14 +197,14 @@ ObjectRef TargetInternal::ParseType(const std::string& str, try { ObjectRef parsed = TargetInternal::ParseType(substr, *info.key); result.push_back(parsed); - } catch (const dmlc::Error& e) { + } catch (const Error& e) { std::string index = "[" + std::to_string(result.size()) + "]"; - throw dmlc::Error(index + e.what()); + throw Error(index + e.what()); } } return Array(result); } - throw dmlc::Error(": Unsupported type \"" + info.type_key + "\" for parsing from string: " + str); + throw Error(": Unsupported type \"" + info.type_key + "\" for parsing from string: " + str); } ObjectRef TargetInternal::ParseType(const ObjectRef& obj, @@ -224,15 +224,14 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, } else if (const auto* ptr = obj.as()) { for (const auto& kv : *ptr) { if (!kv.first->IsInstance()) { - throw dmlc::Error(": Target object requires key of dict to be str, but get: " + - kv.first->GetTypeKey()); + throw Error(": Target object requires key of dict to be str, but get: " + + kv.first->GetTypeKey()); } } Map config = GetRef>(ptr); return Target(TargetInternal::FromConfig({config.begin(), config.end()})); } - throw dmlc::Error(": Expect type 'dict' or 'str' to construct Target, but get: " + - obj->GetTypeKey()); + throw Error(": Expect type 'dict' or 'str' to construct Target, but get: " + obj->GetTypeKey()); } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) { // Parsing array const auto* array = ObjTypeCheck(obj, "Array"); @@ -240,9 +239,9 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, for (const ObjectRef& e : *array) { try { result.push_back(TargetInternal::ParseType(e, *info.key)); - } catch (const dmlc::Error& e) { + } catch (const Error& e) { std::string index = '[' + std::to_string(result.size()) + ']'; - throw dmlc::Error(index + e.what()); + throw Error(index + e.what()); } } return Array(result); @@ -254,17 +253,17 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, ObjectRef key, val; try { key = TargetInternal::ParseType(kv.first, *info.key); - } catch (const dmlc::Error& e) { + } catch (const Error& e) { std::ostringstream os; os << "'s key \"" << key << "\"" << e.what(); - throw dmlc::Error(os.str()); + throw Error(os.str()); } try { val = TargetInternal::ParseType(kv.second, *info.val); - } catch (const dmlc::Error& e) { + } catch (const Error& e) { std::ostringstream os; os << "[\"" << key << "\"]" << e.what(); - throw dmlc::Error(os.str()); + throw Error(os.str()); } result[key] = val; } @@ -275,7 +274,7 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, os << ": Parsing type \"" << info.type_key << "\" is not supported for the given object of type \"" << obj->GetTypeKey() << "\". The object is: " << obj; - throw dmlc::Error(os.str()); + throw Error(os.str()); } return obj; } @@ -355,7 +354,7 @@ Target::Target(const String& tag_or_config_or_target_str) { ObjectPtr target; try { target = TargetInternal::FromString(tag_or_config_or_target_str); - } catch (const dmlc::Error& e) { + } catch (const Error& e) { LOG(FATAL) << "ValueError" << e.what() << ". Target creation from string failed: " << tag_or_config_or_target_str; } @@ -366,7 +365,7 @@ Target::Target(const Map& config) { ObjectPtr target; try { target = TargetInternal::FromConfig({config.begin(), config.end()}); - } catch (const dmlc::Error& e) { + } catch (const Error& e) { LOG(FATAL) << "ValueError" << e.what() << ". Target creation from config dict failed: " << config; } @@ -496,7 +495,7 @@ ObjectPtr TargetInternal::FromConfigString(const String& config_str) { "if the python module is properly loaded"; Optional> config = (*loader)(config_str); if (!config.defined()) { - throw dmlc::Error(": Cannot load config dict with python JSON loader"); + throw Error(": Cannot load config dict with python JSON loader"); } return TargetInternal::FromConfig({config.value().begin(), config.value().end()}); } @@ -514,7 +513,7 @@ ObjectPtr TargetInternal::FromRawString(const String& target_str) { } } if (name.empty()) { - throw dmlc::Error(": Cannot parse empty target string"); + throw Error(": Cannot parse empty target string"); } // Create the target config std::unordered_map config = {{"kind", String(name)}}; @@ -525,17 +524,17 @@ ObjectPtr TargetInternal::FromRawString(const String& target_str) { // Parse key-value pair std::string s_next = (iter + 1 < options.size()) ? options[iter + 1] : ""; iter += ParseKVPair(RemovePrefixDashes(options[iter]), s_next, &key, &value); - } catch (const dmlc::Error& e) { - throw dmlc::Error(": Error when parsing target" + std::string(e.what())); + } catch (const Error& e) { + throw Error(": Error when parsing target" + std::string(e.what())); } try { // check if `key` has been used if (config.count(key)) { - throw dmlc::Error(": The key \"" + key + "\" appears more than once"); + throw Error(": The key \"" + key + "\" appears more than once"); } config[key] = TargetInternal::ParseType(value, TargetInternal::FindTypeInfo(kind, key)); - } catch (const dmlc::Error& e) { - throw dmlc::Error(": Error when parsing target[\"" + key + "\"]" + e.what()); + } catch (const Error& e) { + throw Error(": Error when parsing target[\"" + key + "\"]" + e.what()); } } return TargetInternal::FromConfig(config); @@ -554,11 +553,11 @@ ObjectPtr TargetInternal::FromConfig(std::unordered_mapkind = GetTargetKind(GetRef(kind)); config.erase(kKind); } else { - throw dmlc::Error(": Expect type of field \"kind\" is String, but get type: " + - config[kKind]->GetTypeKey()); + throw Error(": Expect type of field \"kind\" is String, but get type: " + + config[kKind]->GetTypeKey()); } } else { - throw dmlc::Error(": Field \"kind\" is not found"); + throw Error(": Field \"kind\" is not found"); } // parse "tag" if (config.count(kTag)) { @@ -566,8 +565,8 @@ ObjectPtr TargetInternal::FromConfig(std::unordered_maptag = GetRef(tag); config.erase(kTag); } else { - throw dmlc::Error(": Expect type of field \"tag\" is String, but get type: " + - config[kTag]->GetTypeKey()); + throw Error(": Expect type of field \"tag\" is String, but get type: " + + config[kTag]->GetTypeKey()); } } else { target->tag = ""; @@ -582,15 +581,15 @@ ObjectPtr TargetInternal::FromConfig(std::unordered_map()) { keys.push_back(GetRef(key)); } else { - throw dmlc::Error( + throw Error( ": Expect 'keys' to be an array of strings, but it " "contains an element of type: " + e->GetTypeKey()); } } } else { - throw dmlc::Error(": Expect type of field \"keys\" is Array, but get type: " + - config[kKeys]->GetTypeKey()); + throw Error(": Expect type of field \"keys\" is Array, but get type: " + + config[kKeys]->GetTypeKey()); } } // add device name @@ -615,8 +614,8 @@ ObjectPtr TargetInternal::FromConfig(std::unordered_mapkind, key); attrs[key] = TargetInternal::ParseType(value, info); - } catch (const dmlc::Error& e) { - throw dmlc::Error(": Error when parsing target[\"" + key + "\"]" + e.what()); + } catch (const Error& e) { + throw Error(": Error when parsing target[\"" + key + "\"]" + e.what()); } } // parse host diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 1f7d18f747ea..9e8595d6809c 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -125,7 +125,7 @@ TEST(IRF, ExprTransform) { try { f(z - 1, 2); LOG(FATAL) << "should fail"; - } catch (dmlc::Error&) { + } catch (Error&) { } } diff --git a/tests/cpp/parallel_for_test.cc b/tests/cpp/parallel_for_test.cc index bf5fe94b83ff..a4549344bd11 100644 --- a/tests/cpp/parallel_for_test.cc +++ b/tests/cpp/parallel_for_test.cc @@ -19,7 +19,7 @@ #include #include -#include +#include #include #include diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index ab51b6c79c83..f5c0de0a50b0 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -131,6 +131,8 @@ # microTVM Virtual Machines "apps/microtvm/reference-vm/zephyr/Vagrantfile", "apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template", + # patch file for libbacktrace + "cmake/modules/libbacktrace_macos.patch", } diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 62e52abefeb4..8b6b39e3df15 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -827,8 +827,8 @@ def test_import_grad(): mod.import_from_std("gradient.rly") -def test_resnet(): - mod, _ = relay.testing.resnet.get_workload() +def test_mlp(): + mod, _ = relay.testing.mlp.get_workload(1) text = mod.astext() parsed_mod = tvm.parser.parse(text) tvm.ir.assert_structural_equal(mod, parsed_mod) @@ -850,8 +850,8 @@ def inline_params(mod, params): return mod -def test_resnet_inlined_params(): - mod, params = relay.testing.resnet.get_workload() +def test_mlp_inlined_params(): + mod, params = relay.testing.mlp.get_workload(1) mod = inline_params(mod, params) mod = relay.transform.InferType()(mod) text = mod.astext() diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 72a243dbbb67..b2ae28649e6a 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -181,11 +181,6 @@ def test_squeezenet(): astext(net) -def test_vgg(): - net, _ = tvm.relay.testing.vgg.get_workload(batch_size=1) - astext(net) - - def test_densenet(): net, _ = tvm.relay.testing.densenet.get_workload(batch_size=1) astext(net) diff --git a/tests/scripts/task_build.sh b/tests/scripts/task_build.sh index d8e35ebd4de3..845b7153ae20 100755 --- a/tests/scripts/task_build.sh +++ b/tests/scripts/task_build.sh @@ -16,4 +16,4 @@ # specific language governing permissions and limitations # under the License. export VTA_HW_PATH=`pwd`/3rdparty/vta-hw -cd $1 && cmake .. && make $2 && cd .. +cd $1 && cmake .. -DCMAKE_BUILD_TYPE=RelWithDebInfo && make $2 && cd .. diff --git a/tutorials/auto_scheduler/tune_network_cuda.py b/tutorials/auto_scheduler/tune_network_cuda.py index 5ed3ceef5ba0..bc88457f94f9 100644 --- a/tutorials/auto_scheduler/tune_network_cuda.py +++ b/tutorials/auto_scheduler/tune_network_cuda.py @@ -252,7 +252,7 @@ def run_tuning(): # The last line also prints the total number of measurement trials, # total time spent on auto-tuning and the id of the next task to tune. # -# There will also be some "dmlc::Error"s and CUDA errors, because the +# There will also be some "tvm::Error"s and CUDA errors, because the # auto-scheduler will try some invalid schedules. # You can safely ignore them if the tuning can continue, because these # errors are isolated from the main process. diff --git a/tutorials/auto_scheduler/tune_network_mali.py b/tutorials/auto_scheduler/tune_network_mali.py index ca1067b27c80..2bce968771e3 100644 --- a/tutorials/auto_scheduler/tune_network_mali.py +++ b/tutorials/auto_scheduler/tune_network_mali.py @@ -329,7 +329,7 @@ def tune_and_evaluate(): # The last line also prints the total number of measurement trials, # total time spent on auto-tuning and the id of the next task to tune. # -# There will also be some "dmlc::Error"s errors, because the +# There will also be some "tvm::Error"s errors, because the # auto-scheduler will try some invalid schedules. # You can safely ignore them if the tuning can continue, because these # errors are isolated from the main process. diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 8526abbbe6ca..2b47c64729e0 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -251,7 +251,7 @@ def run_tuning(): # The last line also prints the total number of measurement trials, # total time spent on auto-tuning and the id of the next task to tune. # -# There will also be some "dmlc::Error"s errors, because the +# There will also be some "tvm::Error"s errors, because the # auto-scheduler will try some invalid schedules. # You can safely ignore them if the tuning can continue, because these # errors are isolated from the main process. diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index b72caad1e3df..12f930f491a5 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -25,11 +25,9 @@ */ // configurations for the dmlc log. -#define DMLC_LOG_CUSTOMIZE 0 -#define DMLC_LOG_STACK_TRACE 0 -#define DMLC_LOG_DEBUG 0 -#define DMLC_LOG_NODATE 1 -#define DMLC_LOG_FATAL_THROW 0 +#define TVM_LOG_DEBUG 0 +#define DMLC_USE_LOGGING_LIBRARY +#define TVM_BACKTRACE_DISABLED 1 #include #include diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 214c1883f874..0b14ef6476d2 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -23,14 +23,12 @@ */ // configurations for the dmlc log. -#define DMLC_LOG_CUSTOMIZE 0 -#define DMLC_LOG_STACK_TRACE 0 -#define DMLC_LOG_DEBUG 0 -#define DMLC_LOG_NODATE 1 -#define DMLC_LOG_FATAL_THROW 0 +#define TVM_LOG_DEBUG 0 +#define DMLC_USE_LOGGING_LIBRARY +#define TVM_BACKTRACE_DISABLED 1 -#include #include +#include #include "src/runtime/c_runtime_api.cc" #include "src/runtime/cpu_device_api.cc" diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 62b87af01774..01e42ef3faa8 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -22,12 +22,10 @@ * \brief WebGPU runtime based on the TVM JS. */ -// configurations for the dmlc log. -#define DMLC_LOG_CUSTOMIZE 0 -#define DMLC_LOG_STACK_TRACE 0 -#define DMLC_LOG_DEBUG 0 -#define DMLC_LOG_NODATE 1 -#define DMLC_LOG_FATAL_THROW 0 +// configurations for tvm logging. +#define TVM_LOG_DEBUG 0 +#define DMLC_USE_LOGGING_LIBRARY +#define TVM_BACKTRACE_DISABLED 1 #include #include @@ -35,12 +33,27 @@ #include #include +#include +#include + #include "../../src/runtime/meta_data.h" #include "../../src/runtime/vulkan/vulkan_shader.h" #include "../../src/runtime/workspace_pool.h" namespace tvm { namespace runtime { +namespace detail { +// Override logging mechanism +void LogFatalImpl(const std::string& file, int lineno, const std::string& message) { + std::cerr << file << ":" << lineno << ": " << message << std::endl; + abort(); +} + +void LogMessageImpl(const std::string& file, int lineno, const std::string& message) { + std::cerr << file << ":" << lineno << ": " << message << std::endl; +} + +} // namespace detail /*! \brief Thread local workspace */ class WebGPUThreadEntry {