Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,26 @@ class DataType {
* \param bits The number of bits in the type.
* \param lanes The number of lanes.
*/
DataType(int code, int bits, int lanes) {
DataType(int code, int bits, int lanes, bool is_scalable = false) {
data_.code = static_cast<uint8_t>(code);
data_.bits = static_cast<uint8_t>(bits);
data_.lanes = static_cast<uint16_t>(lanes);
is_scalable_ = is_scalable;
if (code == kBFloat) {
ICHECK_EQ(bits, 16);
}
}
// DataType(int code, int bits) {
// data_.code = static_cast<uint8_t>(code);
// data_.bits = static_cast<uint8_t>(bits);
// is_scalable_ = true;
// std::cout<<bits<<std::endl;
// data_.lanes = uint16_t(128) / static_cast<uint16_t>(8); // minimal lanes
//
//// if (code == kBFloat) {
//// ICHECK_EQ(bits, 16);
//// }
// }
Comment on lines +81 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

/*! \return The type code. */
int code() const { return static_cast<int>(data_.code); }
/*! \return number of bits in the data. */
Expand Down Expand Up @@ -107,6 +119,13 @@ class DataType {
bool is_vector_bool() const { return is_vector() && bits() == 1; }
/*! \return whether type is a Void type. */
bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; }
bool is_scalable() const { return is_scalable_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doxygen comments.


DataType with_scalable_lanes() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how this function cooperate with is_scalable_? what happen if is_scalable_ is false but this function get called?

int min_num_lanes = 128 / bits();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

macro? and does here only support vector size 128 or should be configurable between 128 - 2048?

ICHECK(min_num_lanes != 0);
return DataType(data_.code, data_.bits, min_num_lanes, true);
}
/*!
* \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes.
Expand All @@ -131,7 +150,7 @@ class DataType {
*/
bool operator==(const DataType& other) const {
return data_.code == other.data_.code && data_.bits == other.data_.bits &&
data_.lanes == other.data_.lanes;
data_.lanes == other.data_.lanes; // && is_scalable_ == other.is_scalable_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seem like no diff with original code, is the comment for "is_scalable_ == other.is_scalable" intend logic?

}
/*!
* \brief NotEqual comparator.
Expand All @@ -151,21 +170,27 @@ class DataType {
* \param lanes The number of lanes.
* \return The constructed data type.
*/
static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); }
static DataType Int(int bits, int lanes = 1, bool is_scalable = false) {
return DataType(kDLInt, bits, lanes, is_scalable);
}
/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); }
static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) {
return DataType(kDLUInt, bits, lanes, is_scalable);
}
/*!
* \brief Construct an float type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); }
static DataType Float(int bits, int lanes = 1, bool is_scalable = false) {
return DataType(kDLFloat, bits, lanes, is_scalable);
}
/*!
* \brief Construct an bfloat type.
* \param bits The number of bits in the type.
Expand All @@ -178,7 +203,9 @@ class DataType {
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); }
static DataType Bool(int lanes = 1, bool is_scalable = false) {
return DataType::UInt(1, lanes, is_scalable);
}
/*!
* \brief Construct a handle type.
* \param bits The number of bits in the type.
Expand All @@ -204,6 +231,7 @@ class DataType {
}

private:
bool is_scalable_{false};
DLDataType data_;
};

Expand Down Expand Up @@ -285,6 +313,8 @@ inline DLDataType String2DLDataType(std::string s);
*/
inline std::string DLDataType2String(DLDataType t);

inline std::string VLADataType2String(DLDataType t);

// implementation details
inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
switch (static_cast<int>(type_code)) {
Expand Down Expand Up @@ -336,6 +366,17 @@ inline std::string DLDataType2String(DLDataType t) {
return os.str();
}

inline std::string VLADataType2String(DataType t) {
if (t.bits() == 0) return "";
std::ostringstream os;
os << t.operator DLDataType();
// auto const str_to_parse = os.str();
// auto pos = str_to_parse.find("x");
// auto stem= str_to_parse.substr(0, pos);
Comment on lines +373 to +375
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

os << "xVL";
return os.str();
}

inline DLDataType String2DLDataType(std::string s) {
DLDataType t;
// handle void type
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/te/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ class Stage : public ObjectRef {
* \return reference to self.
*/
TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*)

/*!
* \brief Vectorize iteration with a scalable vector length(VL).
* \param var The axis to be vectorized.
* \return reference to self.
*/
TVM_DLL Stage& vectorize_scalable(IterVar var); // NOLINT(*)
/*!
* \brief Replace computation of the current stage by tensor intrinsic f.
* \param var The axis marks beginning of tensorization.
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ class RampNode : public PrimExprNode {
PrimExpr stride;
/*! \brief Total number of lanes. */
int lanes;
bool is_scalable;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
Expand All @@ -779,14 +780,15 @@ class RampNode : public PrimExprNode {

bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) &&
equal(lanes, other->lanes);
equal(lanes, other->lanes) && equal(is_scalable, other->is_scalable);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(base);
hash_reduce(stride);
hash_reduce(lanes);
hash_reduce(is_scalable);
}

static constexpr const char* _type_key = "tir.Ramp";
Expand All @@ -800,6 +802,7 @@ class RampNode : public PrimExprNode {
class Ramp : public PrimExpr {
public:
TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span());
TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, bool is_scalable, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode);
};

Expand Down Expand Up @@ -839,6 +842,7 @@ class BroadcastNode : public PrimExprNode {
class Broadcast : public PrimExpr {
public:
TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span());
TVM_DLL Broadcast(PrimExpr value, int lanes, bool is_scalable, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode);
};

Expand Down
3 changes: 2 additions & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,8 @@ inline PrimExpr make_const(DataType t, ValueType value, Span span) {
if (t.lanes() == 1) {
return MakeConstScalar(t, value, span);
} else {
return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), t.is_scalable(),
span);
}
}

Expand Down
11 changes: 9 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,9 @@ enum class ForKind : int {
* the loop is simply removed and the loop variable is
* mapped to the corresponding context thread.
*/
kThreadBinding = 4
kThreadBinding = 4,
/*! \brief Loop is vectorized but the vector length (VL) is unknown. */
kVectorizedScalable = 5
};

/*!
Expand Down Expand Up @@ -822,6 +824,8 @@ class ForNode : public StmtNode {
* and can be ignored in most passes.
*/
Map<String, ObjectRef> annotations;
bool is_vla;
int stride;

void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_var", &loop_var);
Expand Down Expand Up @@ -862,7 +866,8 @@ class For : public Stmt {
public:
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
Optional<IterVar> thread_binding = NullOpt,
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
Span span = Span(), bool is_vla = false, int stride = 1);

TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode);
Expand Down Expand Up @@ -1369,6 +1374,8 @@ inline const char* ForKind2String(ForKind t) {
return "parallel";
case ForKind::kVectorized:
return "vectorized";
case ForKind::kVectorizedScalable:
return "vectorized_scalable";
case ForKind::kUnrolled:
return "unroll";
case ForKind::kThreadBinding:
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ TVM_DLL Pass LoopPartition();
*/
TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);

/*!
* \brief Lower vectorization loops.
*
* \param enable_vectorize Whether vectorization is enabled.
*
* \return The pass.
*/
TVM_DLL Pass VectorizeLoopScalable(bool enable_vectorize = true);

/*!
* \brief Inject virtual thread loops.
*
Expand Down
8 changes: 7 additions & 1 deletion include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,11 @@ enum IterVarType : int {
/*!
* \brief Marks boundary of tensorization intrinsic.
*/
kTensorized = 8
kTensorized = 8,
/*!
* \brief The loop is vectorized with a scalable vector length
*/
kVectorizedScalable = 9
};

/*!
Expand Down Expand Up @@ -324,6 +328,8 @@ inline const char* IterVarType2String(IterVarType t) {
return "Parallelized";
case kTensorized:
return "Tensorized";
case kVectorizedScalable:
return "VectorizedScalable";
}
return "Unknown";
}
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def apply(
elif ann == "vec":
if vec_size and axis_lens[i] not in vec_size:
cfg.raise_error("Wrong size of lanes in vectorization")
sch[op].vectorize(axes[i])
sch[op].vectorize_scalable(axes[i])
elif ann == "blockIdx.x":
sch[op].bind(axes[i], thread_axis("blockIdx.x"))
elif ann == "blockIdx.y":
Expand Down
1 change: 1 addition & 0 deletions python/tvm/te/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"unroll": ForKind.UNROLLED,
"parallel": ForKind.PARALLEL,
"vectorize": ForKind.VECTORIZED,
"vectorize_scalable": ForKind.VECTORIZED_SCALABLE,
"const_range": (ForKind.UNROLLED,),
}

Expand Down
12 changes: 11 additions & 1 deletion python/tvm/te/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,20 @@ def vectorize(self, var):
Parameters
----------
var : IterVar
The iteration to be vectorize
The iteration to be vectorized
"""
_ffi_api.StageVectorize(self, var)

def vectorize_scalable(self, var):
"""Vectorize the iteration.

Parameters
----------
var : IterVar
The iteration to be vectorized
"""
_ffi_api.StageVectorizeScalable(self, var)

def tensorize(self, var, tensor_intrin):
"""Tensorize the computation enclosed by var with tensor_intrin

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def _exit_cb():
kind_id = _stmt.ForKind.VECTORIZED
elif kind == "unroll":
kind_id = _stmt.ForKind.UNROLLED
elif kind == "vectorize_scalable":
kind_id = _stmt.ForKind.VECTORIZED_SCALABLE
else:
raise ValueError("Unknown kind")
self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq()))
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class ForKind(IntEnum):
VECTORIZED = 2
UNROLLED = 3
THREAD_BINDING = 4
VECTORIZED_SCALABLE = 5


@tvm._ffi.register_object("tir.For")
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,23 @@ def VectorizeLoop(enable_vectorize: bool = True):
return _ffi_api.VectorizeLoop(enable_vectorize) # type: ignore


def VectorizeLoopScalable(enable_vectorize=True):
"""Lower vectorization loops.

Parameters
----------
enable_vectorize : bool
Whether vectorization is enabled.
Will lower to scalar loop when it is turned off.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.VectorizeLoopScalable(enable_vectorize)


def InjectVirtualThread():
"""Inject virtual thread loops.

Expand Down
3 changes: 3 additions & 0 deletions src/autotvm/feature_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) {
case ForKind::kVectorized:
ann = kVectorized;
break;
case ForKind::kVectorizedScalable:
ann = kVectorizedScalable;
break;
case ForKind::kSerial:
ann = kSerial;
break;
Expand Down
1 change: 1 addition & 0 deletions src/autotvm/feature_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ enum AnnotationType {
kThreadZ,
kUnrolled,
kVectorized,
kVectorizedScalable,
kParallel,
kSerial,
kVirtualThread,
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for
pass_list.push_back(tir::transform::LoopPartition());
}

pass_list.push_back(tir::transform::VectorizeLoopScalable(!disable_vectorize));
pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
pass_list.push_back(tir::transform::InjectVirtualThread());
pass_list.push_back(tir::transform::InjectDoubleBuffer());
Expand Down
Loading