diff --git a/src/relay/backend/name_transforms.cc b/src/relay/backend/name_transforms.cc index 4f364b811bcc..a527d38fb84e 100644 --- a/src/relay/backend/name_transforms.cc +++ b/src/relay/backend/name_transforms.cc @@ -29,27 +29,34 @@ namespace tvm { namespace relay { namespace backend { -std::string ToCFunctionStyle(const std::string& original_name) { - ICHECK(!original_name.empty()) << "Function name is empty"; - ICHECK_EQ(original_name.find("TVM"), 0) << "Function not TVM prefixed"; - - int tvm_prefix_length = 3; - std::string function_name("TVM"); +std::string ToCamel(const std::string& original_name) { + std::string camel_name; + camel_name.reserve(original_name.size()); bool new_block = true; - for (const char& symbol : original_name.substr(tvm_prefix_length)) { + for (const char& symbol : original_name) { if (std::isalpha(symbol)) { if (new_block) { - function_name.push_back(std::toupper(symbol)); + camel_name.push_back(std::toupper(symbol)); new_block = false; } else { - function_name.push_back(std::tolower(symbol)); + camel_name.push_back(std::tolower(symbol)); } } else if (symbol == '_') { new_block = true; } } - return function_name; + return camel_name; +} + +std::string ToCFunctionStyle(const std::string& original_name) { + ICHECK(!original_name.empty()) << "Function name is empty"; + ICHECK_EQ(original_name.find("TVM"), 0) << "Function not TVM prefixed"; + + int tvm_prefix_length = 3; + std::string function_prefix("TVM"); + + return function_prefix + ToCamel(original_name.substr(tvm_prefix_length)); } std::string ToCVariableStyle(const std::string& original_name) { @@ -71,6 +78,30 @@ std::string ToCConstantStyle(const std::string& original_name) { return constant_name; } +std::string ToRustStructStyle(const std::string& original_name) { + ICHECK(!original_name.empty()) << "Struct name is empty"; + return ToCamel(original_name); +} + +std::string ToRustMacroStyle(const std::string& original_name) { + ICHECK(!original_name.empty()) << "Macro name is empty"; + + std::string macro_name; + macro_name.resize(original_name.size()); + + std::transform(original_name.begin(), original_name.end(), macro_name.begin(), ::tolower); + return macro_name; +} + +std::string ToRustConstantStyle(const std::string& original_name) { + ICHECK(!original_name.empty()) << "Constant name is empty"; + std::string constant_name; + constant_name.resize(original_name.size()); + + std::transform(original_name.begin(), original_name.end(), constant_name.begin(), ::toupper); + return constant_name; +} + std::string CombineNames(const Array& names) { std::stringstream combine_stream; ICHECK(!names.empty()) << "Name segments empty"; diff --git a/src/relay/backend/name_transforms.h b/src/relay/backend/name_transforms.h index f59280af2222..fab518debc63 100644 --- a/src/relay/backend/name_transforms.h +++ b/src/relay/backend/name_transforms.h @@ -79,6 +79,30 @@ std::string ToCVariableStyle(const std::string& original_name); */ std::string ToCConstantStyle(const std::string& original_name); +/*! + * \brief Transform a name to the Rust struct style assuming it is + * appropriately constructed using the combining functions + * \param name Original name + * \return Transformed function in the Rust struct style + */ +std::string ToRustStructStyle(const std::string& original_name); + +/*! + * \brief Transform a name to the Rust macro style assuming it is + * appropriately constructed using the combining functions + * \param name Original name + * \return Transformed function in the Rust macro style + */ +std::string ToRustMacroStyle(const std::string& original_name); + +/*! + * \brief Transform a name to the Rust constant style assuming it is + * appropriately constructed using the combining functions + * \param name Original name + * \return Transformed function in the Rust constant style + */ +std::string ToRustConstantStyle(const std::string& original_name); + /*! * \brief Combine names together for use as a generated name * \param names Vector of strings to combine diff --git a/tests/cpp/name_transforms_test.cc b/tests/cpp/name_transforms_test.cc index 12a2ce1d0761..7e3cfe1d779c 100644 --- a/tests/cpp/name_transforms_test.cc +++ b/tests/cpp/name_transforms_test.cc @@ -23,15 +23,20 @@ #include #include -using namespace tvm::relay::backend; +namespace tvm { +namespace relay { +namespace backend { + using namespace tvm::runtime; +std::string ToCamel(const std::string& original_name); + TEST(NameTransforms, ToCFunctionStyle) { ASSERT_EQ(ToCFunctionStyle("TVM_Woof"), "TVMWoof"); ASSERT_EQ(ToCFunctionStyle("TVM_woof"), "TVMWoof"); ASSERT_EQ(ToCFunctionStyle("TVM_woof_woof"), "TVMWoofWoof"); ASSERT_EQ(ToCFunctionStyle("TVMGen_woof_woof"), "TVMGenWoofWoof"); - EXPECT_THROW(ToCVariableStyle("Cake_Bakery"), InternalError); // Incorrect prefix + EXPECT_THROW(ToCFunctionStyle("Cake_Bakery"), InternalError); // Incorrect prefix EXPECT_THROW(ToCFunctionStyle(""), InternalError); } @@ -51,6 +56,27 @@ TEST(NameTransforms, ToCConstantStyle) { EXPECT_THROW(ToCConstantStyle(""), InternalError); } +TEST(NameTransforms, ToRustStructStyle) { + ASSERT_EQ(ToRustStructStyle("Woof"), "Woof"); + ASSERT_EQ(ToRustStructStyle("woof"), "Woof"); + ASSERT_EQ(ToRustStructStyle("woof_woof"), "WoofWoof"); + EXPECT_THROW(ToRustStructStyle(""), InternalError); +} + +TEST(NameTransforms, ToRustMacroStyle) { + ASSERT_EQ(ToRustMacroStyle("Woof"), "woof"); + ASSERT_EQ(ToRustMacroStyle("woof"), "woof"); + ASSERT_EQ(ToRustMacroStyle("woof_Woof"), "woof_woof"); + EXPECT_THROW(ToRustMacroStyle(""), InternalError); +} + +TEST(NameTransforms, ToRustConstantStyle) { + ASSERT_EQ(ToRustConstantStyle("Woof"), "WOOF"); + ASSERT_EQ(ToRustConstantStyle("woof"), "WOOF"); + ASSERT_EQ(ToRustConstantStyle("woof_Woof"), "WOOF_WOOF"); + EXPECT_THROW(ToRustConstantStyle(""), InternalError); +} + TEST(NameTransforms, PrefixName) { ASSERT_EQ(PrefixName({"Woof"}), "TVM_Woof"); ASSERT_EQ(PrefixName({"woof"}), "TVM_woof"); @@ -94,3 +120,23 @@ TEST(NameTransforms, CombinedLogic) { ASSERT_EQ(ToCVariableStyle(PrefixName({"Device", "target", "t"})), "tvm_device_target_t"); ASSERT_EQ(ToCVariableStyle(PrefixGeneratedName({"model", "Devices"})), "tvmgen_model_devices"); } + +TEST(NameTransforms, Internal_ToCamel) { + ASSERT_EQ(ToCamel("Woof"), "Woof"); + ASSERT_EQ(ToCamel("woof"), "Woof"); + ASSERT_EQ(ToCamel("woof_woof"), "WoofWoof"); +} + +TEST(NameTransforms, Internal_ToCamel_Allocation) { + std::string woof = "Woof_woof_woof_woof"; + std::string camel = ToCamel(woof); + std::string check; + check.reserve(woof.size()); + + // Check that the pre-allocation happens + ASSERT_EQ(camel.capacity(), check.capacity()); +} + +} // namespace backend +} // namespace relay +} // namespace tvm