diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 01baaa8d13a2..7d1f315b3cb3 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -210,6 +210,73 @@ class PVar : public Pattern> { mutable bool filled_{false}; }; +/*! + * \brief Wrapper for pattern variable container with extra match logic. + * + * \tparam Derived the type of derived class. + * \tparam T the type of the hole. + */ +template +class PVarWithCheck : public arith::Pattern> { + public: + // Store by reference in the expression. + using Nested = const PVarWithCheck&; + + void InitMatch_() const { pvar_.InitMatch_(); } + + bool Match_(const T& value) const { + if (!static_cast(this)->Match_(value)) return false; + return pvar_.Match_(value); + } + + template ::value>::type> + bool Match_(const NodeRefType& value) const { + if (const auto* ptr = value.template as()) { + return Match_(GetRef(ptr)); + } else { + return false; + } + } + + T Eval() const { return pvar_.Eval(); } + + protected: + arith::PVar pvar_; +}; + +/*! + * \brief Pattern variable container with expr type check. + * + * \tparam T the type of the hole. + * \tparam DType the Pattern type of dtype. + */ +template ::value>> +class PVarWithDataType : public PVarWithCheck, T> { + public: + explicit PVarWithDataType(const DType& dtype) : dtype_(dtype) {} + + bool Match_(const T& value) const { return dtype_.Match_(value->dtype); } + + protected: + typename DType::Nested dtype_; +}; + +/*! + * \brief Pattern variable container for data type with lanes. + */ +class PVecDataType : public PVarWithCheck { + public: + /*! \brief construct vector dtype placeholder with element type check */ + explicit PVecDataType(const DataType& elem_dtype) : elem_dtype_(elem_dtype) {} + + bool Match_(const DataType& dtype) const { return dtype.code() == elem_dtype_.code(); } + + protected: + DataType elem_dtype_; +}; + /*! * \brief Constant Pattern variable container. * @@ -467,7 +534,7 @@ class PCastExpr : public Pattern> { /*! * \brief Construct a cast pattern. * - * \param dtype The target data type, can be PVar or PConst. + * \param dtype The target data type, can be PVar or PConst. * \param value The input type. * * \return The result pattern. diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 4194c760628a..2e386c48b75c 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -138,3 +138,25 @@ TEST(Pattern, IntImm) { // cannot match tx + 1 to v ICHECK(!(v * c).Match((tx + 1) * 3)); } + +TEST(Pattern, MatchWithType) { + using namespace tvm; + // match expr with specified dtype + arith::PVarWithDataType> pat(DataType::Float(32)); + tir::Var x("x", DataType::Float(32)); + tir::Var y("y", DataType::Float(32)); + tir::Var x_int("x", DataType::Int(32)); + tir::Var y_int("y", DataType::Int(32)); + ICHECK(pat.Match(x + y * 2.0f)); + ICHECK(!pat.Match(x_int + y_int * 2)); + + // match vectorized expr with specified element dtype + arith::PVecDataType vec_ty(DataType::Float(32)); + arith::PVarWithDataType vpat(vec_ty); + tir::Var vx = tir::Var("x", DataType::Float(32, 8)); + tir::Var vy("y", DataType::Float(32, 8)); + tir::Var vx_int("x", DataType::Int(32, 8)); + tir::Var vy_int("y", DataType::Int(32, 8)); + ICHECK(vpat.Match(vx + vy * tir::Broadcast(2.0f, 8))); + ICHECK(!vpat.Match(vx_int + vy_int * tir::Broadcast(2, 8))); +}