Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/logging.h>

#include <cstring>
#include <string>
#include <type_traits>

Expand Down Expand Up @@ -71,11 +72,16 @@ class DataType {
* \param code The type code.
* \param bits The number of bits in the type.
* \param lanes The number of lanes.
* \param scalable Whether or not the data type is scalable.
*/
DataType(int code, int bits, int lanes) {
DataType(int code, int bits, int lanes, bool scalable = false) {
data_.code = static_cast<uint8_t>(code);
data_.bits = static_cast<uint8_t>(bits);
data_.lanes = static_cast<uint16_t>(lanes);
if (scalable) {
data_.lanes = static_cast<uint16_t>(-lanes);
} else {
data_.lanes = static_cast<uint16_t>(lanes);
}
if (code == kBFloat) {
ICHECK_EQ(bits, 16);
}
Expand All @@ -90,7 +96,14 @@ class DataType {
/*! \return number of bytes to store each scalar. */
int bytes() const { return (bits() + 7) / 8; }
/*! \return number of lanes in the data. */
int lanes() const { return static_cast<int>(data_.lanes); }
int lanes() const {
int encoded_lanes = static_cast<int16_t>(data_.lanes);
if (is_scalable()) {
return -encoded_lanes;
} else {
return encoded_lanes;
}
}
/*! \return whether type is a scalar type. */
bool is_scalar() const { return lanes() == 1; }
/*! \return whether type is a scalar type. */
Expand All @@ -114,17 +127,28 @@ class DataType {
/*! \return whether type is a handle type. */
bool is_handle() const { return code() == DataType::kHandle && !is_void(); }
/*! \return whether type is a vector type. */
bool is_vector() const { return lanes() > 1; }
bool is_vector() const {
int encoded_lanes = static_cast<int16_t>(data_.lanes);
return encoded_lanes != 0 && encoded_lanes != 1;
}
/*! \return whether type is a bool vector type. */
bool is_vector_bool() const { return is_vector() && bits() == 1; }
/*! \return whether type is a Void type. */
bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; }
/*! \return Whether the type is scalable. */
bool is_scalable() const { return static_cast<int16_t>(data_.lanes) < 0; }
/*!
* \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes.
* \return the result type.
*/
DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); }
/*!
* \brief Create a new scalable data type by changing the lanes to a specified value.
* \param lanes The target number of lanes.
* \return A copy of the old DataType with the number of scalable lanes.
*/
DataType with_scalable_lanes(int lanes) const { return DataType(data_.code, data_.bits, -lanes); }
/*!
* \brief Create a new data type by change bits to a specified value.
* \param bits The target number of bits.
Expand Down Expand Up @@ -247,6 +271,9 @@ class DataType {
* \return Number of bytes needed.
*/
inline int GetVectorBytes(DataType dtype) {
if (dtype.is_scalable()) {
LOG(FATAL) << "Cannot get vector bytes of scalable vector";
}
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
Expand Down Expand Up @@ -357,8 +384,12 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
}
if (t.code == kTVMOpaqueHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);

int16_t lanes = static_cast<int16_t>(t.lanes);
if (lanes > 1) {
os << 'x' << lanes;
} else if (lanes < 0) {
os << 'x' << -lanes << "xvscale";
}
return os;
}
Expand Down Expand Up @@ -424,6 +455,10 @@ inline DLDataType String2DLDataType(std::string s) {
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
}
if (strncmp(endpt, "xvscale", 7) == 0) {
t.lanes = -t.lanes;
endpt = endpt + 7;
}
ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
return t;
}
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,9 @@ Var EnvThread(String thread_tag);
* \param buffer The buffer.
* \param value The value to be stored.
* \param indices The indices location to be stored.
* \param predicate A vector mask of int1 values that prevents storing values on masked-off lanes.
*/
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices, PrimExpr predicate);

/*!
* \brief The prefetch hint for a buffer
Expand Down
14 changes: 10 additions & 4 deletions include/tvm/te/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,23 @@ class Stage : public ObjectRef {
* \param factor The split factor of the loop.
* \param p_outer The result outer domain
* \param p_inner The result inner domain.
* \param disable_predication Whether to not predicate the loop
* \return reference to self.
*/
TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer,
IterVar* p_inner); // NOLINT(*)
TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner,
bool disable_predication = false); // NOLINT(*)
/*!
* \brief Split the iteration with given number of parts.
*
* \param parent The parent domain.
* \param nparts The number of parts in the outer domain.
* \param p_outer The result outer domain.
* \param p_inner The result inner domain.
* \param disable_predication Whether to not predicate the loop
* \return reference to self.
*/
TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer,
IterVar* p_inner); // NOLINT(*)
IterVar* p_inner, bool disable_predication = false); // NOLINT(*)
/*!
* \brief Fuse the inner outer domain to the target
* \param outer The outer domain to be fused.
Expand Down Expand Up @@ -761,13 +763,16 @@ class SplitNode : public IterVarRelationNode {
PrimExpr factor;
/*! \brief Number of parts, only factor or nparts can be given */
PrimExpr nparts;
/*! \brief Whether to disable the predication */
bool disable_predication;

void VisitAttrs(AttrVisitor* v) {
v->Visit("parent", &parent);
v->Visit("outer", &outer);
v->Visit("inner", &inner);
v->Visit("factor", &factor);
v->Visit("nparts", &nparts);
v->Visit("disable_predication", &disable_predication);
}

static constexpr const char* _type_key = "Split";
Expand All @@ -780,7 +785,8 @@ class SplitNode : public IterVarRelationNode {
*/
class Split : public IterVarRelation {
public:
TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts);
TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts,
bool disable_predication);

TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode);
};
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,18 @@ TVM_DLL const Op& anylist_setitem_call_packed();
*/
TVM_DLL const Op& anylist_setitem_call_cpacked();

/*!
* \brief Return the value of vscale
*/
TVM_DLL const Op& vscale();

/*!
* \brief Provide the predicate constructed of the currently active lanes
*
* Calculate the active lane masks given a bound and a current value
*/
TVM_DLL const Op& get_active_lane_mask();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
17 changes: 11 additions & 6 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,23 +630,27 @@ class BufferLoadNode : public PrimExprNode {
Buffer buffer;
/*! \brief The indices location to be loaded. */
Array<PrimExpr> indices;
/*! \brief The buffer predicate */
PrimExpr predicate;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("buffer", &buffer);
v->Visit("indices", &indices);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(buffer, other->buffer) &&
equal(indices, other->indices);
equal(indices, other->indices) && equal(predicate, other->predicate);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(buffer);
hash_reduce(indices);
hash_reduce(predicate);
}

static constexpr const char* _type_key = "tir.BufferLoad";
Expand Down Expand Up @@ -675,7 +679,8 @@ class BufferLoadNode : public PrimExprNode {
*/
class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span());
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices,
PrimExpr predicate = PrimExpr(nullptr), Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};
Expand Down Expand Up @@ -746,7 +751,7 @@ class RampNode : public PrimExprNode {
/*! \brief The stride of each step. */
PrimExpr stride;
/*! \brief Total number of lanes. */
int lanes;
PrimExpr lanes;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
Expand Down Expand Up @@ -778,7 +783,7 @@ class RampNode : public PrimExprNode {
*/
class Ramp : public PrimExpr {
public:
TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span());
TVM_DLL Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode);
};
Expand All @@ -789,7 +794,7 @@ class BroadcastNode : public PrimExprNode {
/*! \brief The base value. */
PrimExpr value;
/*! \brief The number of lanes. */
int lanes;
PrimExpr lanes;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
Expand Down Expand Up @@ -818,7 +823,7 @@ class BroadcastNode : public PrimExprNode {
*/
class Broadcast : public PrimExpr {
public:
TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span());
TVM_DLL Broadcast(PrimExpr value, PrimExpr lanes, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode);
};
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,15 @@ class ScheduleNode : public runtime::Object {
* \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means
* that factor is inferred.
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \param disable_predication If enabled, don't create a predicate for guarding the
* loop. This can be useful when splitting with scalable factors that the schedule writer
* knows are divisible. Warning: enabling this feature may result in incorrect code generation
* if not used carefully.
* \return The new loops after split
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
bool preserve_unit_iters = true) = 0;
bool preserve_unit_iters = true,
bool disable_predication = false) = 0;
/*!
* \brief Reorder a list of loops. It doesn't require the loops to be consecutive.
* It requires:
Expand Down
8 changes: 6 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,23 +231,27 @@ class BufferStoreNode : public StmtNode {
PrimExpr value;
/*! \brief The indices location to be stored. */
Array<PrimExpr> indices;
/*! \brief The predicate for this store. */
PrimExpr predicate;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("value", &value);
v->Visit("indices", &indices);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const {
return equal(buffer, other->buffer) && equal(value, other->value) &&
equal(indices, other->indices);
equal(indices, other->indices) && equal(predicate, other->predicate);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(value);
hash_reduce(indices);
hash_reduce(predicate);
}

static constexpr const char* _type_key = "tir.BufferStore";
Expand All @@ -261,7 +265,7 @@ class BufferStoreNode : public StmtNode {
class BufferStore : public Stmt {
public:
TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
Span span = Span());
PrimExpr predicate = PrimExpr(nullptr), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ def __init__(self, type_str):

arr = type_str.split("x")
head = arr[0]
self.lanes = int(arr[1]) if len(arr) > 1 else 1
if len(arr) == 3 and arr[2] == "vscale":
self.lanes = ctypes.c_uint16(-int(arr[1]))
elif len(arr) > 1:
self.lanes = ctypes.c_uint16(int(arr[1]))
bits = 32

if head.startswith("int"):
Expand Down Expand Up @@ -188,8 +191,11 @@ def __repr__(self):

type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
x = "%s%d" % (type_name, self.bits)
if self.lanes != 1:
lanes_as_int = ctypes.c_int16(self.lanes).value
if lanes_as_int > 1:
x += "x%d" % self.lanes
elif lanes_as_int < 0:
x += "x%dxvscale" % -lanes_as_int
return x

def __eq__(self, other):
Expand Down
Loading