From af50bc34e8bdab922564ff5f352fa57e1dedc6bc Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Fri, 16 Oct 2020 13:56:33 -0700 Subject: [PATCH 01/13] Many fixes to get unit tests passing on Windows. CUDA tests passing on Windows. --- CMakeLists.txt | 6 +++ apps/cpp_rpc/CMakeLists.txt | 15 ++++-- cmake/modules/LibInfo.cmake | 1 + cmake/utils/FindLLVM.cmake | 2 +- conda/build-environment.yaml | 1 + python/tvm/contrib/cc.py | 9 ++-- python/tvm/contrib/nvcc.py | 6 +++ .../search_policy/sketch_policy.cc | 2 +- src/support/libinfo.cc | 7 ++- src/target/source/codegen_c.cc | 34 ++++++++---- src/target/source/codegen_c_host.cc | 7 +++ src/target/source/codegen_cuda.cc | 52 +++++++++---------- tests/python/conftest.py | 25 +++++++++ .../{test_common.py => test_tvmc_common.py} | 0 ...auto_scheduler_layout_rewrite_networks.py} | 0 .../test_auto_scheduler_cost_model.py | 13 ++--- tests/python/unittest/test_crt.py | 3 +- .../python/unittest/test_custom_datatypes.py | 17 +++--- tests/python/unittest/test_micro_artifact.py | 4 +- 19 files changed, 141 insertions(+), 63 deletions(-) create mode 100644 tests/python/conftest.py rename tests/python/driver/tvmc/{test_common.py => test_tvmc_common.py} (100%) rename tests/python/relay/{test_auto_scheduler_layout_rewrite.py => test_auto_scheduler_layout_rewrite_networks.py} (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 769a35318d9d..56170c693e3c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -130,6 +130,12 @@ if(MSVC) add_compile_options(/wd4180) # DLL interface warning in c++ add_compile_options(/wd4251) + # destructor was implicitly defined as deleted + add_compile_options(/wd4624) + # unary minus operator applied to unsigned type, result still unsigned + add_compile_options(/wd4146) + # 'inline': used more than once + add_compile_options(/wd4141) else(MSVC) set(WARNING_FLAG -Wall) if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") diff --git a/apps/cpp_rpc/CMakeLists.txt b/apps/cpp_rpc/CMakeLists.txt index ad8ae1488498..ccac53fc3ca0 100644 --- a/apps/cpp_rpc/CMakeLists.txt +++ b/apps/cpp_rpc/CMakeLists.txt @@ -1,4 +1,6 @@ -set(TVM_RPC_SOURCES +cmake_policy(SET CMP0069 NEW) # suppress cmake warning about IPO + +set(TVM_RPC_SOURCES main.cc rpc_env.cc rpc_server.cc @@ -11,7 +13,12 @@ endif() # Set output to same directory as the other TVM libs set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) add_executable(tvm_rpc ${TVM_RPC_SOURCES}) -set_property(TARGET tvm_rpc PROPERTY INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE) + +include(CheckIPOSupported) +check_ipo_supported(RESULT result OUTPUT output) +if(result) + set_property(TARGET tvm_rpc PROPERTY INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE) +endif() if(WIN32) target_compile_definitions(tvm_rpc PUBLIC -DNOMINMAX) @@ -35,5 +42,5 @@ target_include_directories( PUBLIC DLPACK_PATH PUBLIC DMLC_PATH ) - -target_link_libraries(tvm_rpc tvm_runtime) \ No newline at end of file + +target_link_libraries(tvm_rpc tvm_runtime) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index deaa6d9d8362..131dceeb345d 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -75,6 +75,7 @@ function(add_lib_info src_file) TVM_INFO_USE_ARM_COMPUTE_LIB="${USE_ARM_COMPUTE_LIB}" TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME="${USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME}" TVM_INFO_INDEX_DEFAULT_I64="${INDEX_DEFAULT_I64}" + TVM_CXX_COMPILER_PATH="${CMAKE_CXX_COMPILER}" ) endfunction() diff --git a/cmake/utils/FindLLVM.cmake b/cmake/utils/FindLLVM.cmake index b8c5bf815bf5..9fc4df24b813 100644 --- a/cmake/utils/FindLLVM.cmake +++ b/cmake/utils/FindLLVM.cmake @@ -120,7 +120,7 @@ macro(find_llvm use_llvm) string(STRIP ${TVM_LLVM_VERSION} TVM_LLVM_VERSION) # definitions string(REGEX MATCHALL "(^| )-D[A-Za-z0-9_]*" __llvm_defs ${__llvm_cxxflags}) - set(LLVM_DEFINTIIONS "") + set(LLVM_DEFINITIONS "") foreach(__flag IN ITEMS ${__llvm_defs}) string(STRIP "${__flag}" __llvm_def) list(APPEND LLVM_DEFINITIONS "${__llvm_def}") diff --git a/conda/build-environment.yaml b/conda/build-environment.yaml index 31b39bfafcd0..7c7831e25b1b 100644 --- a/conda/build-environment.yaml +++ b/conda/build-environment.yaml @@ -35,3 +35,4 @@ dependencies: - bzip2 - make - scipy + - pillow diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index 9643d9b650fd..21b8d013a28d 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -47,7 +47,7 @@ def create_shared(output, objects, options=None, cc="g++"): ): _linux_compile(output, objects, options, cc, compile_shared=True) elif sys.platform == "win32": - _windows_shared(output, objects, options) + _windows_compile(output, objects, options) else: raise ValueError("Unsupported platform") @@ -71,6 +71,8 @@ def create_executable(output, objects, options=None, cc="g++"): """ if sys.platform == "darwin" or sys.platform.startswith("linux"): _linux_compile(output, objects, options, cc) + elif sys.platform == "win32": + _windows_compile(output, objects, options) else: raise ValueError("Unsupported platform") @@ -212,9 +214,9 @@ def _linux_compile(output, objects, options, compile_cmd="g++", compile_shared=F raise RuntimeError(msg) -def _windows_shared(output, objects, options): +def _windows_compile(output, objects, options): cmd = ["clang"] - cmd += ["-O2", "-flto=full", "-fuse-ld=lld-link"] + cmd += ["-O2", "-v", "-fvisibility=default", "-export-all-symbols"] if output.endswith(".so") or output.endswith(".dll"): cmd += ["-shared"] @@ -240,6 +242,7 @@ def _windows_shared(output, objects, options): ) if proc.returncode != 0: msg = "Compilation error:\n" + msg += " ".join(cmd) + "\n" msg += py_str(out) raise RuntimeError(msg) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 5886760934fb..2a97b0b31d1e 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -89,6 +89,12 @@ def compile_cuda(code, target="ptx", arch=None, options=None, path_target=None): cmd += ["-o", file_target] cmd += [temp_code] + cxx_compiler_path = tvm.support.libinfo().get("TVM_CXX_COMPILER_PATH") + if cxx_compiler_path != "": + # This tells nvcc where to find the c++ compiler just in case it is not in the path. + # On Windows it is not in the path by default. + cmd += ["-ccbin", cxx_compiler_path] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) (out, _) = proc.communicate() diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 91721afdba74..4a4ab18b5eed 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -519,7 +519,7 @@ Array SketchPolicyNode::EvolutionarySearch(const Array& init_popul // auxiliary global variables std::vector pop_scores; std::vector pop_selection_probs; - float max_score = -1e-10; + float max_score = -1e-10f; pop_scores.reserve(population); pop_selection_probs.reserve(population); std::uniform_real_distribution<> dis(0.0, 1.0); diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index c8aa76b9d1f5..0f394f50fe71 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -208,6 +208,10 @@ #define TVM_INFO_INDEX_DEFAULT_I64 "NOT-FOUND" #endif +#ifndef TVM_CXX_COMPILER_PATH +#define TVM_CXX_COMPILER_PATH "" +#endif + namespace tvm { /*! @@ -262,7 +266,8 @@ TVM_DLL Map GetLibInfo() { {"USE_TARGET_ONNX", TVM_INFO_USE_TARGET_ONNX}, {"USE_ARM_COMPUTE_LIB", TVM_INFO_USE_ARM_COMPUTE_LIB}, {"USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME", TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME}, - {"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64}}; + {"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64}, + {"TVM_CXX_COMPILER_PATH", TVM_CXX_COMPILER_PATH}}; return result; } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index af175c7f2208..96aedecf6717 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -82,20 +82,18 @@ void CodeGenC::AddFunction(const PrimFunc& f) { << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); - this->PrintFuncPrefix(); - this->stream << " " << static_cast(global_symbol.value()) << "("; - + std::stringstream arg_stream; for (size_t i = 0; i < f->params.size(); ++i) { tir::Var v = f->params[i]; std::string vid = AllocVarID(v.get()); - if (i != 0) stream << ", "; + if (i != 0) arg_stream << ", "; if (v.dtype().is_handle()) { auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, stream); + PrintStorageScope(it->second, arg_stream); } - PrintType(GetType(v), stream); + PrintType(GetType(v), arg_stream); // Register handle data type // TODO(tvm-team): consider simply keep type info in the // type annotation(via a normalizing rewriting). @@ -106,14 +104,32 @@ void CodeGenC::AddFunction(const PrimFunc& f) { } if (no_alias && restrict_keyword_.length() != 0) { - stream << ' ' << restrict_keyword_; + arg_stream << ' ' << restrict_keyword_; } } else { - PrintType(GetType(v), stream); + PrintType(GetType(v), arg_stream); } - stream << ' ' << vid; + arg_stream << ' ' << vid; } + + stream << "#ifdef _WIN32\n"; + stream << "#undef TVM_DLL\n"; + stream << "#define TVM_DLL __declspec(dllexport)\n"; + stream << "#endif\n"; + this->PrintFuncPrefix(); + this->stream << " " << static_cast(global_symbol.value()) << "("; + stream << arg_stream.str(); + stream << ");\n"; + stream << "#ifdef _WIN32\n"; + stream << "#undef TVM_DLL\n"; + stream << "#define TVM_DLL\n"; + stream << "#endif\n"; + + this->PrintFuncPrefix(); + this->stream << " " << static_cast(global_symbol.value()) << "("; + stream << arg_stream.str(); stream << ") {\n"; + this->PreFunctionBody(f); int func_scope = this->BeginScope(); this->PrintStmt(f->body); diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index bee5441649c5..fadd5f1660fc 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -61,6 +61,13 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) { } void CodeGenCHost::LinkParameters(Map params) { + stream << "#define TVM_EXPORTS\n"; + PrintFuncPrefix(); + stream << " " << tvm::runtime::symbol::tvm_lookup_linked_param + << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, " + << "int* out_ret_tcode, void* resource_handle);\n"; + + stream << "#undef TVM_EXPORTS\n"; PrintFuncPrefix(); stream << " " << tvm::runtime::symbol::tvm_lookup_linked_param << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, " diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index e5547315613f..18d8cd17e868 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -79,6 +79,16 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + decl_stream << "\n#ifdef _WIN32\n"; + decl_stream << " using uint = unsigned int;\n"; + decl_stream << " using uchar = unsigned char;\n"; + decl_stream << " using int64_t = long long;\n"; + decl_stream << " using uint64_t = unsigned long long;\n"; + decl_stream << "#else\n"; + decl_stream << " using int64_t = long;\n"; + decl_stream << " using uint64_t = ulong;\n"; + decl_stream << "#endif\n"; + return CodeGenC::Finish(); } @@ -99,7 +109,7 @@ void CodeGenCUDA::BindThreadIndex(const IterVar& iv) { void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - ICHECK_EQ(lanes, 1) << "do not yet support vector types"; + ICHECK(t.is_scalar()) << "do not yet support vector types"; os << "void*"; return; } @@ -108,7 +118,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) switch (t.bits()) { case 16: enable_fp16_ = true; - if (lanes == 1) { + if (t.is_scalar()) { os << "half"; } else if (lanes <= 8) { // Emit CUDA code to access fp16 vector elements. @@ -136,7 +146,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) fail = true; break; } - if (!fail && (lanes == 1 || t.bits() == 16)) return; + if (!fail && (t.is_scalar() || t.bits() == 16)) return; if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; @@ -154,15 +164,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { - if (t.lanes() != 1) { - os << "u"; - } else { - os << "unsigned "; - } + os << "u"; } switch (t.bits()) { case 1: { - if (t.lanes() == 1) { + if (t.is_scalar()) { os << "int"; return; } else if (t.lanes() == 8) { @@ -179,7 +185,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } } case 4: { - if (t.lanes() == 1) { + if (t.is_scalar()) { os << "int"; return; } else if (t.lanes() == 4) { @@ -220,7 +226,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) enable_int8_ = true; os << "int4"; return; - } else if (!t.is_uint() && t.lanes() == 1) { + } else if (!t.is_uint() && t.is_scalar()) { os << "signed char"; break; } else { @@ -235,22 +241,16 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "int"; break; case 64: { - if (sizeof(long) != 8) { // NOLINT(*) - if (t.lanes() == 1) { - os << "long long"; - break; - } else if (t.lanes() == 2) { - os << "longlong"; - break; - } else { - // No longlong3, longlong4 - LOG(FATAL) << "Cannot convert type " << t << " to CUDA type on a L32 platform"; - break; - } - } else { - os << "long"; - break; + if (t.is_scalar()) { + os << "int64_t"; + } else if (t.lanes() == 2) { + os << "longlong2"; + } else if (t.lanes() == 3) { + os << "longlong3"; + } else if (t.lanes() == 4) { + os << "longlong4"; } + return; } default: fail = true; diff --git a/tests/python/conftest.py b/tests/python/conftest.py new file mode 100644 index 000000000000..5276fa3f3db9 --- /dev/null +++ b/tests/python/conftest.py @@ -0,0 +1,25 @@ +import sys +import tvm + +collect_ignore = [] +if sys.platform.startswith("win"): + collect_ignore.append("frontend/caffe") + collect_ignore.append("frontend/caffe2") + collect_ignore.append("frontend/coreml") + collect_ignore.append("frontend/darknet") + collect_ignore.append("frontend/keras") + collect_ignore.append("frontend/mxnet") + collect_ignore.append("frontend/pytorch") + collect_ignore.append("frontend/tensorflow") + collect_ignore.append("frontend/tflite") + collect_ignore.append("frontend/onnx") + collect_ignore.append("driver/tvmc/test_autoscheduler.py") + collect_ignore.append("unittest/test_auto_scheduler_cost_model.py") # stack overflow + # collect_ignore.append("unittest/test_auto_scheduler_measure.py") # exception ignored + collect_ignore.append("unittest/test_auto_scheduler_search_policy.py") # stack overflow + # collect_ignore.append("unittest/test_auto_scheduler_measure.py") # exception ignored + + collect_ignore.append("unittest/test_tir_intrin.py") + +if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON": + collect_ignore.append("unittest/test_micro_transport.py") diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_tvmc_common.py similarity index 100% rename from tests/python/driver/tvmc/test_common.py rename to tests/python/driver/tvmc/test_tvmc_common.py diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite.py b/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py similarity index 100% rename from tests/python/relay/test_auto_scheduler_layout_rewrite.py rename to tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py diff --git a/tests/python/unittest/test_auto_scheduler_cost_model.py b/tests/python/unittest/test_auto_scheduler_cost_model.py index 36360da45c8d..c348a690afdc 100644 --- a/tests/python/unittest/test_auto_scheduler_cost_model.py +++ b/tests/python/unittest/test_auto_scheduler_cost_model.py @@ -68,14 +68,15 @@ def test_xgb_model(): assert rmse <= 0.3 # test loading a record file - with tempfile.NamedTemporaryFile() as fp: - auto_scheduler.save_records(fp.name, inputs, results) - model.update_from_file(fp.name) + tmpdir = tvm.contrib.util.tempdir() + tmpfile = tmpdir.relpath("test1") + auto_scheduler.save_records(tmpfile, inputs, results) + model.update_from_file(tmpfile) # test model serialization - with tempfile.NamedTemporaryFile() as fp: - model.save(fp.name) - model.load(fp.name) + tmpfile = tmpdir.relpath("test2") + model.save(tmpfile) + model.load(tmpfile) if __name__ == "__main__": diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 4b744b8ee10a..3fd0d1ad42b3 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -19,7 +19,8 @@ import copy import glob import os -import pty +import pytest +pytest.importorskip('pty') import sys import subprocess import textwrap diff --git a/tests/python/unittest/test_custom_datatypes.py b/tests/python/unittest/test_custom_datatypes.py index 6aad93abd510..ee6f77e1ceb9 100644 --- a/tests/python/unittest/test_custom_datatypes.py +++ b/tests/python/unittest/test_custom_datatypes.py @@ -21,7 +21,6 @@ import tvm.topi.testing import numpy as np import pytest -from numpy.random import MT19937, RandomState, SeedSequence from tvm import relay from tvm.relay.testing.layers import batch_norm_infer from tvm.target.datatype import ( @@ -63,11 +62,9 @@ def get_cat_image(dimensions): img = np.transpose(img_bgr, (2, 0, 1))[np.newaxis, :] return np.asarray(img, dtype="float32") - # we use a random seed to generate input_data # to guarantee stable tests -rs = RandomState(MT19937(SeedSequence(123456789))) - +np.random.seed(0) def convert_ndarray(dst_dtype, array): """Converts NDArray(s) into the specified datatype""" @@ -341,7 +338,7 @@ def check_unary_op(op, src_dtype, dst_dtype, shape): t1 = relay.TensorType(shape, src_dtype) x = relay.var("x", t1) z = op(x) - x_data = rs.rand(*shape).astype(t1.dtype) + x_data = np.random.rand(*shape).astype(t1.dtype) module = tvm.IRModule.from_expr(relay.Function([x], z)) @@ -372,8 +369,8 @@ def check_binary_op(opfunc, src_dtype, dst_dtype): x = relay.var("x", t1) y = relay.var("y", t2) z = opfunc(x, y) - x_data = rs.rand(*shape1).astype(t1.dtype) - y_data = rs.rand(*shape2).astype(t2.dtype) + x_data = np.random.rand(*shape1).astype(t1.dtype) + y_data = np.random.rand(*shape2).astype(t2.dtype) module = tvm.IRModule.from_expr(relay.Function([x, y], z)) compare(module, (x_data, y_data), src_dtype, dst_dtype, rtol, atol) @@ -416,8 +413,8 @@ def run_test_conv2d( w = relay.var("w", shape=kshape, dtype=src_dtype) y = relay.nn.conv2d(x, w, padding=padding, dilation=dilation, groups=groups, **attrs) module = tvm.IRModule.from_expr(relay.Function([x, w], y)) - data = rs.uniform(-scale, scale, size=dshape).astype(src_dtype) - kernel = rs.uniform(-scale, scale, size=kshape).astype(src_dtype) + data = np.random.uniform(-scale, scale, size=dshape).astype(src_dtype) + kernel = np.random.uniform(-scale, scale, size=kshape).astype(src_dtype) compare(module, (data, kernel), src_dtype, dst_dtype, rtol, atol) @@ -497,7 +494,7 @@ def run_batchnorm(src_dtype, dst_dtype, rtol=1e-6, atol=1e-6): bn = batch_norm_infer(data=x, epsilon=2e-5, scale=False, name="bn_x") f = relay.Function(relay.analysis.free_vars(bn), bn) - x_data = rs.rand(*shape).astype(t.dtype) + x_data = np.random.rand(*shape).astype(t.dtype) module = tvm.IRModule.from_expr(f) zero_data = np.zeros((32), "float32") diff --git a/tests/python/unittest/test_micro_artifact.py b/tests/python/unittest/test_micro_artifact.py index d757f0956b81..85598f36bce8 100644 --- a/tests/python/unittest/test_micro_artifact.py +++ b/tests/python/unittest/test_micro_artifact.py @@ -17,13 +17,15 @@ """Unit tests for the artifact module.""" +import pytest import json import os import shutil import tvm from tvm.contrib import utils - +pytest.importorskip('tvm.micro') +from tvm.micro import artifact FILE_LIST = ["label1", "label2", "label12", "unlabelled"] From f23b11cbab8162a0647f68b72ff49930483f2de7 Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Tue, 9 Feb 2021 16:27:36 -0800 Subject: [PATCH 02/13] Add support for ushort on Windows --- src/target/source/codegen_cuda.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 18d8cd17e868..e1098e6d2b1f 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -82,6 +82,7 @@ std::string CodeGenCUDA::Finish() { decl_stream << "\n#ifdef _WIN32\n"; decl_stream << " using uint = unsigned int;\n"; decl_stream << " using uchar = unsigned char;\n"; + decl_stream << " using ushort = unsigned short;\n"; decl_stream << " using int64_t = long long;\n"; decl_stream << " using uint64_t = unsigned long long;\n"; decl_stream << "#else\n"; From 5f93ee0d36b5b00f0ac97c81e1443770e140942f Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Fri, 12 Feb 2021 12:18:42 -0800 Subject: [PATCH 03/13] Cleanup to minimial changes to get Windows compile working --- python/tvm/contrib/cc.py | 2 +- src/target/source/codegen_c.cc | 34 ++++++++--------------------- src/target/source/codegen_c_host.cc | 1 + 3 files changed, 11 insertions(+), 26 deletions(-) diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index 21b8d013a28d..59a1d11216ee 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -216,7 +216,7 @@ def _linux_compile(output, objects, options, compile_cmd="g++", compile_shared=F def _windows_compile(output, objects, options): cmd = ["clang"] - cmd += ["-O2", "-v", "-fvisibility=default", "-export-all-symbols"] + cmd += ["-O2"] if output.endswith(".so") or output.endswith(".dll"): cmd += ["-shared"] diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 96aedecf6717..af175c7f2208 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -82,18 +82,20 @@ void CodeGenC::AddFunction(const PrimFunc& f) { << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); - std::stringstream arg_stream; + this->PrintFuncPrefix(); + this->stream << " " << static_cast(global_symbol.value()) << "("; + for (size_t i = 0; i < f->params.size(); ++i) { tir::Var v = f->params[i]; std::string vid = AllocVarID(v.get()); - if (i != 0) arg_stream << ", "; + if (i != 0) stream << ", "; if (v.dtype().is_handle()) { auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, arg_stream); + PrintStorageScope(it->second, stream); } - PrintType(GetType(v), arg_stream); + PrintType(GetType(v), stream); // Register handle data type // TODO(tvm-team): consider simply keep type info in the // type annotation(via a normalizing rewriting). @@ -104,32 +106,14 @@ void CodeGenC::AddFunction(const PrimFunc& f) { } if (no_alias && restrict_keyword_.length() != 0) { - arg_stream << ' ' << restrict_keyword_; + stream << ' ' << restrict_keyword_; } } else { - PrintType(GetType(v), arg_stream); + PrintType(GetType(v), stream); } - arg_stream << ' ' << vid; + stream << ' ' << vid; } - - stream << "#ifdef _WIN32\n"; - stream << "#undef TVM_DLL\n"; - stream << "#define TVM_DLL __declspec(dllexport)\n"; - stream << "#endif\n"; - this->PrintFuncPrefix(); - this->stream << " " << static_cast(global_symbol.value()) << "("; - stream << arg_stream.str(); - stream << ");\n"; - stream << "#ifdef _WIN32\n"; - stream << "#undef TVM_DLL\n"; - stream << "#define TVM_DLL\n"; - stream << "#endif\n"; - - this->PrintFuncPrefix(); - this->stream << " " << static_cast(global_symbol.value()) << "("; - stream << arg_stream.str(); stream << ") {\n"; - this->PreFunctionBody(f); int func_scope = this->BeginScope(); this->PrintStmt(f->body); diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index fadd5f1660fc..e0a6bcdf5a0c 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -44,6 +44,7 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string target_s emit_asserts_ = emit_asserts; declared_globals_.clear(); decl_stream << "// tvm target: " << target_str << "\n"; + decl_stream << "#define TVM_EXPORTS\n"; decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; decl_stream << "#include \n"; From 2d3cf69d712ca9ac9c55efadf9a13c2d5d5d237e Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Fri, 16 Oct 2020 13:56:33 -0700 Subject: [PATCH 04/13] Many fixes to get unit tests passing on Windows. CUDA tests passing on Windows. --- src/target/source/codegen_c.cc | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index af175c7f2208..96aedecf6717 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -82,20 +82,18 @@ void CodeGenC::AddFunction(const PrimFunc& f) { << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); - this->PrintFuncPrefix(); - this->stream << " " << static_cast(global_symbol.value()) << "("; - + std::stringstream arg_stream; for (size_t i = 0; i < f->params.size(); ++i) { tir::Var v = f->params[i]; std::string vid = AllocVarID(v.get()); - if (i != 0) stream << ", "; + if (i != 0) arg_stream << ", "; if (v.dtype().is_handle()) { auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, stream); + PrintStorageScope(it->second, arg_stream); } - PrintType(GetType(v), stream); + PrintType(GetType(v), arg_stream); // Register handle data type // TODO(tvm-team): consider simply keep type info in the // type annotation(via a normalizing rewriting). @@ -106,14 +104,32 @@ void CodeGenC::AddFunction(const PrimFunc& f) { } if (no_alias && restrict_keyword_.length() != 0) { - stream << ' ' << restrict_keyword_; + arg_stream << ' ' << restrict_keyword_; } } else { - PrintType(GetType(v), stream); + PrintType(GetType(v), arg_stream); } - stream << ' ' << vid; + arg_stream << ' ' << vid; } + + stream << "#ifdef _WIN32\n"; + stream << "#undef TVM_DLL\n"; + stream << "#define TVM_DLL __declspec(dllexport)\n"; + stream << "#endif\n"; + this->PrintFuncPrefix(); + this->stream << " " << static_cast(global_symbol.value()) << "("; + stream << arg_stream.str(); + stream << ");\n"; + stream << "#ifdef _WIN32\n"; + stream << "#undef TVM_DLL\n"; + stream << "#define TVM_DLL\n"; + stream << "#endif\n"; + + this->PrintFuncPrefix(); + this->stream << " " << static_cast(global_symbol.value()) << "("; + stream << arg_stream.str(); stream << ") {\n"; + this->PreFunctionBody(f); int func_scope = this->BeginScope(); this->PrintStmt(f->body); From 288cdcf9c2da17a9fb4a15f6b208ebac08e96e64 Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Thu, 11 Feb 2021 13:11:35 -0800 Subject: [PATCH 05/13] format --- tests/python/unittest/test_crt.py | 3 ++- tests/python/unittest/test_custom_datatypes.py | 2 ++ tests/python/unittest/test_micro_artifact.py | 3 ++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 3fd0d1ad42b3..1bd24c931b72 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -20,7 +20,8 @@ import glob import os import pytest -pytest.importorskip('pty') + +pytest.importorskip("pty") import sys import subprocess import textwrap diff --git a/tests/python/unittest/test_custom_datatypes.py b/tests/python/unittest/test_custom_datatypes.py index ee6f77e1ceb9..75e807456981 100644 --- a/tests/python/unittest/test_custom_datatypes.py +++ b/tests/python/unittest/test_custom_datatypes.py @@ -62,10 +62,12 @@ def get_cat_image(dimensions): img = np.transpose(img_bgr, (2, 0, 1))[np.newaxis, :] return np.asarray(img, dtype="float32") + # we use a random seed to generate input_data # to guarantee stable tests np.random.seed(0) + def convert_ndarray(dst_dtype, array): """Converts NDArray(s) into the specified datatype""" x = relay.var("x", shape=array.shape, dtype=str(array.dtype)) diff --git a/tests/python/unittest/test_micro_artifact.py b/tests/python/unittest/test_micro_artifact.py index 85598f36bce8..fc180200720d 100644 --- a/tests/python/unittest/test_micro_artifact.py +++ b/tests/python/unittest/test_micro_artifact.py @@ -24,7 +24,8 @@ import tvm from tvm.contrib import utils -pytest.importorskip('tvm.micro') + +pytest.importorskip("tvm.micro") from tvm.micro import artifact FILE_LIST = ["label1", "label2", "label12", "unlabelled"] From e992f2c3815c5f3c1c762196ce03bf31a5c8b3a2 Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Fri, 12 Feb 2021 12:25:02 -0800 Subject: [PATCH 06/13] Revert changes --- src/target/source/codegen_c.cc | 34 +++++++++------------------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 96aedecf6717..af175c7f2208 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -82,18 +82,20 @@ void CodeGenC::AddFunction(const PrimFunc& f) { << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); - std::stringstream arg_stream; + this->PrintFuncPrefix(); + this->stream << " " << static_cast(global_symbol.value()) << "("; + for (size_t i = 0; i < f->params.size(); ++i) { tir::Var v = f->params[i]; std::string vid = AllocVarID(v.get()); - if (i != 0) arg_stream << ", "; + if (i != 0) stream << ", "; if (v.dtype().is_handle()) { auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, arg_stream); + PrintStorageScope(it->second, stream); } - PrintType(GetType(v), arg_stream); + PrintType(GetType(v), stream); // Register handle data type // TODO(tvm-team): consider simply keep type info in the // type annotation(via a normalizing rewriting). @@ -104,32 +106,14 @@ void CodeGenC::AddFunction(const PrimFunc& f) { } if (no_alias && restrict_keyword_.length() != 0) { - arg_stream << ' ' << restrict_keyword_; + stream << ' ' << restrict_keyword_; } } else { - PrintType(GetType(v), arg_stream); + PrintType(GetType(v), stream); } - arg_stream << ' ' << vid; + stream << ' ' << vid; } - - stream << "#ifdef _WIN32\n"; - stream << "#undef TVM_DLL\n"; - stream << "#define TVM_DLL __declspec(dllexport)\n"; - stream << "#endif\n"; - this->PrintFuncPrefix(); - this->stream << " " << static_cast(global_symbol.value()) << "("; - stream << arg_stream.str(); - stream << ");\n"; - stream << "#ifdef _WIN32\n"; - stream << "#undef TVM_DLL\n"; - stream << "#define TVM_DLL\n"; - stream << "#endif\n"; - - this->PrintFuncPrefix(); - this->stream << " " << static_cast(global_symbol.value()) << "("; - stream << arg_stream.str(); stream << ") {\n"; - this->PreFunctionBody(f); int func_scope = this->BeginScope(); this->PrintStmt(f->body); From 4dca7c26599db3618489c24f19563a0fe333d9b9 Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Fri, 12 Feb 2021 12:30:50 -0800 Subject: [PATCH 07/13] revert --- src/target/source/codegen_c_host.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index e0a6bcdf5a0c..3ec64ed2ace9 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -62,13 +62,6 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) { } void CodeGenCHost::LinkParameters(Map params) { - stream << "#define TVM_EXPORTS\n"; - PrintFuncPrefix(); - stream << " " << tvm::runtime::symbol::tvm_lookup_linked_param - << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, " - << "int* out_ret_tcode, void* resource_handle);\n"; - - stream << "#undef TVM_EXPORTS\n"; PrintFuncPrefix(); stream << " " << tvm::runtime::symbol::tvm_lookup_linked_param << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, " From 6f7a7eead7ab07abcc6392bff5a35a4517c3cf3d Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Sat, 13 Feb 2021 06:17:03 -0800 Subject: [PATCH 08/13] Add license header --- tests/python/conftest.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/conftest.py b/tests/python/conftest.py index 5276fa3f3db9..e8042c8f5095 100644 --- a/tests/python/conftest.py +++ b/tests/python/conftest.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import sys import tvm From 29724e8eb46a6083c015aa10940895aa06286d73 Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Fri, 19 Feb 2021 09:12:04 -0800 Subject: [PATCH 09/13] Fix namespace --- tests/python/unittest/test_auto_scheduler_cost_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_auto_scheduler_cost_model.py b/tests/python/unittest/test_auto_scheduler_cost_model.py index c348a690afdc..0b34615583db 100644 --- a/tests/python/unittest/test_auto_scheduler_cost_model.py +++ b/tests/python/unittest/test_auto_scheduler_cost_model.py @@ -68,7 +68,7 @@ def test_xgb_model(): assert rmse <= 0.3 # test loading a record file - tmpdir = tvm.contrib.util.tempdir() + tmpdir = tvm.contrib.utils.tempdir() tmpfile = tmpdir.relpath("test1") auto_scheduler.save_records(tmpfile, inputs, results) model.update_from_file(tmpfile) From 19e3c9af90c371ac486ab05d93c689912f94c74b Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Fri, 19 Feb 2021 11:37:27 -0800 Subject: [PATCH 10/13] catch any exceptions in ObjectBase destructor to prevent stack overflow --- python/tvm/auto_scheduler/cost_model/xgb_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index aab36c175c3c..3cf65954be7f 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -116,11 +116,13 @@ def __init__( if xgb is None: xgb = __import__("xgboost") except ImportError: + # add "from Node" to silence + # "During handling of the above exception, another exception occurred" raise ImportError( "XGBoost is required for XGBModel. " "Please install its python package first. " "Help: (https://xgboost.readthedocs.io/en/latest/) " - ) + ) from None self.xgb_params = { "max_depth": 10, From 57d81b4ffc0f79beb55e0b982cf4c26eec34eb86 Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Mon, 22 Feb 2021 18:19:42 +0000 Subject: [PATCH 11/13] Fix build error using nvrtc --- src/target/source/codegen_cuda.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index e1098e6d2b1f..1b54140b636e 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -86,8 +86,8 @@ std::string CodeGenCUDA::Finish() { decl_stream << " using int64_t = long long;\n"; decl_stream << " using uint64_t = unsigned long long;\n"; decl_stream << "#else\n"; - decl_stream << " using int64_t = long;\n"; - decl_stream << " using uint64_t = ulong;\n"; + decl_stream << " #define int64_t long\n"; + decl_stream << " #define uint64_t ulong\n"; decl_stream << "#endif\n"; return CodeGenC::Finish(); From fac59dc1b24483d51198b69d02fbbccf8171d21f Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Tue, 23 Feb 2021 18:06:37 +0000 Subject: [PATCH 12/13] Add undefined types for cuda --- src/target/source/codegen_cuda.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 1b54140b636e..0080f2755f62 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -86,6 +86,8 @@ std::string CodeGenCUDA::Finish() { decl_stream << " using int64_t = long long;\n"; decl_stream << " using uint64_t = unsigned long long;\n"; decl_stream << "#else\n"; + decl_stream << " #define uint unsigned int\n"; + decl_stream << " #define uchar unsigned char\n"; decl_stream << " #define int64_t long\n"; decl_stream << " #define uint64_t ulong\n"; decl_stream << "#endif\n"; From 266c44c1c1ec37b43b3b726f6fe8bc4af0ac083f Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Tue, 23 Feb 2021 21:36:52 +0000 Subject: [PATCH 13/13] add ushort --- src/target/source/codegen_cuda.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 0080f2755f62..35b94f55e4e4 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -88,6 +88,7 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#else\n"; decl_stream << " #define uint unsigned int\n"; decl_stream << " #define uchar unsigned char\n"; + decl_stream << " #define ushort unsigned short\n"; decl_stream << " #define int64_t long\n"; decl_stream << " #define uint64_t ulong\n"; decl_stream << "#endif\n";