diff --git a/onnxruntime/contrib_ops/contrib_kernels.cc b/onnxruntime/contrib_ops/contrib_kernels.cc index 89b57be80ab1e..5245ce0e46646 100644 --- a/onnxruntime/contrib_ops/contrib_kernels.cc +++ b/onnxruntime/contrib_ops/contrib_kernels.cc @@ -16,7 +16,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding); @@ -43,7 +42,6 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index edff03fbaa900..fb6eadcdae52c 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1089,39 +1089,6 @@ The bounding box coordinates corresponding to the selected indices can then be o updateOutputShape(ctx, 0, input_shape); }); - ONNX_CONTRIB_OPERATOR_SCHEMA(StringNormalizer) - .SetDomain(kMSDomain) - .SinceVersion(1) - .Input(0, "X", "Strings to normalize", "T") - .Output(0, "Y", "Normalized strings", "T") - .TypeConstraint( - "T", - {"tensor(string)"}, - "Input/Output is a string tensor") - .Attr( - "casechangeaction", - "string enum that cases output to be lowercased/uppercases/unchanged. Valid values are \"LOWER\", \"UPPER\", \"NONE\"", - AttributeProto::STRING) - .Attr( - "is_case_sensitive", - "Boolean. Whether the identification of stop words in X is case-sensitive.", - AttributeProto::INT) - .Attr( - "stopwords", - "List of stop words", - AttributeProto::STRINGS, - OPTIONAL) - .Attr( - "locale", - "Environment dependent string that denotes the locale according to which output strings needs to be upper/lowercased. Default en_US", - AttributeProto::STRING, - OPTIONAL) - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type(); - output_elem_type->set_elem_type(ONNX_NAMESPACE::TensorProto::STRING); - }) - .SetDoc(R"DOC([optional] Step1: Remove elements in X if they match any of the stop words so that the output tensor will not contain any stop words. This operator only accepts [C]- and [1, C]-tensors. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C].)DOC"); - ONNX_CONTRIB_OPERATOR_SCHEMA(GatherND) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 3727852897ceb..9d4ddb147a546 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -269,6 +269,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, Where); +// Opset 10 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, StringNormalizer); + void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); @@ -529,6 +532,9 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); + + // Opset 10 + kernel_registry.Register(BuildKernelCreateInfo()); } // Forward declarations of ml op kernels diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/core/providers/cpu/nn/string_normalizer.cc similarity index 91% rename from onnxruntime/contrib_ops/cpu/string_normalizer.cc rename to onnxruntime/core/providers/cpu/nn/string_normalizer.cc index a20eaade10d3b..7ce49bef4a427 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.cc +++ b/onnxruntime/core/providers/cpu/nn/string_normalizer.cc @@ -16,15 +16,13 @@ #include namespace onnxruntime { -namespace contrib { -ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( +ONNX_CPU_OPERATOR_KERNEL( StringNormalizer, - 1, - string, + 10, KernelDefBuilder() .TypeConstraint("T", DataTypeImpl::GetTensorType()), - contrib::StringNormalizer); + StringNormalizer); namespace string_normalizer { const std::string conv_error("Conversion Error"); @@ -157,29 +155,29 @@ using namespace string_normalizer; StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info), is_case_sensitive_(true), - casechangeaction_(NONE), + case_change_action_(NONE), compare_caseaction_(NONE) { int64_t iscasesensitive = 0; Status status = info.GetAttr("is_case_sensitive", &iscasesensitive); ORT_ENFORCE(status.IsOK(), "attribute is_case_sensitive is not set"); is_case_sensitive_ = iscasesensitive != 0; - std::string casechangeaction; - status = info.GetAttr("casechangeaction", &casechangeaction); - ORT_ENFORCE(status.IsOK(), "attribute caseaction is not set"); - if (casechangeaction == "LOWER") { - casechangeaction_ = LOWER; - } else if (casechangeaction == "UPPER") { - casechangeaction_ = UPPER; - } else if (casechangeaction == "NONE") { - casechangeaction_ = NONE; + std::string case_change_action; + status = info.GetAttr("case_change_action", &case_change_action); + ORT_ENFORCE(status.IsOK(), "attribute case_change_action is not set"); + if (case_change_action == "LOWER") { + case_change_action_ = LOWER; + } else if (case_change_action == "UPPER") { + case_change_action_ = UPPER; + } else if (case_change_action == "NONE") { + case_change_action_ = NONE; } else { - ORT_ENFORCE(false, "attribute casechangeaction has invalid value"); + ORT_ENFORCE(false, "attribute case_change_action has invalid value"); } if (!is_case_sensitive_) { // Convert stop words to a case which can help us preserve the case of filtered strings - compare_caseaction_ = (casechangeaction_ == UPPER) ? UPPER : LOWER; + compare_caseaction_ = (case_change_action_ == UPPER) ? UPPER : LOWER; } locale_name_ = info.GetAttrOrDefault("locale", default_locale); @@ -248,10 +246,10 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { ++first; } status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, locale, converter, - N, filtered_strings.size(), casechangeaction_); + N, filtered_strings.size(), case_change_action_); } else { // Nothing to filter. Copy input to output and change case if needed - status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_); + status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, case_change_action_); } } else { if (!wstopwords_.empty()) { @@ -273,7 +271,7 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { } locale.ChangeCase(compare_caseaction_, wstr); if (0 == wstopwords_.count(wstr)) { - if (casechangeaction_ == NONE) { + if (case_change_action_ == NONE) { filtered_orignal_strings.push_back(std::cref(s)); } else { filtered_cased_strings.push_back(converter.to_bytes(wstr)); @@ -281,7 +279,7 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { } ++first; } - if (casechangeaction_ == NONE) { + if (case_change_action_ == NONE) { status = CopyCaseAction(filtered_orignal_strings.cbegin(), filtered_orignal_strings.cend(), ctx, locale, converter, N, filtered_orignal_strings.size(), NONE); } else { @@ -290,10 +288,9 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { } } else { // Nothing to filter. Copy input to output and change case if needed - status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_); + status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, case_change_action_); } } return status; } -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.h b/onnxruntime/core/providers/cpu/nn/string_normalizer.h similarity index 91% rename from onnxruntime/contrib_ops/cpu/string_normalizer.h rename to onnxruntime/core/providers/cpu/nn/string_normalizer.h index 8bc865400f6d4..1f15926060f86 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.h +++ b/onnxruntime/core/providers/cpu/nn/string_normalizer.h @@ -10,7 +10,6 @@ #include namespace onnxruntime { -namespace contrib { class StringNormalizer : public OpKernel { public: @@ -27,7 +26,7 @@ class StringNormalizer : public OpKernel { private: bool is_case_sensitive_; - CaseAction casechangeaction_; + CaseAction case_change_action_; CaseAction compare_caseaction_; // used for case-insensitive compare std::string locale_name_; // Either if these are populated but not both @@ -35,5 +34,4 @@ class StringNormalizer : public OpKernel { std::unordered_set wstopwords_; }; -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 18585041d02ac..b6fbefbddfea6 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -301,18 +301,6 @@ int real_main(int argc, char* argv[], OrtEnv** p_env) { {"atanh_example", "opset 9 not supported yet"}, {"scan_sum", "opset 9 not supported yet"}, {"shrink", "opset 9 not supported yet"}, - {"strnormalizer_export_monday_casesensintive_lower", "opset 10 not supported yet"}, - {"strnormalizer_export_monday_casesensintive_nochangecase", "opset 10 not supported yet"}, - {"strnormalizer_export_monday_casesensintive_upper", "opset 10 not supported yet"}, - {"strnormalizer_export_monday_empty_output", "opset 10 not supported yet"}, - {"strnormalizer_export_monday_insensintive_upper_twodim", "opset 10 not supported yet"}, - {"strnormalizer_nostopwords_nochangecase", "opset 10 not supported yet"}, - {"strnorm_model_monday_casesensintive_lower", "opset 10 not supported yet"}, - {"strnorm_model_monday_casesensintive_nochangecase", "opset 10 not supported yet"}, - {"strnorm_model_monday_casesensintive_upper", "opset 10 not supported yet"}, - {"strnorm_model_monday_empty_output", "opset 10 not supported yet"}, - {"strnorm_model_monday_insensintive_upper_twodim", "opset 10 not supported yet"}, - {"strnorm_model_nostopwords_nochangecase", "opset 10 not supported yet"}, {"cast_DOUBLE_to_FLOAT16", "Cast opset 9 not supported yet"}, {"cast_DOUBLE_to_FLOAT", "Cast opset 9 not supported yet"}, {"cast_FLOAT_to_DOUBLE", "Cast opset 9 not supported yet"}, diff --git a/onnxruntime/test/contrib_ops/string_normalizer_test.cc b/onnxruntime/test/providers/cpu/nn/string_normalizer_test.cc similarity index 88% rename from onnxruntime/test/contrib_ops/string_normalizer_test.cc rename to onnxruntime/test/providers/cpu/nn/string_normalizer_test.cc index e3374f60c0195..3da6533394bfa 100644 --- a/onnxruntime/test/contrib_ops/string_normalizer_test.cc +++ b/onnxruntime/test/providers/cpu/nn/string_normalizer_test.cc @@ -9,8 +9,8 @@ namespace onnxruntime { namespace test { namespace str_normalizer_test { -constexpr const char* domain = onnxruntime::kMSDomain; -const int opset_ver = 1; +constexpr const char* domain = kOnnxDomain; +const int opset_ver = 10; #ifdef _MSC_VER const std::string test_locale("en-US"); @@ -18,12 +18,14 @@ const std::string test_locale("en-US"); const std::string test_locale("en_US.UTF-8"); #endif -void InitTestAttr(OpTester& test, const std::string& casechangeaction, - bool iscasesensitive, +void InitTestAttr(OpTester& test, const std::string& case_change_action, + bool is_case_sensitive, const std::vector& stopwords, const std::string& locale) { - test.AddAttribute("casechangeaction", casechangeaction); - test.AddAttribute("is_case_sensitive", int64_t{iscasesensitive}); + if (!case_change_action.empty()) { + test.AddAttribute("case_change_action", case_change_action); + } + test.AddAttribute("is_case_sensitive", int64_t{is_case_sensitive}); if (!stopwords.empty()) { test.AddAttribute("stopwords", stopwords); } @@ -36,27 +38,12 @@ void InitTestAttr(OpTester& test, const std::string& casechangeaction, using namespace str_normalizer_test; TEST(ContribOpTest, StringNormalizerTest) { - // Test wrong 2 dimensions - // - casesensitive approach - // - no stopwords. - // - No change case action - { - OpTester test("StringNormalizer", opset_ver, domain); - InitTestAttr(test, "NONE", true, {}, test_locale); - std::vector dims{2, 2}; - std::vector input = {std::string("monday"), std::string("tuesday"), std::string("wednesday"), std::string("thursday")}; - test.AddInput("T", dims, input); - std::vector output(input); // do the same for now - test.AddOutput("Y", dims, output); - - test.Run(OpTester::ExpectResult::kExpectFailure, "Input dimensions are either[C > 0] or [1][C > 0] allowed"); - } // - casesensitive approach // - no stopwords. - // - No change case action + // - No change case action, expecting default to take over { OpTester test("StringNormalizer", opset_ver, domain); - InitTestAttr(test, "NONE", true, {}, test_locale); + InitTestAttr(test, "", true, {}, test_locale); std::vector dims{4}; std::vector input = {std::string("monday"), std::string("tuesday"), std::string("wednesday"), std::string("thursday")}; diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 2f2f16afc9902..588f717214e20 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -6,29 +6,39 @@ import unittest import onnx.backend.test +import numpy as np import onnxruntime.backend as c2 pytest_plugins = 'onnx.backend.test.report', -class OrtBackendTest(onnx.backend.test.BackendTest): +class OnnxruntimeBackendTest(onnx.backend.test.BackendTest): - def __init__(self, backend, parent_module=None): - super(OrtBackendTest, self).__init__(backend, parent_module) + def __init__(self, backend, parent_module=None): + onnx.backend.test.BackendTest.__init__(self, backend, parent_module) - @classmethod - def assert_similar_outputs(cls, ref_outputs, outputs, rtol, atol): - # type: (Sequence[Any], Sequence[Any], float, float) -> None + @classmethod + def assert_similar_outputs(cls, ref_outputs, outputs, rtol, atol): + np.testing.assert_equal(len(ref_outputs), len(outputs)) + for i in range(len(outputs)): + np.testing.assert_equal(ref_outputs[i].dtype, outputs[i].dtype) + if ref_outputs[i].dtype == np.object: + np.testing.assert_array_equal(ref_outputs[i], outputs[i]) + else: + np.testing.assert_allclose( + ref_outputs[i], + outputs[i], + rtol=rtol, + atol=atol) - # override the rtol and atol values to match onnx_test_runner tolerances - super(OrtBackendTest, cls).assert_similar_outputs(ref_outputs, outputs, 1e-3, 1e-5) -backend_test = OrtBackendTest(c2, __name__) +backend_test = OnnxruntimeBackendTest(c2, __name__) # Type not supported backend_test.exclude(r'(FLOAT16)') backend_test.exclude(r'(' '^test_cast_DOUBLE_to_FLOAT_cpu.*' +'|^test_gru_seq_length_cpu.*' '|^test_cast_FLOAT_to_DOUBLE_cpu.*' '|^test_cast_FLOAT_to_STRING_cpu.*' '|^test_cast_STRING_to_FLOAT_cpu.*' @@ -58,8 +68,6 @@ def assert_similar_outputs(cls, ref_outputs, outputs, rtol, atol): '|^test_PReLU_3d_cpu.*' '|^test_PReLU_3d_multiparam_cpu.*' '|^test_PoissonNLLLLoss_no_reduce_cpu.*' -'|^test_strnormalizer_*.*' -'|^test_strnorm_*.*' '|^test_Softsign_cpu.*' '|^test_operator_add_broadcast_cpu.*' '|^test_operator_add_size1_broadcast_cpu.*'