From 39b302eb73b9e55c672d2db7056eb2dad2ad71ed Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 29 Nov 2018 18:17:19 -0800 Subject: [PATCH 01/13] Imlpement StringNormalizer TODO: Fix empty output case. --- onnxruntime/contrib_ops/contrib_ops.cc | 36 ++++ .../contrib_ops/cpu/string_normalizer.cc | 203 ++++++++++++++++++ .../contrib_ops/cpu/string_normalizer.h | 35 +++ .../contrib_ops/string_normalizer_test.cc | 143 ++++++++++++ 4 files changed, 417 insertions(+) create mode 100644 onnxruntime/contrib_ops/cpu/string_normalizer.cc create mode 100644 onnxruntime/contrib_ops/cpu/string_normalizer.h create mode 100644 onnxruntime/test/contrib_ops/string_normalizer_test.cc diff --git a/onnxruntime/contrib_ops/contrib_ops.cc b/onnxruntime/contrib_ops/contrib_ops.cc index 222b183dc6392..8f91b09a923d9 100644 --- a/onnxruntime/contrib_ops/contrib_ops.cc +++ b/onnxruntime/contrib_ops/contrib_ops.cc @@ -452,6 +452,40 @@ The bounding box coordinates corresponding to the selected indices can then be o ->set_dim_value(1); } }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(StringNormalizer) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "X", "Strings to normalize", "T") + .Output(0, "Y", "Normalized strings", "T") + .TypeConstraint( + "T", + {"tensor(string)"}, + "Input/Output is a string tensor") + .Attr( + "casechangeaction", + "string enum that cases output to be lowercased/uppercases/unchanged. Valid values are \"LOWER\", \"UPPER\", \"NONE\"", + AttributeProto::STRING) + .Attr( + "iscasesensitive", + "Boolean. Whether the identification of stop words in X is case-sensitive.", + AttributeProto::INT) + .Attr( + "stopwords", + "List of stop words", + AttributeProto::STRINGS, + OPTIONAL) + .Attr( + "locale", + "Platform dependent string that denotes the locale according to which output strings needs to be upper/lowercased. Default en_US", + AttributeProto::STRING, + "en_US") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type(); + output_elem_type->set_elem_type(ONNX_NAMESPACE::TensorProto::STRING); + }) + .SetDoc(R"DOC([optional] Step1: Remove elements in X if they matches any of stop words so that output tensor may not contain any stop word. This operator only accepts [C]- and [1, C]-tensor. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C]. +[optional] Step2: Lower all characters (if action is LOWER) in X or capitalize them (when action is UPPER))DOC"); } class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp); @@ -461,6 +495,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression); void RegisterContribKernels(std::function fn) { @@ -474,6 +509,7 @@ void RegisterContribKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); + fn(BuildKernel()); fn(BuildKernel()); } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/contrib_ops/cpu/string_normalizer.cc new file mode 100644 index 0000000000000..a2c1b77d3ad22 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.cc @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "string_normalizer.h" +#include "onnx/defs/schema.h" +#include "core/common/common.h" +#include "core/framework/tensor.h" + +#include +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + StringNormalizer, + 1, + string, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + contrib::StringNormalizer); + +namespace string_normalizer { +const std::string conv_error("Conversion Error"); +const std::wstring wconv_error(L"Conversion Error"); +// performs tolower/toupper in-place +void ChangeCase(const std::locale& loc, StringNormalizer::CaseAction caseaction, + std::wstring& wstr) { + if (caseaction == StringNormalizer::LOWER) { + std::transform(wstr.begin(), wstr.end(), wstr.begin(), + [&loc](wchar_t ch) { return std::tolower(ch, loc); }); + } else { + std::transform(wstr.begin(), wstr.end(), wstr.begin(), + [&loc](wchar_t ch) { return std::toupper(ch, loc); }); + } +} + +template +Status CopyCaseAction(ForwardIter first, ForwardIter end, OpKernelContext* ctx, + const std::locale& loc, + std::wstring_convert>& converter, + size_t N, size_t C, + StringNormalizer::CaseAction caseaction) { + std::vector output_dims; + if (N == 1) { + output_dims.push_back(1); + } + + // Empty output case + if (C == 0) { + output_dims.push_back(1); + TensorShape output_shape(output_dims); + ctx->Output(0, output_shape); + return Status::OK(); + } + + output_dims.push_back(C); + + TensorShape output_shape(output_dims); + auto output_tensor = ctx->Output(0, output_shape); + auto const output_data = output_tensor->template MutableData(); + + size_t output_idx = 0; + while (first != end) { + const std::string& s = *first; + if (caseaction == StringNormalizer::LOWER || caseaction == StringNormalizer::UPPER) { + std::wstring wstr = converter.from_bytes(s); + if (wstr == wconv_error) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input contains invalid utf8 chars at: " + s); + } + // In place transform + ChangeCase(loc, caseaction, wstr); + new (output_data + output_idx) std::string(converter.to_bytes(wstr)); + } else { + // Simple copy + new (output_data + output_idx) std::string(s); + } + ++output_idx; + ++first; + } + return Status::OK(); +} +} // namespace string_normalizer + +StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info) { + std::string casechangeaction; + auto status = info.GetAttr("casechangeaction", &casechangeaction); + ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute caseaction is not set"); + if (casechangeaction == "LOWER") { + casechangeaction_ = LOWER; + } else if (casechangeaction == "UPPER") { + casechangeaction_ = UPPER; + } else if (casechangeaction == "NONE") { + casechangeaction_ = NONE; + } else { + ONNXRUNTIME_ENFORCE(false, "attribute casechangeaction has invalid value"); + } + int64_t iscasesensitive = 0; + status = info.GetAttr("iscasesensitive", &iscasesensitive); + ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute iscasesensitive is not set"); + iscasesensitive_ = iscasesensitive != 0; + + info.GetAttrs("stopwords", stopwords_); + // Default is specified in the schema + status = info.GetAttr("locale", &locale_); + ONNXRUNTIME_ENFORCE(status.IsOK(), "Failed to get locale"); +} + +Status StringNormalizer::Compute(OpKernelContext* ctx) const { + using namespace string_normalizer; + + auto X = ctx->Input(0); + auto& input_dims = X->Shape().GetDims(); + + size_t N = 0; + size_t C = 0; + if (input_dims.size() == 1) { + if (input_dims[0] < 1) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Single dimension value must be greater than 0"); + } + C = input_dims[0]; + } else if (input_dims.size() == 2) { + if (input_dims[0] != 1 || input_dims[1] < 1) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input dimensions are either[C > 0] or [1][C > 0] allowed"); + } + N = 1; + C = input_dims[1]; + } else { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input dimensions are either[C > 0] or [1][C > 0] allowed"); + } + + Status status; + std::locale loc(locale_); + std::wstring_convert> converter(conv_error, wconv_error); + auto const input_data = X->template Data(); + + using StrRef = std::reference_wrapper; + if (iscasesensitive_) { + if (!stopwords_.empty()) { + // Create a filter and create filtered output + std::unordered_set, + std::equal_to> + swords; + std::transform(stopwords_.begin(), stopwords_.end(), std::inserter(swords, swords.end()), + [](const std::string& s) { return std::cref(s); }); + + std::vector filtered_strings; + filtered_strings.reserve(C); + for (size_t input_idx = 0; input_idx < C; ++input_idx) { + const std::string& s = *(input_data + input_idx); + if (0 == swords.count(s)) { + filtered_strings.push_back(std::cref(s)); + } + } + status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, loc, converter, + N, filtered_strings.size(), casechangeaction_); + } else { + // Nothing to filter. Copy input to output and change case if needed + status = CopyCaseAction(input_data, input_data + C, ctx, loc, converter, N, C, casechangeaction_); + } + } else { + if (!stopwords_.empty()) { + // Perform case-insensitive comparison. Convert to lowercase for NONE, LOWER and UPPER otherwise. + const CaseAction ca = (casechangeaction_ == UPPER) ? UPPER : LOWER; + std::unordered_set swords; + std::transform(stopwords_.begin(), stopwords_.end(), std::inserter(swords, swords.end()), + [&loc, &converter, ca](const std::string& s) { + std::wstring wstr = converter.from_bytes(s); + ChangeCase(loc, ca, wstr); + return wstr; + }); + + // Filter input. We choose to undergo conversion twice (if needed) + // as oppose to preserve lower/uppercased strings to favor lower memory + // consumption. + std::vector filtered_strings; + filtered_strings.reserve(C); + for (size_t input_idx = 0; input_idx < C; ++input_idx) { + const std::string& s = *(input_data + input_idx); + std::wstring wstr = converter.from_bytes(s); + ChangeCase(loc, ca, wstr); + if (0 == swords.count(wstr)) { + filtered_strings.push_back(std::cref(s)); + } + } + status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, loc, converter, + N, filtered_strings.size(), casechangeaction_); + } else { + // Nothing to filter. Copy input to output and change case if needed + status = CopyCaseAction(input_data, input_data + C, ctx, loc, converter, N, C, casechangeaction_); + } + } + return status; +} +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.h b/onnxruntime/contrib_ops/cpu/string_normalizer.h new file mode 100644 index 0000000000000..2de272cadcc24 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" + +#include +#include + +namespace onnxruntime { +namespace contrib { + +class StringNormalizer : public OpKernel { + public: + enum CaseAction { + NONE = 0, + LOWER = 1, + UPPER = 2, + }; + + explicit StringNormalizer(const OpKernelInfo& info); + ~StringNormalizer() = default; + + Status Compute(OpKernelContext* ctx) const override; + + private: + CaseAction casechangeaction_; + bool iscasesensitive_; + std::vector stopwords_; + std::string locale_; // needed for upper/lowercasing actions and case insensitive compare +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/string_normalizer_test.cc b/onnxruntime/test/contrib_ops/string_normalizer_test.cc new file mode 100644 index 0000000000000..6dc985054cbc1 --- /dev/null +++ b/onnxruntime/test/contrib_ops/string_normalizer_test.cc @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +namespace str_normalizer_test { +constexpr const char* domain = onnxruntime::kMSDomain; +const int opset_ver = 1; + +void InitTestAttr(OpTester& test, const std::string& casechangeaction, + bool iscasesensitive, + const std::vector& stopwords, + const std::string& locale) { + test.AddAttribute("casechangeaction", casechangeaction); + test.AddAttribute("iscasesensitive", int64_t{iscasesensitive}); + if (!stopwords.empty()) { + test.AddAttribute("stopwords", stopwords); + } + test.AddAttribute("locale", locale); +} +} // namespace str_normalizer_test + +using namespace str_normalizer_test; + +TEST(ContribOpTest, StringNormalizerTest) { + // Test wrong 2 dimensions + // - casesensitive approach + // - no stopwords. + // - No change case action + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", true, {}, "en_US"); + std::vector dims{2, 2}; + std::vector input = {std::string("monday"), std::string("tuesday"), std::string("wendsday"), std::string("thursday")}; + test.AddInput("T", dims, input); + std::vector output(input); // do the same for now + test.AddOutput("Y", dims, output); + + test.Run(OpTester::ExpectResult::kExpectFailure, "Input dimensions are either[C > 0] or [1][C > 0] allowed"); + } + // - casesensitive approach + // - no stopwords. + // - No change case action + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", true, {}, "en_US"); + std::vector dims{4}; + std::vector input = {std::string("monday"), std::string("tuesday"), + std::string("wednesday"), std::string("thursday")}; + test.AddInput("T", dims, input); + std::vector output(input); // do the same for now + test.AddOutput("Y", dims, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // - casesensitive approach + // - filter out monday + // - No change case action + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", true, {"monday"}, "en_US"); + std::vector dims{4}; + std::vector input = {std::string("monday"), std::string("tuesday"), + std::string("wednesday"), std::string("thursday")}; + test.AddInput("T", dims, input); + + std::vector output = {std::string("tuesday"), + std::string("wednesday"), std::string("thursday")}; + test.AddOutput("Y", {3}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // - casesensitive approach + // - filter out monday + // - LOWER should produce the same output as they are all lower. + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "LOWER", true, {"monday"}, "en_US"); + std::vector dims{4}; + std::vector input = {std::string("monday"), std::string("tuesday"), + std::string("wednesday"), std::string("thursday")}; + test.AddInput("T", dims, input); + + std::vector output = {std::string("tuesday"), + std::string("wednesday"), std::string("thursday")}; + test.AddOutput("Y", {3}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // - casesensitive approach + // - filter out monday + // - UPPER should produce the same output as they are all lower. + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", true, {"monday"}, "en_US"); + std::vector dims{4}; + std::vector input = {std::string("monday"), std::string("tuesday"), + std::string("wednesday"), std::string("thursday")}; + test.AddInput("T", dims, input); + + std::vector output = {std::string("TUESDAY"), + std::string("WEDNESDAY"), std::string("THURSDAY")}; + test.AddOutput("Y", {3}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // Empty output case + // - casesensitive approach + // - filter out monday + // - UPPER should produce the same output as they are all lower. + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", true, {"monday"}, "en_US"); + std::vector dims{2}; + std::vector input = {std::string("monday"), + std::string("monday")}; + test.AddInput("T", dims, input); + + std::vector output; + test.AddOutput("Y", {1}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // Empty output case + // - casesensitive approach + // - filter out monday + // - UPPER should produce the same output as they are all lower. + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", true, {"monday"}, "en_US"); + std::vector dims{1, 2}; + std::vector input = {std::string("monday"), + std::string("monday")}; + test.AddInput("T", dims, input); + + std::vector output; + test.AddOutput("Y", {1, 1}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +} // namespace test +} // namespace onnxruntime From 2de4a726a3938e2acf98d796a0ee01a988c0561e Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 30 Nov 2018 11:05:38 -0800 Subject: [PATCH 02/13] Add mixed language tests, test case insentive path. --- onnxruntime/contrib_ops/contrib_ops.cc | 2 +- .../contrib_ops/cpu/string_normalizer.cc | 55 +++++++++---- .../contrib_ops/string_normalizer_test.cc | 79 +++++++++++++++++-- 3 files changed, 115 insertions(+), 21 deletions(-) diff --git a/onnxruntime/contrib_ops/contrib_ops.cc b/onnxruntime/contrib_ops/contrib_ops.cc index 8f91b09a923d9..0ecd10e33e062 100644 --- a/onnxruntime/contrib_ops/contrib_ops.cc +++ b/onnxruntime/contrib_ops/contrib_ops.cc @@ -479,7 +479,7 @@ The bounding box coordinates corresponding to the selected indices can then be o "locale", "Platform dependent string that denotes the locale according to which output strings needs to be upper/lowercased. Default en_US", AttributeProto::STRING, - "en_US") + OPTIONAL) .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type(); output_elem_type->set_elem_type(ONNX_NAMESPACE::TensorProto::STRING); diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/contrib_ops/cpu/string_normalizer.cc index a2c1b77d3ad22..8759ee34c9702 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.cc +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.cc @@ -26,8 +26,8 @@ namespace string_normalizer { const std::string conv_error("Conversion Error"); const std::wstring wconv_error(L"Conversion Error"); // performs tolower/toupper in-place -void ChangeCase(const std::locale& loc, StringNormalizer::CaseAction caseaction, - std::wstring& wstr) { +inline void ChangeCase(const std::locale& loc, StringNormalizer::CaseAction caseaction, + std::wstring& wstr) { if (caseaction == StringNormalizer::LOWER) { std::transform(wstr.begin(), wstr.end(), wstr.begin(), [&loc](wchar_t ch) { return std::tolower(ch, loc); }); @@ -52,7 +52,9 @@ Status CopyCaseAction(ForwardIter first, ForwardIter end, OpKernelContext* ctx, if (C == 0) { output_dims.push_back(1); TensorShape output_shape(output_dims); - ctx->Output(0, output_shape); + auto output_ten = ctx->Output(0, output_shape); + auto output_default = output_ten->template MutableData(); + new (output_default) std::string(); return Status::OK(); } @@ -104,9 +106,9 @@ StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info) { iscasesensitive_ = iscasesensitive != 0; info.GetAttrs("stopwords", stopwords_); - // Default is specified in the schema - status = info.GetAttr("locale", &locale_); - ONNXRUNTIME_ENFORCE(status.IsOK(), "Failed to get locale"); + ONNXRUNTIME_ENFORCE(status.IsOK(), "Failed to get stopwords"); + + locale_ = info.GetAttrOrDefault("locale", std::string("en_US")); } Status StringNormalizer::Compute(OpKernelContext* ctx) const { @@ -148,8 +150,17 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { std::hash, std::equal_to> swords; - std::transform(stopwords_.begin(), stopwords_.end(), std::inserter(swords, swords.end()), - [](const std::string& s) { return std::cref(s); }); + for (const auto& s : stopwords_) { + if (s.empty()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Empty stopwords are invalid"); + } + auto p = swords.insert(std::cref(s)); + if (!p.second) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Duplicate stopwords not allowed"); + } + } std::vector filtered_strings; filtered_strings.reserve(C); @@ -170,13 +181,23 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { // Perform case-insensitive comparison. Convert to lowercase for NONE, LOWER and UPPER otherwise. const CaseAction ca = (casechangeaction_ == UPPER) ? UPPER : LOWER; std::unordered_set swords; - std::transform(stopwords_.begin(), stopwords_.end(), std::inserter(swords, swords.end()), - [&loc, &converter, ca](const std::string& s) { - std::wstring wstr = converter.from_bytes(s); - ChangeCase(loc, ca, wstr); - return wstr; - }); - + for (const auto& s : stopwords_) { + if (s.empty()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Empty stopwords are invalid"); + } + std::wstring wstr = converter.from_bytes(s); + if (wstr == wconv_error) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Stopword contains invalid utf8 chars at: " + s); + } + ChangeCase(loc, ca, wstr); + auto p = swords.insert(wstr); + if (!p.second) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Duplicate stopwords not allowed"); + } + } // Filter input. We choose to undergo conversion twice (if needed) // as oppose to preserve lower/uppercased strings to favor lower memory // consumption. @@ -185,6 +206,10 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { for (size_t input_idx = 0; input_idx < C; ++input_idx) { const std::string& s = *(input_data + input_idx); std::wstring wstr = converter.from_bytes(s); + if (wstr == wconv_error) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input contains invalid utf8 chars at: " + s); + } ChangeCase(loc, ca, wstr); if (0 == swords.count(wstr)) { filtered_strings.push_back(std::cref(s)); diff --git a/onnxruntime/test/contrib_ops/string_normalizer_test.cc b/onnxruntime/test/contrib_ops/string_normalizer_test.cc index 6dc985054cbc1..b1dea337200f2 100644 --- a/onnxruntime/test/contrib_ops/string_normalizer_test.cc +++ b/onnxruntime/test/contrib_ops/string_normalizer_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include @@ -21,7 +21,9 @@ void InitTestAttr(OpTester& test, const std::string& casechangeaction, if (!stopwords.empty()) { test.AddAttribute("stopwords", stopwords); } - test.AddAttribute("locale", locale); + if (!locale.empty()) { + test.AddAttribute("locale", locale); + } } } // namespace str_normalizer_test @@ -105,6 +107,73 @@ TEST(ContribOpTest, StringNormalizerTest) { test.AddOutput("Y", {3}, output); test.Run(OpTester::ExpectResult::kExpectSuccess); } + // - case-SENSETIVE approach en_US locale + // - we test the behavior of a mix of english, french, german, russian and chinese + // with en_US locale + // - filter out monday + // - UPPER should produce the same output as they are all lower. + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", true, {u8"monday"}, "en_US"); + std::vector dims{7}; + std::vector input = {std::string(u8"monday"), + std::string(u8"tuesday"), + std::string(u8"Besançon"), + std::string(u8"École élémentaire"), + std::string(u8"Понедельник"), + std::string(u8"mit freundlichen grüßen"), + std::string(u8"中文")}; + test.AddInput("T", dims, input); + + // en_US results (default) + std::vector output = {std::string(u8"TUESDAY"), + // It does upper case cecedille, accented E + // and german umlaut but fails + // with german eszett + std::string(u8"BESANÇON"), + std::string(u8"ÉCOLE ÉLÉMENTAIRE"), + // No issues with Cyrllic + std::string(u8"ПОНЕДЕЛЬНИК"), + std::string(u8"MIT FREUNDLICHEN GRÜßEN"), + // Chinese do not have cases + std::string(u8"中文")}; + test.AddOutput("Y", {6}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // - case-INSENSETIVE approach en_US locale + // - we test the behavior of a mix of english, french, german, russian and chinese + // with en_US locale + // - filter out monday + // - UPPER should produce the same output as they are all lower. + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", false, {u8"monday"}, "en_US"); + std::vector dims{7}; + std::vector input = {std::string(u8"monday"), + std::string(u8"tuesday"), + std::string(u8"Besançon"), + std::string(u8"École élémentaire"), + std::string(u8"Понедельник"), + std::string(u8"mit freundlichen grüßen"), + std::string(u8"中文")}; + test.AddInput("T", dims, input); + + // en_US results (default) + std::vector output = {std::string(u8"TUESDAY"), + // It does upper case cecedille, accented E + // and german umlaut but fails + // with german eszett + std::string(u8"BESANÇON"), + std::string(u8"ÉCOLE ÉLÉMENTAIRE"), + // No issues with Cyrllic + std::string(u8"ПОНЕДЕЛЬНИК"), + std::string(u8"MIT FREUNDLICHEN GRÜßEN"), + // Chinese do not have cases + std::string(u8"中文")}; + test.AddOutput("Y", {6}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // Empty output case // - casesensitive approach // - filter out monday @@ -117,7 +186,7 @@ TEST(ContribOpTest, StringNormalizerTest) { std::string("monday")}; test.AddInput("T", dims, input); - std::vector output; + std::vector output{""}; // One empty string test.AddOutput("Y", {1}, output); test.Run(OpTester::ExpectResult::kExpectSuccess); } @@ -127,13 +196,13 @@ TEST(ContribOpTest, StringNormalizerTest) { // - UPPER should produce the same output as they are all lower. { OpTester test("StringNormalizer", opset_ver, domain); - InitTestAttr(test, "UPPER", true, {"monday"}, "en_US"); + InitTestAttr(test, "UPPER", true, {"monday"}, ""); std::vector dims{1, 2}; std::vector input = {std::string("monday"), std::string("monday")}; test.AddInput("T", dims, input); - std::vector output; + std::vector output{""}; // One empty string test.AddOutput("Y", {1, 1}, output); test.Run(OpTester::ExpectResult::kExpectSuccess); } From 0741a0b46cac888e7e465d654108c359ffe97598 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 30 Nov 2018 11:40:50 -0800 Subject: [PATCH 03/13] Favor perf over memory, store converted strings. Add std::move(). --- .../contrib_ops/cpu/string_normalizer.cc | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/contrib_ops/cpu/string_normalizer.cc index 8759ee34c9702..c427fa7b5df34 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.cc +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.cc @@ -66,19 +66,19 @@ Status CopyCaseAction(ForwardIter first, ForwardIter end, OpKernelContext* ctx, size_t output_idx = 0; while (first != end) { - const std::string& s = *first; + auto& s = *first; if (caseaction == StringNormalizer::LOWER || caseaction == StringNormalizer::UPPER) { std::wstring wstr = converter.from_bytes(s); if (wstr == wconv_error) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Input contains invalid utf8 chars at: " + s); + "Input contains invalid utf8 chars at: " + static_cast(s)); } // In place transform ChangeCase(loc, caseaction, wstr); new (output_data + output_idx) std::string(converter.to_bytes(wstr)); } else { - // Simple copy - new (output_data + output_idx) std::string(s); + // Simple copy or move if the iterator points to a non-const string + new (output_data + output_idx) std::string(std::move(s)); } ++output_idx; ++first; @@ -198,11 +198,13 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { "Duplicate stopwords not allowed"); } } - // Filter input. We choose to undergo conversion twice (if needed) - // as oppose to preserve lower/uppercased strings to favor lower memory - // consumption. - std::vector filtered_strings; - filtered_strings.reserve(C); + // Filter input. When no case action is required + // we simply store original string references. + // Otherwise, we store converted strings. + std::vector filtered_orignal_strings; + std::vector filtered_cased_strings; + filtered_orignal_strings.reserve(C); + filtered_cased_strings.reserve(C); for (size_t input_idx = 0; input_idx < C; ++input_idx) { const std::string& s = *(input_data + input_idx); std::wstring wstr = converter.from_bytes(s); @@ -212,11 +214,20 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { } ChangeCase(loc, ca, wstr); if (0 == swords.count(wstr)) { - filtered_strings.push_back(std::cref(s)); + if (casechangeaction_ == NONE) { + filtered_orignal_strings.push_back(std::cref(s)); + } else { + filtered_cased_strings.push_back(converter.to_bytes(wstr)); + } } } - status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, loc, converter, - N, filtered_strings.size(), casechangeaction_); + if (casechangeaction_ == NONE) { + status = CopyCaseAction(filtered_orignal_strings.cbegin(), filtered_orignal_strings.cend(), ctx, loc, converter, + N, filtered_orignal_strings.size(), NONE); + } else { + status = CopyCaseAction(filtered_cased_strings.begin(), filtered_cased_strings.end(), ctx, loc, converter, + N, filtered_cased_strings.size(), NONE); + } } else { // Nothing to filter. Copy input to output and change case if needed status = CopyCaseAction(input_data, input_data + C, ctx, loc, converter, N, C, casechangeaction_); From 8fa693332e47ecff454281cf3c8e2d5fb8f60710 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 30 Nov 2018 15:04:21 -0800 Subject: [PATCH 04/13] Do not ignore return value from GetAttrs. --- onnxruntime/contrib_ops/cpu/string_normalizer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/contrib_ops/cpu/string_normalizer.cc index c427fa7b5df34..4364edead1a95 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.cc +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.cc @@ -105,7 +105,7 @@ StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info) { ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute iscasesensitive is not set"); iscasesensitive_ = iscasesensitive != 0; - info.GetAttrs("stopwords", stopwords_); + status = info.GetAttrs("stopwords", stopwords_); ONNXRUNTIME_ENFORCE(status.IsOK(), "Failed to get stopwords"); locale_ = info.GetAttrOrDefault("locale", std::string("en_US")); From 340df4c9f083a47442ae2224decd76f0f363fc1e Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 30 Nov 2018 16:23:57 -0800 Subject: [PATCH 05/13] Make stopwords optional. --- onnxruntime/contrib_ops/cpu/string_normalizer.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/contrib_ops/cpu/string_normalizer.cc index 4364edead1a95..576715af1a1d2 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.cc +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.cc @@ -105,9 +105,7 @@ StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info) { ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute iscasesensitive is not set"); iscasesensitive_ = iscasesensitive != 0; - status = info.GetAttrs("stopwords", stopwords_); - ONNXRUNTIME_ENFORCE(status.IsOK(), "Failed to get stopwords"); - + stopwords_ = info.GetAttrsOrDefault("stopwords"); locale_ = info.GetAttrOrDefault("locale", std::string("en_US")); } From 7da88f850a7bed0c63cd7780ad6fa4f37b16047d Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 3 Dec 2018 11:43:13 -0800 Subject: [PATCH 06/13] Address review comments. --- onnxruntime/contrib_ops/contrib_ops.cc | 6 +- .../contrib_ops/cpu/string_normalizer.cc | 118 +++++++++--------- .../contrib_ops/cpu/string_normalizer.h | 12 +- .../contrib_ops/string_normalizer_test.cc | 4 +- 4 files changed, 70 insertions(+), 70 deletions(-) diff --git a/onnxruntime/contrib_ops/contrib_ops.cc b/onnxruntime/contrib_ops/contrib_ops.cc index 0ecd10e33e062..df9ca7e566a31 100644 --- a/onnxruntime/contrib_ops/contrib_ops.cc +++ b/onnxruntime/contrib_ops/contrib_ops.cc @@ -12,8 +12,8 @@ namespace onnxruntime { namespace contrib { using ::ONNX_NAMESPACE::AttributeProto; -using ::ONNX_NAMESPACE::OPTIONAL; using ::ONNX_NAMESPACE::OpSchema; +using ::ONNX_NAMESPACE::OPTIONAL; void RegisterContribSchemas() { ONNX_CONTRIB_OPERATOR_SCHEMA(SampleOp) @@ -467,7 +467,7 @@ The bounding box coordinates corresponding to the selected indices can then be o "string enum that cases output to be lowercased/uppercases/unchanged. Valid values are \"LOWER\", \"UPPER\", \"NONE\"", AttributeProto::STRING) .Attr( - "iscasesensitive", + "is_case_sensitive", "Boolean. Whether the identification of stop words in X is case-sensitive.", AttributeProto::INT) .Attr( @@ -477,7 +477,7 @@ The bounding box coordinates corresponding to the selected indices can then be o OPTIONAL) .Attr( "locale", - "Platform dependent string that denotes the locale according to which output strings needs to be upper/lowercased. Default en_US", + "Environment dependent string that denotes the locale according to which output strings needs to be upper/lowercased. Default en_US", AttributeProto::STRING, OPTIONAL) .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/contrib_ops/cpu/string_normalizer.cc index 576715af1a1d2..a0a974457a6fd 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.cc +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.cc @@ -28,6 +28,7 @@ const std::wstring wconv_error(L"Conversion Error"); // performs tolower/toupper in-place inline void ChangeCase(const std::locale& loc, StringNormalizer::CaseAction caseaction, std::wstring& wstr) { + assert(caseaction != StringNormalizer::NONE); if (caseaction == StringNormalizer::LOWER) { std::transform(wstr.begin(), wstr.end(), wstr.begin(), [&loc](wchar_t ch) { return std::tolower(ch, loc); }); @@ -77,6 +78,7 @@ Status CopyCaseAction(ForwardIter first, ForwardIter end, OpKernelContext* ctx, ChangeCase(loc, caseaction, wstr); new (output_data + output_idx) std::string(converter.to_bytes(wstr)); } else { + assert(caseaction == StringNormalizer::NONE); // Simple copy or move if the iterator points to a non-const string new (output_data + output_idx) std::string(std::move(s)); } @@ -87,9 +89,19 @@ Status CopyCaseAction(ForwardIter first, ForwardIter end, OpKernelContext* ctx, } } // namespace string_normalizer -StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info) { +using namespace string_normalizer; + +StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info), + is_case_sensitive_(true), + casechangeaction_(NONE), + compare_caseaction_(NONE) { + int64_t iscasesensitive = 0; + Status status = info.GetAttr("is_case_sensitive", &iscasesensitive); + ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute is_case_sensitive is not set"); + is_case_sensitive_ = iscasesensitive != 0; + std::string casechangeaction; - auto status = info.GetAttr("casechangeaction", &casechangeaction); + status = info.GetAttr("casechangeaction", &casechangeaction); ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute caseaction is not set"); if (casechangeaction == "LOWER") { casechangeaction_ = LOWER; @@ -100,13 +112,30 @@ StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info) { } else { ONNXRUNTIME_ENFORCE(false, "attribute casechangeaction has invalid value"); } - int64_t iscasesensitive = 0; - status = info.GetAttr("iscasesensitive", &iscasesensitive); - ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute iscasesensitive is not set"); - iscasesensitive_ = iscasesensitive != 0; - stopwords_ = info.GetAttrsOrDefault("stopwords"); - locale_ = info.GetAttrOrDefault("locale", std::string("en_US")); + if (!is_case_sensitive_) { + // Convert stop words to a case which can help us preserve the case of filtered strings + compare_caseaction_ = (casechangeaction_ == UPPER) ? UPPER : LOWER; + } + + std::string loc = info.GetAttrOrDefault("locale", std::string("en_US")); + locale_ = std::locale(loc); + std::wstring_convert> converter(conv_error, wconv_error); + + std::vector swords = info.GetAttrsOrDefault("stopwords"); + for (const auto& sw : swords) { + ONNXRUNTIME_ENFORCE(!sw.empty(), "Empty stopwords not allowed"); + if (is_case_sensitive_) { + auto p = stopwords_.insert(sw); + ONNXRUNTIME_ENFORCE(p.second, "Duplicate stopwords not allowed"); + } else { + std::wstring wstr = converter.from_bytes(sw); + ONNXRUNTIME_ENFORCE(wstr != wconv_error, "Stopword contains invalid utf8 chars"); + ChangeCase(locale_, compare_caseaction_, wstr); + auto p = wstopwords_.insert(wstr); + ONNXRUNTIME_ENFORCE(p.second, "Duplicate stopwords not allowed"); + } + } } Status StringNormalizer::Compute(OpKernelContext* ctx) const { @@ -136,66 +165,30 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { } Status status; - std::locale loc(locale_); std::wstring_convert> converter(conv_error, wconv_error); auto const input_data = X->template Data(); - using StrRef = std::reference_wrapper; - if (iscasesensitive_) { + if (is_case_sensitive_) { if (!stopwords_.empty()) { - // Create a filter and create filtered output - std::unordered_set, - std::equal_to> - swords; - for (const auto& s : stopwords_) { - if (s.empty()) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Empty stopwords are invalid"); - } - auto p = swords.insert(std::cref(s)); - if (!p.second) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Duplicate stopwords not allowed"); - } - } - std::vector filtered_strings; filtered_strings.reserve(C); - for (size_t input_idx = 0; input_idx < C; ++input_idx) { - const std::string& s = *(input_data + input_idx); - if (0 == swords.count(s)) { + auto first = input_data; + auto const last = input_data + C; + while (first != last) { + const std::string& s = *first; + if (0 == stopwords_.count(s)) { filtered_strings.push_back(std::cref(s)); } + ++first; } - status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, loc, converter, + status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, locale_, converter, N, filtered_strings.size(), casechangeaction_); } else { // Nothing to filter. Copy input to output and change case if needed - status = CopyCaseAction(input_data, input_data + C, ctx, loc, converter, N, C, casechangeaction_); + status = CopyCaseAction(input_data, input_data + C, ctx, locale_, converter, N, C, casechangeaction_); } } else { - if (!stopwords_.empty()) { - // Perform case-insensitive comparison. Convert to lowercase for NONE, LOWER and UPPER otherwise. - const CaseAction ca = (casechangeaction_ == UPPER) ? UPPER : LOWER; - std::unordered_set swords; - for (const auto& s : stopwords_) { - if (s.empty()) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Empty stopwords are invalid"); - } - std::wstring wstr = converter.from_bytes(s); - if (wstr == wconv_error) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Stopword contains invalid utf8 chars at: " + s); - } - ChangeCase(loc, ca, wstr); - auto p = swords.insert(wstr); - if (!p.second) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Duplicate stopwords not allowed"); - } - } + if (!wstopwords_.empty()) { // Filter input. When no case action is required // we simply store original string references. // Otherwise, we store converted strings. @@ -203,32 +196,35 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { std::vector filtered_cased_strings; filtered_orignal_strings.reserve(C); filtered_cased_strings.reserve(C); - for (size_t input_idx = 0; input_idx < C; ++input_idx) { - const std::string& s = *(input_data + input_idx); + auto first = input_data; + auto const last = input_data + C; + while (first != last) { + const std::string& s = *first; std::wstring wstr = converter.from_bytes(s); if (wstr == wconv_error) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Input contains invalid utf8 chars at: " + s); } - ChangeCase(loc, ca, wstr); - if (0 == swords.count(wstr)) { + ChangeCase(locale_, compare_caseaction_, wstr); + if (0 == wstopwords_.count(wstr)) { if (casechangeaction_ == NONE) { filtered_orignal_strings.push_back(std::cref(s)); } else { filtered_cased_strings.push_back(converter.to_bytes(wstr)); } } + ++first; } if (casechangeaction_ == NONE) { - status = CopyCaseAction(filtered_orignal_strings.cbegin(), filtered_orignal_strings.cend(), ctx, loc, converter, + status = CopyCaseAction(filtered_orignal_strings.cbegin(), filtered_orignal_strings.cend(), ctx, locale_, converter, N, filtered_orignal_strings.size(), NONE); } else { - status = CopyCaseAction(filtered_cased_strings.begin(), filtered_cased_strings.end(), ctx, loc, converter, + status = CopyCaseAction(filtered_cased_strings.begin(), filtered_cased_strings.end(), ctx, locale_, converter, N, filtered_cased_strings.size(), NONE); } } else { // Nothing to filter. Copy input to output and change case if needed - status = CopyCaseAction(input_data, input_data + C, ctx, loc, converter, N, C, casechangeaction_); + status = CopyCaseAction(input_data, input_data + C, ctx, locale_, converter, N, C, casechangeaction_); } } return status; diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.h b/onnxruntime/contrib_ops/cpu/string_normalizer.h index 2de272cadcc24..f1d9207240b1d 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.h +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.h @@ -5,8 +5,9 @@ #include "core/framework/op_kernel.h" +#include #include -#include +#include namespace onnxruntime { namespace contrib { @@ -25,10 +26,13 @@ class StringNormalizer : public OpKernel { Status Compute(OpKernelContext* ctx) const override; private: + bool is_case_sensitive_; CaseAction casechangeaction_; - bool iscasesensitive_; - std::vector stopwords_; - std::string locale_; // needed for upper/lowercasing actions and case insensitive compare + CaseAction compare_caseaction_; // used for case-insensitive compare + std::locale locale_; // needed for upper/lowercasing actions and case insensitive compare + // Either if these are populated but not both + std::unordered_set stopwords_; + std::unordered_set wstopwords_; }; } // namespace contrib diff --git a/onnxruntime/test/contrib_ops/string_normalizer_test.cc b/onnxruntime/test/contrib_ops/string_normalizer_test.cc index b1dea337200f2..295e32be45888 100644 --- a/onnxruntime/test/contrib_ops/string_normalizer_test.cc +++ b/onnxruntime/test/contrib_ops/string_normalizer_test.cc @@ -17,7 +17,7 @@ void InitTestAttr(OpTester& test, const std::string& casechangeaction, const std::vector& stopwords, const std::string& locale) { test.AddAttribute("casechangeaction", casechangeaction); - test.AddAttribute("iscasesensitive", int64_t{iscasesensitive}); + test.AddAttribute("is_case_sensitive", int64_t{iscasesensitive}); if (!stopwords.empty()) { test.AddAttribute("stopwords", stopwords); } @@ -38,7 +38,7 @@ TEST(ContribOpTest, StringNormalizerTest) { OpTester test("StringNormalizer", opset_ver, domain); InitTestAttr(test, "NONE", true, {}, "en_US"); std::vector dims{2, 2}; - std::vector input = {std::string("monday"), std::string("tuesday"), std::string("wendsday"), std::string("thursday")}; + std::vector input = {std::string("monday"), std::string("tuesday"), std::string("wednesday"), std::string("thursday")}; test.AddInput("T", dims, input); std::vector output(input); // do the same for now test.AddOutput("Y", dims, output); From 584d06ffd79b2e863a26d4d028cd91783ee1abd3 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 3 Dec 2018 14:42:43 -0800 Subject: [PATCH 07/13] Address typos. --- onnxruntime/contrib_ops/contrib_ops.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/contrib_ops.cc b/onnxruntime/contrib_ops/contrib_ops.cc index df9ca7e566a31..ad52918ab7705 100644 --- a/onnxruntime/contrib_ops/contrib_ops.cc +++ b/onnxruntime/contrib_ops/contrib_ops.cc @@ -484,8 +484,7 @@ The bounding box coordinates corresponding to the selected indices can then be o auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type(); output_elem_type->set_elem_type(ONNX_NAMESPACE::TensorProto::STRING); }) - .SetDoc(R"DOC([optional] Step1: Remove elements in X if they matches any of stop words so that output tensor may not contain any stop word. This operator only accepts [C]- and [1, C]-tensor. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C]. -[optional] Step2: Lower all characters (if action is LOWER) in X or capitalize them (when action is UPPER))DOC"); + .SetDoc(R"DOC([optional] Step1: Remove elements in X if they match any of the stop words so that the output tensor will not contain any stop words. This operator only accepts [C]- and [1, C]-tensors. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C].)DOC"); } class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp); From 6e5c1ae0a3987dd91ca33e8db5380aed3d8ddb52 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 3 Dec 2018 16:15:18 -0800 Subject: [PATCH 08/13] Address test failure. --- onnxruntime/test/ir/graph_test.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 445d044ef1f40..5947c61aaca31 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -557,10 +557,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) // Validate that an unused initializer doesn't break graph loading/resolution // and is removed as expected. TEST(ResolvingGraphTest, UnusedInitializerIsIgnored) { - OPERATOR_SCHEMA(Identity_Fake) - .SetDoc("Identity.") - .Input(0, "input_1", "docstr for input_1.", "tensor(int32)") - .Output(0, "output_1", "docstr for output_1.", "tensor(int32)"); + ASSERT_TRUE(kSchemasRegistered); Model model("UnusedInitializerIsIgnored"); auto& graph = model.MainGraph(); From 34da9ce8bca2607e26196e3c844bd6511b2d2c44 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 3 Dec 2018 16:54:10 -0800 Subject: [PATCH 09/13] Create a locale on the fly. Default locale does not seem to create well. --- .../contrib_ops/cpu/string_normalizer.cc | 19 ++++++++++--------- .../contrib_ops/cpu/string_normalizer.h | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/contrib_ops/cpu/string_normalizer.cc index a0a974457a6fd..1aefb12cf363d 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.cc +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.cc @@ -118,8 +118,8 @@ StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info), compare_caseaction_ = (casechangeaction_ == UPPER) ? UPPER : LOWER; } - std::string loc = info.GetAttrOrDefault("locale", std::string("en_US")); - locale_ = std::locale(loc); + locale_ = info.GetAttrOrDefault("locale", std::string("en_US")); + std::locale locale(locale_); std::wstring_convert> converter(conv_error, wconv_error); std::vector swords = info.GetAttrsOrDefault("stopwords"); @@ -131,7 +131,7 @@ StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info), } else { std::wstring wstr = converter.from_bytes(sw); ONNXRUNTIME_ENFORCE(wstr != wconv_error, "Stopword contains invalid utf8 chars"); - ChangeCase(locale_, compare_caseaction_, wstr); + ChangeCase(locale, compare_caseaction_, wstr); auto p = wstopwords_.insert(wstr); ONNXRUNTIME_ENFORCE(p.second, "Duplicate stopwords not allowed"); } @@ -165,6 +165,7 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { } Status status; + std::locale locale(locale_); std::wstring_convert> converter(conv_error, wconv_error); auto const input_data = X->template Data(); using StrRef = std::reference_wrapper; @@ -181,11 +182,11 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { } ++first; } - status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, locale_, converter, + status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, locale, converter, N, filtered_strings.size(), casechangeaction_); } else { // Nothing to filter. Copy input to output and change case if needed - status = CopyCaseAction(input_data, input_data + C, ctx, locale_, converter, N, C, casechangeaction_); + status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_); } } else { if (!wstopwords_.empty()) { @@ -205,7 +206,7 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Input contains invalid utf8 chars at: " + s); } - ChangeCase(locale_, compare_caseaction_, wstr); + ChangeCase(locale, compare_caseaction_, wstr); if (0 == wstopwords_.count(wstr)) { if (casechangeaction_ == NONE) { filtered_orignal_strings.push_back(std::cref(s)); @@ -216,15 +217,15 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { ++first; } if (casechangeaction_ == NONE) { - status = CopyCaseAction(filtered_orignal_strings.cbegin(), filtered_orignal_strings.cend(), ctx, locale_, converter, + status = CopyCaseAction(filtered_orignal_strings.cbegin(), filtered_orignal_strings.cend(), ctx, locale, converter, N, filtered_orignal_strings.size(), NONE); } else { - status = CopyCaseAction(filtered_cased_strings.begin(), filtered_cased_strings.end(), ctx, locale_, converter, + status = CopyCaseAction(filtered_cased_strings.begin(), filtered_cased_strings.end(), ctx, locale, converter, N, filtered_cased_strings.size(), NONE); } } else { // Nothing to filter. Copy input to output and change case if needed - status = CopyCaseAction(input_data, input_data + C, ctx, locale_, converter, N, C, casechangeaction_); + status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_); } } return status; diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.h b/onnxruntime/contrib_ops/cpu/string_normalizer.h index f1d9207240b1d..73fe4dc878557 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.h +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.h @@ -29,7 +29,7 @@ class StringNormalizer : public OpKernel { bool is_case_sensitive_; CaseAction casechangeaction_; CaseAction compare_caseaction_; // used for case-insensitive compare - std::locale locale_; // needed for upper/lowercasing actions and case insensitive compare + std::string locale_; // Either if these are populated but not both std::unordered_set stopwords_; std::unordered_set wstopwords_; From def825e5e1d0994b1a76fe5802763ebe16443c64 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 3 Dec 2018 17:46:03 -0800 Subject: [PATCH 10/13] Try default locale as en_US.UTF-8 --- onnxruntime/contrib_ops/cpu/string_normalizer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/contrib_ops/cpu/string_normalizer.cc index 1aefb12cf363d..39b264e3a481d 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.cc +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.cc @@ -118,7 +118,7 @@ StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info), compare_caseaction_ = (casechangeaction_ == UPPER) ? UPPER : LOWER; } - locale_ = info.GetAttrOrDefault("locale", std::string("en_US")); + locale_ = info.GetAttrOrDefault("locale", std::string("en_US.UTF-8")); std::locale locale(locale_); std::wstring_convert> converter(conv_error, wconv_error); From c549d83069d7017cbb3e0f9b7c9e1a6b0afc279c Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 4 Dec 2018 10:31:20 -0800 Subject: [PATCH 11/13] Add CI language-pack-en to make default locale available. Catch and translate locale creation exception to make the message meaningful. --- onnxruntime/contrib_ops/cpu/string_normalizer.cc | 16 +++++++++++++--- onnxruntime/contrib_ops/cpu/string_normalizer.h | 2 +- .../linux/docker/scripts/install_ubuntu.sh | 1 + .../ci_build/github/linux/ubuntu16.04/install.sh | 1 + 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/contrib_ops/cpu/string_normalizer.cc index 39b264e3a481d..43e3b250dc44c 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.cc +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.cc @@ -87,6 +87,16 @@ Status CopyCaseAction(ForwardIter first, ForwardIter end, OpKernelContext* ctx, } return Status::OK(); } + +inline std::locale GetLocale(const std::string& locale_name) { + try { + std::locale result(locale_name); + return result; + } catch (const std::runtime_error& e) { + ONNXRUNTIME_THROW("Failed to construct locale with name: ", + locale_name, e.what(), " Please, install necessary language-pack-XX"); + } +} } // namespace string_normalizer using namespace string_normalizer; @@ -118,8 +128,8 @@ StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info), compare_caseaction_ = (casechangeaction_ == UPPER) ? UPPER : LOWER; } - locale_ = info.GetAttrOrDefault("locale", std::string("en_US.UTF-8")); - std::locale locale(locale_); + locale_name_ = info.GetAttrOrDefault("locale", std::string("en_US.UTF-8")); + std::locale locale = GetLocale(locale_name_); std::wstring_convert> converter(conv_error, wconv_error); std::vector swords = info.GetAttrsOrDefault("stopwords"); @@ -165,7 +175,7 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { } Status status; - std::locale locale(locale_); + std::locale locale = GetLocale(locale_name_); std::wstring_convert> converter(conv_error, wconv_error); auto const input_data = X->template Data(); using StrRef = std::reference_wrapper; diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.h b/onnxruntime/contrib_ops/cpu/string_normalizer.h index 73fe4dc878557..8bc865400f6d4 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.h +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.h @@ -29,7 +29,7 @@ class StringNormalizer : public OpKernel { bool is_case_sensitive_; CaseAction casechangeaction_; CaseAction compare_caseaction_; // used for case-insensitive compare - std::string locale_; + std::string locale_name_; // Either if these are populated but not both std::unordered_set stopwords_; std::unordered_set wstopwords_; diff --git a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh index 163342d157f49..eb0aed9d5e859 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh @@ -25,6 +25,7 @@ apt-get update && apt-get install -y --no-install-recommends \ sudo \ gfortran \ python3-dev \ + language-pack-en \ libopenblas-dev \ liblttng-ust0 \ libcurl3 \ diff --git a/tools/ci_build/github/linux/ubuntu16.04/install.sh b/tools/ci_build/github/linux/ubuntu16.04/install.sh index 7c35cc3c75657..cc49f269173e0 100755 --- a/tools/ci_build/github/linux/ubuntu16.04/install.sh +++ b/tools/ci_build/github/linux/ubuntu16.04/install.sh @@ -16,6 +16,7 @@ apt-get update && apt-get install -y --no-install-recommends \ sudo \ gfortran \ python3-dev \ + language-pack-en \ libopenblas-dev \ liblttng-ust0 \ libcurl3 \ From 50ca1c5e28cc62369a7991c947c5fbe72fd4bdb9 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 4 Dec 2018 12:08:01 -0800 Subject: [PATCH 12/13] Make sure tests use en_US.UTF-8. Adjust exception message. --- onnxruntime/contrib_ops/cpu/string_normalizer.cc | 4 ++-- .../test/contrib_ops/string_normalizer_test.cc | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/contrib_ops/cpu/string_normalizer.cc index 43e3b250dc44c..f367302095d94 100644 --- a/onnxruntime/contrib_ops/cpu/string_normalizer.cc +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.cc @@ -93,8 +93,8 @@ inline std::locale GetLocale(const std::string& locale_name) { std::locale result(locale_name); return result; } catch (const std::runtime_error& e) { - ONNXRUNTIME_THROW("Failed to construct locale with name: ", - locale_name, e.what(), " Please, install necessary language-pack-XX"); + ONNXRUNTIME_THROW("Failed to construct locale with name:", + locale_name, ":", e.what(), ":Please, install necessary language-pack-XX and configure locales"); } } } // namespace string_normalizer diff --git a/onnxruntime/test/contrib_ops/string_normalizer_test.cc b/onnxruntime/test/contrib_ops/string_normalizer_test.cc index 295e32be45888..5cf060775adc2 100644 --- a/onnxruntime/test/contrib_ops/string_normalizer_test.cc +++ b/onnxruntime/test/contrib_ops/string_normalizer_test.cc @@ -36,7 +36,7 @@ TEST(ContribOpTest, StringNormalizerTest) { // - No change case action { OpTester test("StringNormalizer", opset_ver, domain); - InitTestAttr(test, "NONE", true, {}, "en_US"); + InitTestAttr(test, "NONE", true, {}, "en_US.UTF-8"); std::vector dims{2, 2}; std::vector input = {std::string("monday"), std::string("tuesday"), std::string("wednesday"), std::string("thursday")}; test.AddInput("T", dims, input); @@ -50,7 +50,7 @@ TEST(ContribOpTest, StringNormalizerTest) { // - No change case action { OpTester test("StringNormalizer", opset_ver, domain); - InitTestAttr(test, "NONE", true, {}, "en_US"); + InitTestAttr(test, "NONE", true, {}, "en_US.UTF-8"); std::vector dims{4}; std::vector input = {std::string("monday"), std::string("tuesday"), std::string("wednesday"), std::string("thursday")}; @@ -64,7 +64,7 @@ TEST(ContribOpTest, StringNormalizerTest) { // - No change case action { OpTester test("StringNormalizer", opset_ver, domain); - InitTestAttr(test, "NONE", true, {"monday"}, "en_US"); + InitTestAttr(test, "NONE", true, {"monday"}, "en_US.UTF-8"); std::vector dims{4}; std::vector input = {std::string("monday"), std::string("tuesday"), std::string("wednesday"), std::string("thursday")}; @@ -80,7 +80,7 @@ TEST(ContribOpTest, StringNormalizerTest) { // - LOWER should produce the same output as they are all lower. { OpTester test("StringNormalizer", opset_ver, domain); - InitTestAttr(test, "LOWER", true, {"monday"}, "en_US"); + InitTestAttr(test, "LOWER", true, {"monday"}, "en_US.UTF-8"); std::vector dims{4}; std::vector input = {std::string("monday"), std::string("tuesday"), std::string("wednesday"), std::string("thursday")}; @@ -96,7 +96,7 @@ TEST(ContribOpTest, StringNormalizerTest) { // - UPPER should produce the same output as they are all lower. { OpTester test("StringNormalizer", opset_ver, domain); - InitTestAttr(test, "UPPER", true, {"monday"}, "en_US"); + InitTestAttr(test, "UPPER", true, {"monday"}, "en_US.UTF-8"); std::vector dims{4}; std::vector input = {std::string("monday"), std::string("tuesday"), std::string("wednesday"), std::string("thursday")}; @@ -114,7 +114,7 @@ TEST(ContribOpTest, StringNormalizerTest) { // - UPPER should produce the same output as they are all lower. { OpTester test("StringNormalizer", opset_ver, domain); - InitTestAttr(test, "UPPER", true, {u8"monday"}, "en_US"); + InitTestAttr(test, "UPPER", true, {u8"monday"}, "en_US.UTF-8"); std::vector dims{7}; std::vector input = {std::string(u8"monday"), std::string(u8"tuesday"), @@ -147,7 +147,7 @@ TEST(ContribOpTest, StringNormalizerTest) { // - UPPER should produce the same output as they are all lower. { OpTester test("StringNormalizer", opset_ver, domain); - InitTestAttr(test, "UPPER", false, {u8"monday"}, "en_US"); + InitTestAttr(test, "UPPER", false, {u8"monday"}, "en_US.UTF-8"); std::vector dims{7}; std::vector input = {std::string(u8"monday"), std::string(u8"tuesday"), @@ -180,7 +180,7 @@ TEST(ContribOpTest, StringNormalizerTest) { // - UPPER should produce the same output as they are all lower. { OpTester test("StringNormalizer", opset_ver, domain); - InitTestAttr(test, "UPPER", true, {"monday"}, "en_US"); + InitTestAttr(test, "UPPER", true, {"monday"}, "en_US.UTF-8"); std::vector dims{2}; std::vector input = {std::string("monday"), std::string("monday")}; From b41cf7e5eba9010bd6ee0042f45fb93aacd688a3 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 4 Dec 2018 12:10:56 -0800 Subject: [PATCH 13/13] Make sure locales are configured on Ubuntu. --- tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh | 3 +++ tools/ci_build/github/linux/ubuntu16.04/install.sh | 3 +++ 2 files changed, 6 insertions(+) diff --git a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh index eb0aed9d5e859..3705834ee25d1 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh @@ -39,6 +39,9 @@ apt-get update && apt-get install -y --no-install-recommends \ rsync libunwind8 libpng16-dev \ python3-setuptools python3-numpy python3-wheel python python3-pip python3-pytest +locale-gen en_US.UTF-8 +update-locale LANG=en_US.UTF-8 + if [ $PYTHON_VER != "3.5" ]; then apt-get install -y --no-install-recommends \ python${PYTHON_VER} \ diff --git a/tools/ci_build/github/linux/ubuntu16.04/install.sh b/tools/ci_build/github/linux/ubuntu16.04/install.sh index cc49f269173e0..ca4c289f41cc5 100755 --- a/tools/ci_build/github/linux/ubuntu16.04/install.sh +++ b/tools/ci_build/github/linux/ubuntu16.04/install.sh @@ -29,6 +29,9 @@ apt-get update && apt-get install -y --no-install-recommends \ rsync libunwind8 \ python3-setuptools python3-numpy python3-wheel python python3-pip +locale-gen en_US.UTF-8 +update-locale LANG=en_US.UTF-8 + rm -rf /var/lib/apt/lists/* aria2c -q -d /tmp https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip