Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 4 additions & 17 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ if(MSVC)
add_definitions(-D_CRT_SECURE_NO_WARNINGS)
add_definitions(-D_SCL_SECURE_NO_WARNINGS)
add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE)
add_definitions(-DHalide_SHARED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj")
Expand Down Expand Up @@ -112,8 +111,8 @@ else(MSVC)
endif(MSVC)

# add source group
FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "3rdparty/HalideIR/src/*.cpp" "nnvm/src/*.cc")
FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h" "3rdparty/HalideIR/src/*.h"
FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "nnvm/src/*.cc")
FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h"
"nnvm/src/*.h" "nnvm/include/*.h")
assign_source_group("Source" ${GROUP_SOURCE})
assign_source_group("Include" ${GROUP_INCLUDE})
Expand All @@ -127,6 +126,7 @@ file(GLOB COMPILER_SRCS
src/lang/*.cc
src/pass/*.cc
src/op/*.cc
src/node/*.cc
src/schedule/*.cc
)

Expand Down Expand Up @@ -154,12 +154,7 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
file(GLOB TOPI_SRCS
topi/src/*.cc
)
file(GLOB_RECURSE HALIDEIR_SRCS
3rdparty/HalideIR/src/base/*.cpp
3rdparty/HalideIR/src/ir/*.cpp
3rdparty/HalideIR/src/tvm/*.cpp
)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})

file(GLOB RUNTIME_SRCS
src/runtime/*.cc
src/runtime/vm/*.cc
Expand Down Expand Up @@ -245,7 +240,6 @@ target_link_libraries(nnvm_compiler tvm)
# Related headers
target_include_directories(
tvm
PUBLIC "3rdparty/HalideIR/src"
PUBLIC "topi/include")
target_include_directories(
tvm_topi
Expand Down Expand Up @@ -294,11 +288,6 @@ if (INSTALL_DEV)
FILES_MATCHING
PATTERN "*.h"
)
install(
DIRECTORY "3rdparty/HalideIR/src/." DESTINATION "include/HalideIR"
FILES_MATCHING
PATTERN "*.h"
)
install(
DIRECTORY "3rdparty/dlpack/include/." DESTINATION "include"
FILES_MATCHING
Expand All @@ -319,8 +308,6 @@ endif(INSTALL_DEV)

# More target definitions
if(MSVC)
target_compile_definitions(tvm PRIVATE -DHalide_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DHalide_EXPORTS)
target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
target_compile_definitions(nnvm_compiler PRIVATE -DNNVM_EXPORTS)
Expand Down
4 changes: 2 additions & 2 deletions apps/howto_deploy/tvm_runtime_pack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);

/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ExprHash, ExprEqual>;
using ExprIntSetMap = std::unordered_map<Expr, IntSet, NodeHash, NodeEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ inline TNodeRef NullValue() {
}

template<>
inline Type NullValue<Type>() {
return Type(Type::Handle, 0, 0);
inline DataType NullValue<DataType>() {
return DataType(kHandle, 0, 0);
}

/*! \brief Error thrown during attribute checking. */
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class Layout : public NodeRef {
if (!this->defined()) return -1;
const auto axes = operator->()->axes;
for (size_t i = 0; i < axes.size(); ++i) {
if (axes[i]->var.get()->name_hint == axis.name()) return static_cast<int32_t>(i);
if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
}
return -1;
}
Expand All @@ -243,7 +243,7 @@ class Layout : public NodeRef {
bool Contains(const LayoutAxis& axis) const {
if (!defined()) return false;
for (const IterVar var : operator->()->axes) {
if (var->var.get()->name_hint == axis.name()) {
if (var->var->name_hint == axis.name()) {
return true;
}
}
Expand Down
246 changes: 246 additions & 0 deletions include/tvm/dtype.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file tvm/dtype.h
* \brief Data type used in IR.
*/
#ifndef TVM_DTYPE_H_
#define TVM_DTYPE_H_

#include "runtime/packed_func.h"

namespace tvm {
class Expr;

/*!
* \brief Primitive data types in tvm.
*/
class DataType {
public:
/*! \brief default constructor */
DataType() {}
/*!
* \brief Constructor
* \param dtype The DLDataType
*/
explicit DataType(DLDataType dtype)
: data_(dtype) {}
/*!
* \brief Constructor
* \param code The type code.
* \param bits The number of bits in the type.
* \param lanes The number of lanes.
*/
DataType(int code, int bits, int lanes) {
data_.code = static_cast<uint8_t>(code);
data_.bits = static_cast<uint8_t>(bits);
data_.lanes = static_cast<uint16_t>(lanes);
}
/*! \return The type code. */
int code() const {
return static_cast<int>(data_.code);
}
/*! \return number of bits in the data. */
int bits() const {
return static_cast<int>(data_.bits);
}
/*! \return number of bytes to store each scalar. */
int bytes() const {
return (bits() + 7) / 8;
}
/*! \return number of lanes in the data. */
int lanes() const {
return static_cast<int>(data_.lanes);
}
/*! \return whether type is a scalar type. */
bool is_scalar() const {
return lanes() == 1;
}
/*! \return whether type is a scalar type. */
bool is_bool() const {
return code() == kDLUInt && bits() == 1;
}
/*! \return whether type is a float type. */
bool is_float() const {
return code() == kDLFloat;
}
/*! \return whether type is an int type. */
bool is_int() const {
return code() == kDLInt;
}
/*! \return whether type is an uint type. */
bool is_uint() const {
return code() == kDLUInt;
}
/*! \return whether type is a handle type. */
bool is_handle() const {
return code() == kHandle;
}
/*! \return whether type is a vector type. */
bool is_vector() const {
return lanes() > 1;
}
/*!
* \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes.
* \return the result type.
*/
DataType with_lanes(int lanes) const {
return DataType(data_.code, data_.bits, lanes);
}
/*!
* \brief Create a new data type by change bits to a specified value.
* \param bits The target number of bits.
* \return the result type.
*/
DataType with_bits(int bits) const {
return DataType(data_.code, bits, data_.lanes);
}
/*!
* \brief Get the scalar version of the type.
* \return the result type.
*/
DataType element_of() const {
return with_lanes(1);
}
// operator overloadings
bool operator==(const DataType& other) const {
return
data_.code == other.data_.code &&
data_.bits == other.data_.bits &&
data_.lanes == other.data_.lanes;
}
bool operator!=(const DataType& other) const {
return !operator==(other);
}
operator DLDataType () const {
return data_;
}
/*! \return the maximum possible value in this format. */
TVM_DLL Expr max() const;
/*! \return the minimum possible value in this format. */
TVM_DLL Expr min() const;

private:
DLDataType data_;
};

/*!
* \brief Construct an int type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes.
* \return The constructed data type.
*/
inline DataType Int(int bits, int lanes = 1) {
return DataType(kDLInt, bits, lanes);
}

/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType UInt(int bits, int lanes = 1) {
return DataType(kDLUInt, bits, lanes);
}

/*!
* \brief Construct a bool type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType Bool(int lanes = 1) {
return UInt(1, lanes);
}

/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType Float(int bits, int lanes = 1) {
return DataType(kDLFloat, bits, lanes);
}

/*!
* \brief Construct a handle type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType Handle(int bits = 64, int lanes = 1) {
return DataType(kHandle, bits, lanes);
}

/*!
* \brief Get the corresponding type of TVMShapeIndex.
* \return The type of TVM shape index.
*/
inline DataType TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
return Int(sizeof(tvm_index_t) * 8);
} else {
return UInt(sizeof(tvm_index_t) * 8);
}
}

/*!
* \brief Convert DLDataType to DataType.
* \param t The original type.
* \return The conversion result.
*/
inline DataType TVMType2Type(DLDataType t) {
return DataType(t.code, t.bits, t.lanes);
}

/*!
* \brief Convert DataType to DataType.
* \param t The original type.
* \return The conversion result.
*/
inline DLDataType Type2TVMType(DataType t) {
return t.operator DLDataType();
}

/*!
* \brief Get the number of bytes needed in a vector.
* \param dtype The data type.
* \return Number of bytes needed.
*/
inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == Bool()) return 1;
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
}

// Overload print function.
inline std::ostream& operator<<(std::ostream& os, DataType dtype) { // NOLINT(*)
using namespace tvm::runtime;
return os << dtype.operator DLDataType();
}

// Backward compatibility
using Type = DataType;
} // namespace tvm
#endif // TVM_DTYPE_H_
Loading