Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <string>
#include "ir.h"
#include "base.h"
#include "expr.h"
#include "packed_func_ext.h"

namespace tvm {
Expand Down Expand Up @@ -73,7 +74,6 @@ inline Type NullValue<Type>() {
return Type(Type::Handle, 0, 0);
}


/*! \brief Error thrown during attribute checking. */
struct AttrError : public dmlc::Error {
/*!
Expand Down
46 changes: 46 additions & 0 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ using HalideIR::VarExpr;
using HalideIR::IR::RangeNode;
using HalideIR::IR::FunctionRef;
using HalideIR::IR::FunctionBaseNode;
using HalideIR::Internal::IntImm;
using HalideIR::Internal::Stmt;
using HalideIR::Internal::IRPrinter;
using HalideIR::Internal::Variable;
Expand Down Expand Up @@ -83,6 +84,51 @@ class Var : public HalideIR::VarExpr {
};


/*!
* \brief Container of constant ineteger (IntImm).
*
* This is used to store and automate type check
* attributes that must be constant integer.
*/
class Integer : public Expr {
public:
Integer() : Expr() {}
/*!
* \brief constructor from node.
*/
explicit Integer(NodePtr<Node> node) : Expr(node) {}
/*!
* \brief Construct integer from int value.
*/
Integer(int value) : Expr(value) {} // NOLINT(*)
/*!
* \brief Assign an expression to integer.
* \param other another expression.
*/
Integer& operator=(const Integer& other) {
node_ = other.node_;
return *this;
}
/*!
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImm* operator->() const {
return static_cast<const IntImm*>(node_.get());
}
/*!
* \brief convert to int64_t
*/
operator int64_t() const {
CHECK(node_ != nullptr)
<< " Trying get reference a null Integer";
return (*this)->value;
}
/*! \brief type indicate the container type */
using ContainerType = IntImm;
};


/*! \brief container class of iteration variable. */
class IterVarNode;

Expand Down
17 changes: 17 additions & 0 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <sstream>
#include <string>
#include <memory>
#include <limits>
#include <type_traits>

#include "base.h"
Expand Down Expand Up @@ -126,6 +127,8 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
inline TVMArgValue::operator HalideIR::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Expr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kDLFloat) {
Expand All @@ -145,6 +148,20 @@ inline TVMArgValue::operator HalideIR::Expr() const {
return Expr(sptr);
}

inline TVMArgValue::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Integer(static_cast<int>(value_.v_int64));
}
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<Integer>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<Expr>()
<< " but get " << sptr->type_key();
return Integer(sptr);
}

inline NodePtr<Node>& TVMArgValue::node_sptr() {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return *ptr<NodePtr<Node> >();
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
/*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
IndexExpr size;
IndexExpr axis;
int axis;
double bias;
double alpha;
double beta;
Expand All @@ -340,7 +340,7 @@ struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
/*! \brief Attributes for L2Normalize operator */
struct L2NormalizeAttrs : public tvm::AttrsNode<L2NormalizeAttrs> {
double eps;
Array<IndexExpr> axis;
Array<Integer> axis;

TVM_DECLARE_ATTRS(L2NormalizeAttrs, "relay.attrs.L2NormalizeAttrs") {
TVM_ATTR_FIELD(eps)
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> {

/*! \brief Attributes used in transpose operators */
struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
Array<IndexExpr> axes;
Array<Integer> axes;
TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") {
TVM_ATTR_FIELD(axes)
.describe("The target axes order, reverse order if not specified.");
Expand All @@ -70,10 +70,10 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
}; // struct ReshapeAttrs

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
IndexExpr axis;
Integer axis;

TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<IndexExpr>())
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
.describe("The axis over which to select values.");
}
};
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ struct Expr;
#endif

namespace tvm {
// forward declarations
class Integer;

namespace runtime {
// forward declarations
class TVMArgs;
Expand Down Expand Up @@ -559,6 +562,7 @@ class TVMArgValue : public TVMPODValue_ {
inline bool IsNodeType() const;
inline operator HalideIR::Type() const;
inline operator HalideIR::Expr() const;
inline operator tvm::Integer() const;
// get internal node ptr, if it is node
inline NodePtr<Node>& node_sptr();
};
Expand Down
6 changes: 3 additions & 3 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ TVM_REGISTER_NODE_TYPE(LRNAttrs);

Expr MakeLRN(Expr data,
IndexExpr size,
IndexExpr axis,
int axis,
double alpha,
double beta,
double bias) {
Expand All @@ -337,7 +337,7 @@ TVM_REGISTER_API("relay.op.nn._make.lrn")
});

RELAY_REGISTER_OP("nn.lrn")
.describe(R"code(LRN layer.
.describe(R"code(LRN layer.

Normalize the input in a local region across or within feature maps.
Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta,
Expand All @@ -362,7 +362,7 @@ TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs);

Expr MakeL2Normalize(Expr data,
double eps,
Array<IndexExpr> axis) {
Array<Integer> axis) {
auto attrs = make_node<L2NormalizeAttrs>();
attrs->eps = eps;
attrs->axis = std::move(axis);
Expand Down
23 changes: 11 additions & 12 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,24 +218,23 @@ bool TransposeRel(const Array<Type>& types,
}
const auto* param = attrs.as<TransposeAttrs>();
const int ndim = data->shape.size();
const Array<IndexExpr>& axes = param->axes;
const Array<Integer>& axes = param->axes;
// check dimension match
CHECK(axes.empty() || static_cast<int>(axes.size()) == ndim)
CHECK(!axes.defined() || static_cast<int>(axes.size()) == ndim)
<< "Dimension mismatch: axes has " << axes.size() << " elements"
<< ", but data.ndim = " << ndim;
// construct int_axes
std::vector<int> int_axes;
int_axes.reserve(ndim);
if (axes.empty()) {
// used not defined to check if it is None.
if (!axes.defined()) {
for (int i = ndim - 1; i >= 0; --i) {
int_axes.push_back(i);
}
} else {
std::vector<int> axis_used(ndim, 0);
for (const IndexExpr& e : axes) {
const int64_t *axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr);
int axis = *axis_ptr;
for (const Integer& e : axes) {
int64_t axis = e;
// sanity check for axis and ndim
CHECK(-ndim <= axis && axis < ndim)
<< "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)"
Expand All @@ -245,7 +244,7 @@ bool TransposeRel(const Array<Type>& types,
// sanity check for duplication
CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis;
axis_used[axis] = 1;
int_axes.push_back(axis);
int_axes.push_back(static_cast<int>(axis));
}
}
std::vector<IndexExpr> oshape;
Expand All @@ -258,7 +257,7 @@ bool TransposeRel(const Array<Type>& types,
}

Expr MakeTranspose(Expr data,
Array<IndexExpr> axes) {
Array<Integer> axes) {
auto attrs = make_node<TransposeAttrs>();
attrs->axes = std::move(axes);
static const Op& op = Op::Get("transpose");
Expand Down Expand Up @@ -401,7 +400,7 @@ bool TakeRel(const Array<Type>& types,
std::vector<IndexExpr> oshape;
const auto ndim_data = static_cast<int>(data->shape.size());
const auto ndim_indices = static_cast<int>(indices->shape.size());
auto axis = (*as_const_int(param->axis));
int axis = static_cast<int>(param->axis->value);
if (axis < 0) axis += ndim_data;
CHECK_LE(axis, ndim_data)
<< "axis should be with in data shape"
Expand All @@ -424,9 +423,9 @@ bool TakeRel(const Array<Type>& types,

Expr MakeTake(Expr data,
Expr indices,
IndexExpr axis) {
Integer axis) {
auto attrs = make_node<TakeAttrs>();
attrs->axis = axis;
attrs->axis = std::move(axis);
static const Op& op = Op::Get("take");
return CallNode::make(op, {data, indices}, Attrs(attrs), {});
}
Expand Down